-
-
Save danijar/3f3b547ff68effb03e20c470af22c696 to your computer and use it in GitHub Desktop.
# Working example for my blog post at: | |
# http://danijar.com/variable-sequence-lengths-in-tensorflow/ | |
import functools | |
import sets | |
import tensorflow as tf | |
from tensorflow.models.rnn import rnn_cell | |
from tensorflow.models.rnn import rnn | |
def lazy_property(function): | |
attribute = '_' + function.__name__ | |
@property | |
@functools.wraps(function) | |
def wrapper(self): | |
if not hasattr(self, attribute): | |
setattr(self, attribute, function(self)) | |
return getattr(self, attribute) | |
return wrapper | |
class VariableSequenceClassification: | |
def __init__(self, data, target, num_hidden=200, num_layers=2): | |
self.data = data | |
self.target = target | |
self._num_hidden = num_hidden | |
self._num_layers = num_layers | |
self.prediction | |
self.error | |
self.optimize | |
@lazy_property | |
def length(self): | |
used = tf.sign(tf.reduce_max(tf.abs(self.data), reduction_indices=2)) | |
length = tf.reduce_sum(used, reduction_indices=1) | |
length = tf.cast(length, tf.int32) | |
return length | |
@lazy_property | |
def prediction(self): | |
# Recurrent network. | |
output, _ = rnn.dynamic_rnn( | |
rnn_cell.GRUCell(self._num_hidden), | |
data, | |
dtype=tf.float32, | |
sequence_length=self.length, | |
) | |
last = self._last_relevant(output, self.length) | |
# Softmax layer. | |
weight, bias = self._weight_and_bias( | |
self._num_hidden, int(self.target.get_shape()[1])) | |
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias) | |
return prediction | |
@lazy_property | |
def cost(self): | |
cross_entropy = -tf.reduce_sum(self.target * tf.log(self.prediction)) | |
return cross_entropy | |
@lazy_property | |
def optimize(self): | |
learning_rate = 0.003 | |
optimizer = tf.train.RMSPropOptimizer(learning_rate) | |
return optimizer.minimize(self.cost) | |
@lazy_property | |
def error(self): | |
mistakes = tf.not_equal( | |
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1)) | |
return tf.reduce_mean(tf.cast(mistakes, tf.float32)) | |
@staticmethod | |
def _weight_and_bias(in_size, out_size): | |
weight = tf.truncated_normal([in_size, out_size], stddev=0.01) | |
bias = tf.constant(0.1, shape=[out_size]) | |
return tf.Variable(weight), tf.Variable(bias) | |
@staticmethod | |
def _last_relevant(output, length): | |
batch_size = tf.shape(output)[0] | |
max_length = int(output.get_shape()[1]) | |
output_size = int(output.get_shape()[2]) | |
index = tf.range(0, batch_size) * max_length + (length - 1) | |
flat = tf.reshape(output, [-1, output_size]) | |
relevant = tf.gather(flat, index) | |
return relevant | |
if __name__ == '__main__': | |
# We treat images as sequences of pixel rows. | |
train, test = sets.Mnist() | |
_, rows, row_size = train.data.shape | |
num_classes = train.target.shape[1] | |
data = tf.placeholder(tf.float32, [None, rows, row_size]) | |
target = tf.placeholder(tf.float32, [None, num_classes]) | |
model = VariableSequenceClassification(data, target) | |
sess = tf.Session() | |
sess.run(tf.initialize_all_variables()) | |
for epoch in range(10): | |
for _ in range(100): | |
batch = train.sample(10) | |
sess.run(model.optimize, {data: batch.data, target: batch.target}) | |
error = sess.run(model.error, {data: test.data, target: test.target}) | |
print('Epoch {:2d} error {:3.1f}%'.format(epoch + 1, 100 * error)) |
Hello, I would like to use your code to variable lunguezza strings. I need help please. My goal and classify each string with a target (target is 0 or 1 then only 2 classes), I do not work with images.
My trainset file is like the following:
1 s1 s2 ... sn
0 s1 s2 ... si
1 s1 s2
where the first column is the target of the sequence represented by 's' values (the sequence can be binary or not).
After I did the reading files, what can I do to adapt your main?
Thank you.
May I ask why are you used a "-" in your cost function?
@Dellen
Maybe you should refer to the definition of cross entropy loss function...
-tf.reduce_sum(self.target * tf.log(self.prediction))
should be equivalent to tf.reduce_sum(self.target * tf.log(1.0/self.prediction))
.
HI, I noticed that you have removed the DropoutWrapper in this version, which will cause the model outputs to be random in testing or evaluating stages previously. To avoid this problem, I modified the codes like this:
def __init__(self, data, target, learning_rate=0.001, num_hidden=256, num_layers=2, dropout = 0.8):
self.data = data
self.target = target
self._num_hidden = num_hidden
self._num_layers = num_layers
self._dropout = dropout
self._learning_rate = learning_rate
self.prediction
self.error
self.optimize
self._is_training = True
@lazy_property
def prediction(self):
cell = tf.contrib.rnn.LSTMCell(self._num_hidden)
if _is_training == True:
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self._dropout)
cell = tf.contrib.rnn.MultiRNNCell([cell] * self._num_layers)
output, _ = tf.nn.dynamic_rnn(
cell,
self.data,
dtype=tf.float32,
sequence_length=self.length,
)
last = self._last_relevant(output, self.length)
# Softmax layer.
weight, bias = self._weight_and_bias(
self._num_hidden, int(self.target.get_shape()[1]))
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias)
return prediction
def set_is_training(self, is_training):
self._is_training = is_training
And in training stages, I set_is_training(True); in testing or evaluating stages, I set_is_training(False).
But I found that because the prediction function is defined as a @lazy_property, if _is_training == True:
will work only when the first time I call the model.prediction. As a result, if in training stage I set _is_training = True, the DropoutWrapper will work in testing and evaluating stages also.
So, how can I use dropout in training, and remove the dropout in testing stages? Thank you very much for your help!
Hello,
Thank you for this post. Is this code still relevant ? Would using gather_nd
solve the problem of retrieving the last relevant ? If yes, how would you use it ?
Here is update version that works with TF 1.4: https://gist.github.com/abaybektursun/98656e483ec6e918c26235b47f3f5d60
I have a same problem like MartinThoma