This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Wd1 = tf.get_variable(name='dense1', shape=(rnn_size, 16)) | |
Wd2 = tf.get_variable(name='dense2', shape=(rnn_size, 2)) | |
dense1 = tf.nn.relu(tf.matmul(last_rnn_output, Wd1)) | |
dense2 = tf.matmul(dense1, Wd2)) | |
pred = tf.nn.softmax(dense2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hidden_states = tf.scan(fn=rnn_step, | |
elems=tf.transpose(embed, perm=[1, 0, 2]), # change batch_size*seq_len*dim --> seq_len*batch_size*dim | |
initializer=tf.zeros([batch_size, rnn_size])) | |
outputs = tf.transpose(hidden_states, perm=[1, 0, 2]) # convert to original shape --> batch_size*seq_len*dim | |
last_rnn_output = outputs[:, -1, :] # extract last output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# our RNN variables | |
Wx = tf.get_variable(name='Wx', shape=[embedding_size, rnn_size]) | |
Wh = tf.get_variable(name='Wh', shape=[rnn_size, rnn_size]) | |
bias_rnn = tf.get_variable(name='brnn', initializer=tf.zeros([rnn_size])) | |
def rnn_step(prev_hidden_state, x): | |
return tf.tanh(tf.matmul(x, Wx) + tf.matmul(prev_hidden_state, Wh) + bias_rnn) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
embeddings_matrix = tf.get_variable("embedding", [vocabulary_size, embedding_size]) | |
embed = tf.nn.embedding_lookup(embeddings_matrix, x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
sess.run(train_init_op) | |
while True: | |
try: | |
_, step, c, acc = sess.run([train_step, global_step, cost, accuracy]) | |
if step % 50 == 0: | |
print("Iter " + str(step) + | |
", batch loss {:.6f}".format(c) + |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
correct_pred = tf.equal(tf.argmax(pred, 1), y) | |
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) | |
cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logit)) | |
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) | |
train_step = optimizer.minimize(cost, global_step=global_step) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rnn_cell = tf.contrib.rnn.BasicRNNCell(rnn_size) # will remove in tensorflow 2.0 | |
rnn_cell = tf.keras.layers.SimpleRNNCell(rnn_size) | |
outputs, states = tf.nn.dynamic_rnn(rnn_cell, embed, dtype=tf.float32) | |
# RNN outputs: [batch_size * seq_len * hidden_size] | |
# split and extract only last output | |
last_rnn_output = outputs[:, -1, :] | |
# Dense layers | |
dense1 = tf.layers.dense(last_rnn_output, 16, activation='relu') | |
logit = tf.layers.dense(last_rnn_output, 2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
training_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).repeat(5).shuffle(1024).batch(batch_size) | |
test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels)).repeat(1).batch(batch_size) | |
iterator = tf.data.Iterator.from_structure(training_dataset.output_types, | |
training_dataset.output_shapes) | |
train_init_op = iterator.make_initializer(training_dataset) | |
test_init_op = iterator.make_initializer(test_dataset) | |
# Input data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.imdb.load_data(num_words=vocabulary_size) | |
train_data = tf.keras.preprocessing.sequence.pad_sequences(train_data, maxlen=256) | |
test_data = tf.keras.preprocessing.sequence.pad_sequences(test_data, maxlen=256) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Assuming the original model looks like this: | |
model = Sequential() | |
model.add(Dense(2, input_dim=3, name='dense_1')) | |
model.add(Dense(3, name='dense_2')) | |
... | |
model.save_weights(fname) | |
""" | |
# new model |