Created
December 31, 2017 09:23
-
-
Save talolard/8f58dd0a2d36417338cb8054f71ae86b to your computer and use it in GitHub Desktop.
Example of gru implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
GRU layer implementation orignally taken from https://github.com/ottokart/punctuator2 | |
''' | |
class GRULayer(object): | |
def __init__(self, rng, n_in, n_out, minibatch_size): | |
super(GRULayer, self).__init__() | |
# Notation from: An Empirical Exploration of Recurrent Network Architectures | |
self.n_in = n_in | |
self.n_out = n_out | |
# Initial hidden state | |
self.h0 = theano.shared(value=np.zeros((minibatch_size, n_out)).astype(theano.config.floatX), name='h0', borrow=True) | |
# Gate parameters: | |
self.W_x = weights_Glorot(n_in, n_out*2, 'W_x', rng) | |
self.W_h = weights_Glorot(n_out, n_out*2, 'W_h', rng) | |
self.b = weights_const(1, n_out*2, 'b', 0) | |
# Input parameters | |
self.W_x_h = weights_Glorot(n_in, n_out, 'W_x_h', rng) | |
self.W_h_h = weights_Glorot(n_out, n_out, 'W_h_h', rng) | |
self.b_h = weights_const(1, n_out, 'b_h', 0) | |
self.params = [self.W_x, self.W_h, self.b, self.W_x_h, self.W_h_h, self.b_h] | |
def step(self, x_t, h_tm1): | |
rz = T.nnet.sigmoid(T.dot(x_t, self.W_x) + T.dot(h_tm1, self.W_h) + self.b) | |
r = _slice(rz, self.n_out, 0) | |
z = _slice(rz, self.n_out, 1) | |
h = T.tanh(T.dot(x_t, self.W_x_h) + T.dot(h_tm1 * r, self.W_h_h) + self.b_h) | |
h_t = z * h_tm1 + (1. - z) * h | |
return h_t | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment