Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jchia/4c30e4f6a901d117193c4833680f8be4 to your computer and use it in GitHub Desktop.
Save jchia/4c30e4f6a901d117193c4833680f8be4 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import numpy as np
import keras.layers as kl
import tensorflow as tf
class_count = 5
width, height = 100, 100
# Placeholders
ph = tf.placeholder(shape=[1, width, height], dtype=tf.float32)
labels_ph = tf.placeholder(shape=[1], dtype=tf.int32)
# GRU with batch normalization
with tf.variable_scope("GRU1"):
gru_cell = tf.contrib.rnn.GRUCell(num_units=20)
final_state = kl.SimpleRNN(20)(ph)
bn = kl.BatchNormalization()
norm_output_gru = bn(final_state)
# The prediction layer
logits = kl.Dense(class_count, activation=None)(norm_output_gru)
onehot_labels = tf.one_hot(labels_ph, depth=class_count)
# Optimizer
loss = tf.losses.softmax_cross_entropy(onehot_labels, logits)
optimizer = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.9)
# From https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
with tf.control_dependencies(bn.updates):
train_op = optimizer.minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run([train_op, loss], feed_dict={ph: np.zeros([1, width, height]), labels_ph: np.zeros([1])})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment