Created
August 2, 2018 10:08
-
-
Save githubharald/8b6f3d489fc014b0faccbae8542060dc to your computer and use it in GitHub Desktop.
Compute confidence score for CTC-decoded text using TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Compute score for decoded text in a CTC-trained neural network using TensorFlow: | |
1. decode text with best path decoding (or some other decoder) | |
2. feed decoded text into loss function | |
3. loss is negative logarithm of probability | |
Example data: two time-steps, 2 labels (0, 1) and the blank label (2). | |
Decoding results in [0] (i.e. string containing one entry for label 0). | |
The probability is the sum over all paths yielding [0], these are: [0, 0], [0, 2], [2, 0] | |
with probability | |
0.6*0.3 + 0.6*0.6 + 0.3*0.3 = 0.63. | |
Expected output: | |
Best path decoding: [0] | |
Loss: 0.462035 | |
Probability: 0.63 | |
""" | |
import numpy as np | |
import tensorflow as tf | |
# size of input data | |
batchSize = 1 | |
numClasses = 3 | |
numTimesteps = 2 | |
def createGraph(): | |
"create computation graph" | |
tinputs = tf.placeholder(tf.float32, [numTimesteps, batchSize, numClasses]) | |
tseqLen = tf.placeholder(tf.int32, [None]) # list of sequence length in batch | |
tgroundtruth = tf.SparseTensor(tf.placeholder(tf.int64, shape=[None, 2]) , tf.placeholder(tf.int32, [None]), tf.placeholder(tf.int64, [2])) | |
tloss = tf.nn.ctc_loss(tgroundtruth, tinputs, tseqLen) | |
tbest = tf.nn.ctc_greedy_decoder(tinputs, tseqLen, merge_repeated=True) | |
return (tinputs, tseqLen, tgroundtruth, tloss, tbest) | |
def getData(): | |
"get data as logits (softmax not yet applied)" | |
seqLen = [numTimesteps] | |
inputs = np.log(np.asarray([ [[0.6, 0.1, 0.3]], [[0.3, 0.1, 0.6]] ], np.float32)) | |
return (inputs, seqLen) | |
def toLabelString(decoderOutput): | |
"map sparse tensor from decoder to label string" | |
decoded = decoderOutput[0][0] | |
idxDict = {b:[] for b in range(batchSize)} | |
encodedLabels = [[] for i in range(batchSize)] | |
for (idxVal, idx2d) in enumerate(decoded.indices): | |
value = decoded.values[idxVal] | |
batch = idx2d[0] | |
encodedLabels[batch].append(value) | |
return encodedLabels[0] | |
def main(): | |
# initialize | |
(tinputs, tseqLen, tgroundtruth, tloss, tbest) = createGraph() | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
# get data | |
(inputs, seqLen) = getData() | |
# decode with best path decoding (greedy decoder) | |
retBest = sess.run(tbest, {tinputs:inputs, tseqLen:seqLen } ) | |
print('Best path decoding:', toLabelString(retBest)) | |
# for decoded result, compute loss | |
retLoss = sess.run(tloss, {tinputs:inputs, tseqLen:seqLen, tgroundtruth: (retBest[0][0].indices, retBest[0][0].values, retBest[0][0].dense_shape) }) | |
# print loss and probability of decoded result | |
print('Loss:', retLoss[0]) | |
print('Probability:', np.exp(-retLoss[0])) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment