Created
June 6, 2018 10:29
-
-
Save EndingCredits/94928f3d6cbb3dd56a669df73180f8d6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
""" | |
# Generate some random data | |
# data = n x m, n = num examples, m = number features | |
# labels = n x c, n = num examples, c = number classes | |
""" | |
KARPATHY_CONST = 0.00025 | |
class tfESN(object): | |
def __init__(self, | |
in_channels, | |
max_len=1000, | |
n_units=100, | |
alpha=1.0, | |
tf_sess=None ): | |
self.n_inputs = in_channels | |
self.n_units = n_units | |
self.max_len = max_len | |
self.alpha = alpha | |
self.random_state_ = np.random.RandomState(12345) | |
self.sess = tf.get_default_session() if tf_sess is None else tf_sess | |
self.inputs = tf.placeholder(tf.float32, shape=(None, in_channels)) | |
self.W_out, self.h, self.pred, self.optim, self.loss = self._build_graph(self.inputs) | |
self.sess.run(tf.global_variables_initializer()) | |
def _build_graph(self, inputs): | |
# Shift y right one position, and pad X to same length | |
#inputs = tf.tile(inputs, [2, 1]) | |
y = tf.pad(inputs, ((0,1), (0,0)))[:,1:2] | |
X = tf.pad(inputs, ((1,0), (0,0))) | |
# Get hidden activations | |
h = self._esn_graph(X) | |
# Add bias unit | |
h_b = tf.pad(h-1.0, ((0,0), (0,1)))+1.0 | |
# Get length of sequence | |
used = tf.sign(tf.reduce_max(tf.abs(y), axis=1)) | |
length = tf.cast(tf.stop_gradient(tf.reduce_sum(used)), tf.int32) | |
# Remove first few | |
transient = tf.minimum(length / 10, 100) | |
# Truncate h and y | |
h_t = h_b[transient:length] | |
y_t = y[transient:length] | |
# Solve LS for labels to get W_out | |
C = tf.matmul(h_t, h_t, transpose_a=True) | |
C = C + self.alpha*tf.eye(self.n_units + 1) | |
D = tf.matmul(h_t, y_t, transpose_a=True) | |
W_out = tf.matmul(tf.matrix_inverse(C), D) | |
#W_out = tf.matrix_solve_ls(h, y, l2_regularizer=1.0, name ="W_out") | |
# Get output layer according to solution | |
y_pred = tf.matmul(h_b, W_out) | |
# Get mean of squared loss | |
y_pred_t = y_pred[transient:length] | |
loss = tf.reduce_mean(tf.square(y_pred_t - y_t))# / num_elem | |
optimizer = tf.train.AdamOptimizer(0.00025*40).minimize(loss) | |
return W_out, h, y_pred, optimizer, loss | |
def _crj_weights(self): | |
l = self.n_units // 15 | |
N = self.n_units | |
W_cyc = np.zeros((N, N)) | |
W_jmp = np.zeros((N, N)) | |
# Cycle connections | |
for i in range(N): | |
W_cyc[(i+1)%N,i] = 1.0 | |
# Jumps connections | |
for i in range(N // l): | |
W_jmp[((i+1)*l)%N,i*l] = 1.0 | |
W_jmp[i*l,((i+1)*l)%N] = 1.0 | |
# random input weights: | |
W_in = self.random_state_.randint(2, | |
size=(self.n_inputs,self.n_units) ) * 2 - 1 | |
return W_cyc, W_jmp, W_in | |
def _init_weights(self): | |
r_c = 0.7 | |
r_j = 0.4 | |
v = 0.9 | |
W_cyc, W_jump, W_in = self._crj_weights() | |
W = W_cyc * r_c + W_jmp * r_j | |
W_in = v * W_in | |
return W, W_in | |
def _esn_graph(self, X): | |
if False: | |
Ww, Vw = self._init_weights() | |
W = tf.Variable(Ww, name='R', dtype=tf.float32) | |
V = tf.Variable(Vw, name='V', dtype=tf.float32) | |
else: | |
W_cyc, W_jmp, W_in = self._crj_weights() | |
r = tf.Variable(np.ones(3), name='r', dtype=tf.float32) | |
W = W_cyc * r[0] + W_jmp * r[1] | |
V = W_in * r[2] | |
self.r = r | |
V_X = tf.expand_dims(tf.matmul(X, V), axis=1) | |
outputs = tf.scan(lambda h, x: tf.nn.tanh(tf.matmul(h, W) + x), V_X) | |
return tf.squeeze(outputs, axis=1) | |
def _rnn_graph(self, X): | |
X = tf.expand_dims(X, axis=1) | |
rnn_cell = tf.contrib.rnn.BasicRNNCell(self.n_units) | |
outputs, _ = tf.nn.dynamic_rnn(rnn_cell, X, dtype=tf.float32) | |
return tf.squeeze(outputs, axis=1) | |
def fit(self, X, train=False): | |
# expand last dim if constants | |
if X.ndim < 2: | |
X = np.reshape(X, (len(X), -1)) | |
# Pad to max length | |
#X = np.pad( X, ((0, self.max_len - len(X)), (0,0) ), "constant" ) | |
# Run graph and return results | |
if train: | |
W_out, states, pred, _, loss = self.sess.run([self.W_out, self.h, | |
self.pred, self.optim, self.loss], feed_dict={self.inputs: X}) | |
else: | |
W_out, states, pred, loss = self.sess.run([self.W_out, self.h, | |
self.pred, self.loss], feed_dict={self.inputs: X}) | |
return W_out, states, pred, loss | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment