Created
March 15, 2018 08:39
-
-
Save duducheng/e1a74ac8a1c45ac641b76ce0320cdb52 to your computer and use it in GitHub Desktop.
Clockwalk RNN and Temporal Kernel RNN in TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''At present these implementations haven't used `tf.scan` or other tf functions to speed up.''' | |
import numpy as np | |
import tensorflow as tf | |
from functools import reduce | |
from tqdm import trange | |
import matplotlib.pyplot as plt | |
class RNN: | |
def __init__(self, n_steps, n_hidden, input_size=1, output_size=1, | |
learning_rate=1e-3, output_dir='./', | |
built_in_cell=tf.contrib.rnn.BasicRNNCell, sess=None): | |
self.n_steps = n_steps | |
self.n_hidden = n_hidden | |
self.input_size = input_size | |
self.output_size = output_size | |
self.learning_rate = learning_rate | |
self.built_in_cell = built_in_cell | |
# build model | |
self._build_model() | |
# initialize loss and optimizer | |
self._loss_optimizer() | |
# summary | |
self.summary = tf.summary.merge_all() | |
# open Session | |
if sess is None: | |
self.sess = tf.Session() | |
else: | |
self.sess = sess | |
init = tf.global_variables_initializer() | |
self.sess.run(init) | |
# writer | |
self.train_writer = tf.summary.FileWriter(output_dir + '/train', | |
self.sess.graph) | |
self.val_writer = tf.summary.FileWriter(output_dir + '/val') | |
def _build_model(self): | |
self.X = tf.placeholder(dtype=tf.float32, | |
shape=(None, self.n_steps, self.input_size), | |
name='inputs') | |
self.Y = tf.placeholder(dtype=tf.float32, | |
shape=[None, self.output_size], name="targets") | |
self.global_step = tf.Variable(0, name='global_step', trainable=False) | |
with tf.variable_scope("RNN"): | |
cell = self.built_in_cell( | |
num_units=self.n_hidden, activation=tf.tanh) | |
rnn_outputs, _ = tf.nn.dynamic_rnn(cell=cell, dtype=tf.float32, | |
inputs=self.X) | |
self.output = tf.layers.dense(rnn_outputs[:, -1, :], | |
units=self.output_size, name='output') | |
def _loss_optimizer(self): | |
error = tf.reduce_sum(tf.square(self.Y - self.output), | |
axis=1, name='error') | |
self.loss = tf.reduce_mean(error, name="loss") | |
tf.summary.scalar("mse", self.loss) | |
optimizer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate) | |
self.training = optimizer.minimize(self.loss, | |
global_step=self.global_step, | |
name='training') | |
def fit(self, generator, steps, validation_data=None, savefig='./figure'): | |
training_loss = [] | |
validation_loss = [] | |
for _ in trange(steps): | |
X_once, Y_once = next(generator) | |
_, train_step, train_loss, train_summary = self.sess.run( | |
[self.training, self.global_step, self.loss, self.summary], | |
feed_dict={self.X: X_once, self.Y: Y_once}) | |
self.train_writer.add_summary(train_summary, train_step) | |
training_loss.append(train_loss) | |
if (validation_data is not None) and (train_step % 50 == 0): | |
val_loss, val_pred, val_summary = self.sess.run( | |
[self.loss, self.output, self.summary], | |
feed_dict={self.X: validation_data[0], | |
self.Y: validation_data[1]}) | |
# print(val_loss, end=',') | |
plt.clf() | |
plt.plot(validation_data[1].flatten()[:50]) | |
plt.plot(val_pred.flatten()[:50]) | |
plt.title("Loss:%s" % val_loss) | |
plt.savefig(savefig + '_%s.png' % train_step) | |
self.val_writer.add_summary(val_summary, train_step) | |
validation_loss.append(val_loss) | |
return training_loss, validation_loss | |
def predict(self, X): | |
pred = self.sess.run(self.output, | |
feed_dict={self.X: X}) | |
return pred | |
@property | |
def trainable_size(self): | |
trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) | |
return sum([int(reduce(lambda x, y: x * y, trainable.get_shape())) | |
for trainable in trainables]) | |
class ClockworkRNN(RNN): | |
'''A Clockwork RNN. https://arxiv.org/abs/1402.3511''' | |
def __init__(self, n_steps, n_hidden, input_size=1, output_size=1, | |
learning_rate=1e-3, periods=[1], output_dir='./', sess=None): | |
if n_hidden % len(periods) != 0: | |
raise ValueError("ClockworkRNN requires the `n_hidden1 to be " | |
"a multiple of the number of `period`.") | |
self.n_steps = n_steps | |
self.n_hidden = n_hidden | |
self.input_size = input_size | |
self.output_size = output_size | |
self.learning_rate = learning_rate | |
self.periods = periods | |
# build model | |
self._build_model() | |
# initialize loss and optimizer | |
self._loss_optimizer() | |
# summary | |
self.summary = tf.summary.merge_all() | |
# open Session | |
if sess is None: | |
self.sess = tf.Session() | |
else: | |
self.sess = sess | |
init = tf.global_variables_initializer() | |
self.sess.run(init) | |
# writer | |
self.train_writer = tf.summary.FileWriter(output_dir + '/train', | |
self.sess.graph) | |
self.val_writer = tf.summary.FileWriter(output_dir + '/val') | |
def _build_model(self): | |
self.X = tf.placeholder(dtype=tf.float32, | |
shape=(None, self.n_steps, self.input_size), | |
name='inputs') | |
self.Y = tf.placeholder(dtype=tf.float32, | |
shape=[None, self.output_size], name="targets") | |
self.global_step = tf.Variable(0, name='global_step', trainable=False) | |
group_size = self.n_hidden // len(self.periods) | |
mask = np.zeros((self.n_hidden, self.n_hidden), np.float32) | |
period = np.zeros(self.n_hidden) | |
for i, t in enumerate(self.periods): | |
mask[i * group_size:(i + 1) * group_size, i * group_size:] = 1 | |
period[i * group_size:(i + 1) * group_size] = t | |
clockwork_mask = tf.constant( | |
mask, dtype=tf.float32, name='clockword_mask') | |
clockwork_period = tf.constant( | |
period, dtype=tf.int32, name='clockwork_period') | |
with tf.variable_scope("clockwork_cell"): | |
input_weights = tf.get_variable("input_weights", | |
shape=[self.input_size, | |
self.n_hidden]) # W_I | |
state_weights = tf.get_variable("state_weights_", | |
shape=[self.n_hidden, | |
self.n_hidden]) # W_H' | |
state_weights = tf.multiply(state_weights, clockwork_mask, | |
name='state_weigths') # W_H | |
self.state_weights = state_weights | |
biases = tf.get_variable("biases", shape=[self.n_hidden]) # b_H | |
state = tf.zeros_like(dtype=tf.float32, | |
tensor=tf.matmul(self.X[:, 0, :], | |
input_weights)) | |
for time_step in range(self.n_steps): | |
wI_x = tf.matmul(self.X[:, time_step, :], input_weights) | |
wH_y = tf.matmul(state, state_weights) | |
current_state = tf.tanh(wH_y + wI_x + biases) | |
# Note: this implement will not speed up (over SRN) | |
current_state = tf.where( | |
tf.equal(tf.mod(time_step, clockwork_period), 0), | |
tf.transpose(current_state), | |
tf.transpose(state)) | |
state = tf.transpose(current_state) | |
self.output = tf.layers.dense( | |
state, units=self.output_size, name='output') | |
class MomentumRNN(RNN): | |
'''Temporal Kernel Recurrent Neural Networks. http://www.cs.utoronto.ca/~ilya/pubs/2008/tkrnn.pdf | |
My implementation (a little modification). | |
Seems better (or faster). | |
Temporal Kernel RNN with full kernel size.''' | |
def __init__(self, n_steps, n_hidden, input_size=1, output_size=1, | |
learning_rate=1e-3, output_dir='./', sess=None, fix_lambdas=None): | |
if fix_lambdas is not None and (fix_lambdas > 1 or fix_lambdas < 0): | |
raise ValueError('`lambdas` should be always in [0,1]') | |
self.n_steps = n_steps | |
self.n_hidden = n_hidden | |
self.input_size = input_size | |
self.output_size = output_size | |
self.learning_rate = learning_rate | |
self.fix_lambdas = fix_lambdas | |
# build model | |
self._build_model() | |
# initialize loss and optimizer | |
self._loss_optimizer() | |
# summary | |
self.summary = tf.summary.merge_all() | |
# open Session | |
if sess is None: | |
self.sess = tf.Session() | |
else: | |
self.sess = sess | |
init = tf.global_variables_initializer() | |
self.sess.run(init) | |
# writer | |
self.train_writer = tf.summary.FileWriter(output_dir + '/train', | |
self.sess.graph) | |
self.val_writer = tf.summary.FileWriter(output_dir + '/val') | |
def _build_model(self): | |
self.X = tf.placeholder(dtype=tf.float32, | |
shape=(None, self.n_steps, self.input_size), | |
name='inputs') | |
self.Y = tf.placeholder(dtype=tf.float32, | |
shape=[None, self.output_size], name="targets") | |
self.global_step = tf.Variable(0, name='global_step', trainable=False) | |
with tf.variable_scope("momentum_cell"): | |
# Wxy | |
input_weights = tf.get_variable("input_weights", | |
shape=[self.input_size, self.n_hidden]) | |
# Wyy | |
state_weights = tf.get_variable("state_weights", | |
shape=[self.n_hidden, self.n_hidden]) | |
# b | |
biases = tf.get_variable("biases", shape=[self.n_hidden]) | |
# Syt | |
acc_state = tf.zeros_like(dtype=tf.float32, | |
tensor=tf.matmul(self.X[:, 0, :], input_weights)) | |
# Sxt | |
acc_input = tf.zeros_like(dtype=tf.float32, | |
tensor=self.X[:, 0, :]) | |
if self.fix_lambdas is not None: | |
lambdas_x = tf.constant( | |
[self.fix_lambdas] * self.input_size, name='lambdas_x', dtype=tf.float32) | |
lambdas_y = tf.constant( | |
[self.fix_lambdas] * self.n_hidden, name='lambdas_y', dtype=tf.float32) | |
else: | |
# non constrain lambdas_x | |
lambdas_x = tf.get_variable("lambdasx_", shape=[self.input_size], | |
initializer=tf.constant_initializer(0.0)) | |
# lambdas_x | |
lambdas_x = tf.nn.sigmoid(lambdas_x, name='lambdas_x') | |
# non constrain lambdas_y | |
lambdas_y = tf.get_variable("lambdasy_", shape=[self.n_hidden], | |
initializer=tf.constant_initializer(0.0)) | |
# lambdas_y | |
lambdas_y = tf.nn.sigmoid(lambdas_y, name='lambdas_y') | |
for time_step in range(self.n_steps): | |
acc_input = self.X[:, time_step, :] + \ | |
tf.multiply(lambdas_x, acc_input) | |
Wxy_Sxt = tf.matmul(acc_input, input_weights) | |
Wyy_Syt = tf.matmul(acc_state, state_weights) | |
state = tf.tanh(Wxy_Sxt + Wyy_Syt + biases) | |
acc_state = state + tf.multiply(lambdas_y, acc_state) | |
self.output = tf.layers.dense( | |
state, units=self.output_size, name='output') | |
class TemporalKernelRNN(RNN): | |
'''As appeared in the original paper, which is extremely slow, | |
since there are lots of unnecessary unfold maps in this implement. | |
Moreover, this design seems to work worse than the MomentumRNN.''' | |
def __init__(self, n_steps, n_hidden, input_size=1, output_size=1, | |
learning_rate=1e-3, output_dir='./', sess=None, | |
fix_lambdas=None, temporal_kernel_size=None): | |
if fix_lambdas is not None and (fix_lambdas > 1 or fix_lambdas < 0): | |
raise ValueError('`lambdas` should be always in [0,1]') | |
if temporal_kernel_size is None: | |
# this case is same as in the MomentumRNN. | |
self.temporal_kernel_size = n_steps | |
else: | |
self.temporal_kernel_size = temporal_kernel_size | |
self.n_steps = n_steps | |
self.n_hidden = n_hidden | |
self.input_size = input_size | |
self.output_size = output_size | |
self.learning_rate = learning_rate | |
self.fix_lambdas = fix_lambdas | |
# build model | |
self._build_model() | |
# initialize loss and optimizer | |
self._loss_optimizer() | |
# summary | |
self.summary = tf.summary.merge_all() | |
# open Session | |
if sess is None: | |
self.sess = tf.Session() | |
else: | |
self.sess = sess | |
init = tf.global_variables_initializer() | |
self.sess.run(init) | |
# writer | |
self.train_writer = tf.summary.FileWriter(output_dir + '/train', | |
self.sess.graph) | |
self.val_writer = tf.summary.FileWriter(output_dir + '/val') | |
def _build_model(self): | |
self.X = tf.placeholder(dtype=tf.float32, | |
shape=(None, self.n_steps, self.input_size), | |
name='inputs') | |
self.Y = tf.placeholder(dtype=tf.float32, | |
shape=[None, self.output_size], name="targets") | |
self.global_step = tf.Variable(0, name='global_step', trainable=False) | |
with tf.variable_scope("temporal_kernel_cell"): | |
# Wxy | |
input_weights = tf.get_variable("input_weights", | |
shape=[self.input_size, self.n_hidden]) | |
# Wyy | |
state_weights = tf.get_variable("state_weights", | |
shape=[self.n_hidden, self.n_hidden]) | |
# b | |
biases = tf.get_variable("biases", shape=[self.n_hidden]) | |
if self.fix_lambdas is not None: | |
lambdas_x = tf.constant( | |
[self.fix_lambdas] * self.input_size, name='lambdas_x', dtype=tf.float32) | |
lambdas_y = tf.constant( | |
[self.fix_lambdas] * self.n_hidden, name='lambdas_y', dtype=tf.float32) | |
else: | |
# non constrain lambdas_x | |
lambdas_x = tf.get_variable("lambdasx_", shape=[self.input_size], | |
initializer=tf.constant_initializer(0.0)) | |
# lambdas_x | |
lambdas_x = tf.nn.sigmoid(lambdas_x, name='lambdas_x') | |
# non constrain lambdas_y | |
lambdas_y = tf.get_variable("lambdasy_", shape=[self.n_hidden], | |
initializer=tf.constant_initializer(0.0)) | |
# lambdas_y | |
lambdas_y = tf.nn.sigmoid(lambdas_y, name='lambdas_y') | |
states = [] | |
for time_step in range(self.n_steps): | |
# Sxt | |
acc_input = tf.zeros_like(dtype=tf.float32, | |
tensor=self.X[:, 0, :]) | |
# Syt | |
acc_state = tf.zeros_like(dtype=tf.float32, | |
tensor=tf.matmul(self.X[:, 0, :], input_weights)) | |
for i, state in enumerate(states[-self.temporal_kernel_size:]): | |
acc_input = self.X[:, time_step - self.n_steps + 1 + i, :] +\ | |
lambdas_x * acc_input | |
acc_state = state + lambdas_y * acc_state | |
Wxy_Sxt = tf.matmul(acc_input, input_weights) | |
Wyy_Syt = tf.matmul(acc_state, state_weights) | |
states.append(tf.tanh(Wxy_Sxt + Wyy_Syt + biases, | |
name='state_at_%s' % time_step)) | |
self.output = tf.layers.dense( | |
states[-1], units=self.output_size, name='output') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment