Skip to content

Instantly share code, notes, and snippets.

@marcoleewow
Last active December 6, 2017 06:21
Show Gist options
  • Save marcoleewow/d7dc5078a36927b3f13dbf64ac9905ad to your computer and use it in GitHub Desktop.
Save marcoleewow/d7dc5078a36927b3f13dbf64ac9905ad to your computer and use it in GitHub Desktop.
Tensorflow 1.4.0-rc1 CTC Loss Simple Test
import numpy as np
import tensorflow as tf
def CTCLossSimpleTest():
"""
Alex Graves Paper example:
# a -
P = [[0.3, 0.7], # t = 0
[0.4, 0.6]] # t = 1
P(a) = P(a-) + P(-a) + P(aa) = 0.3*0.6 + 0.7*0.4 + 0.3*0.4 = 0.58
negative log prob = -ln(0.58) = 0.544727175
"""
# vocabulary list
vocabularies = 'a'
# input tensor
# a -
P = [[0.3, 0.7], # t = 0
[0.4, 0.6]] # t = 1
print(P)
# might be something wrong with precision here?
# convert to float32 array
P = np.array(P, dtype=np.float32)
max_time, num_classes = P.shape
assert num_classes == len(vocabularies) + 1
# then to convert tensor
P = tf.convert_to_tensor(P)
# expand dims for batch dimension
P = tf.expand_dims(P, axis=1)
assert P.shape == (max_time, 1, num_classes) # shape = (Max_time, batch_dim, num_classes)
# convert label string to list of indices
label = 'a'
label = [vocabularies.index(char) for char in label]
# calculate sequence length
sequence_length = len(label)
# convert label to tf.SparseTensor
value = label
index = [[0, i] for i in range(sequence_length)]
dense_shape = [1, sequence_length]
labels = tf.SparseTensor(indices=index, values=value, dense_shape=dense_shape)
# convert sequence length to tensor
sequence_length = np.atleast_1d(sequence_length)
sequence_length = tf.convert_to_tensor(sequence_length, dtype=tf.int32)
# more assertions
assert P.shape == (max_time, 1, num_classes) # shape = (Max_time, batch_size, num_classes)
assert sequence_length.shape == (1,) #shape = (batch_size)
# CTC loss
loss = tf.nn.ctc_loss(labels, P, sequence_length)
# run session
sess = tf.Session()
prob = sess.run([loss])
# compare truth prob to calculated prob
true_prob = 0.58
neg_log_true_prob = -1. * np.log(true_prob)
print('true negative log prob = %.4f' % neg_log_true_prob)
print('Tensorflow CTC calculated negative log prob = %.4f' % prob[0])
assert prob[0] == neg_log_true_prob
CTCLossSimpleTest()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment