Created
July 26, 2017 15:25
-
-
Save paultsw/4c857bbe5df5a02f7e1b1cc6b638b9c9 to your computer and use it in GitHub Desktop.
Compute percent identity between two bases.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tensorflow ops for computing percent-identity between two sequences. | |
""" | |
import tensorflow as tf | |
def sparsify_seq(dseq, maxtime, pad_value=0): | |
""" | |
Convert a dense rank-2 sequence tensor `dseq` to a sparse tensor, applying masking to | |
pad values to avoid spuriously high edit distances when comparing two sparsified | |
tensors. | |
Args: | |
dseq: rank-2 tensor representing a sequence. Of shape [batch, maxtime]. | |
maxtime: Python integer indicating the maximum number of timesteps to take. | |
pad_value: a value to ignore in the sparsify process. | |
Returns: | |
a SparseTensor representing dseq. | |
""" | |
indices = tf.where(tf.not_equal(dseq, pad_value)) | |
values = tf.to_int32(tf.gather_nd(dseq, indices=indices)) | |
shape = tf.shape(dseq, out_type=tf.int64) | |
return tf.SparseTensor(indices=indices, values=values, dense_shape=shape) | |
def pct_identity(basecalls, true_base_labels, max_time, pad_token=0): | |
""" | |
Append ops to compute edit distance (aka Levenshtein distance) between true bases and predicted basecalls. | |
Args: | |
* basecalls: a tf.int32 tensor of shape [batch, max_time] with values in [0,num_labels). | |
* true_base_labels: a rank-2 tf.int32 tensor of shape [batch, max_time] with values in [0,num_labels). | |
* max_time: a python integer indicating the sequence-length (padding symbols included) of each tensor. | |
* pad_token: a python integer indicating the integer label that represents a padding character. | |
Padding characters are masked-out by `sparsify_seq()`. | |
Returns: | |
(1. - edit_dist) == percent identity between `basecalls`, `true_base_labels`, a vector of shape [batch] | |
with dtype=tf.float32. | |
""" | |
with tf.name_scope("PctIdentity"): | |
edit_dist = tf.edit_distance(sparsify_seq(basecalls, max_time, pad_value=pad_token), | |
sparsify_seq(true_base_labels, max_time, pad_value=pad_token), | |
normalize=True) | |
return (1. - edit_dist) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment