Skip to content

Instantly share code, notes, and snippets.

@bayerj
Created December 22, 2011 12:10
Show Gist options
  • Select an option

  • Save bayerj/1510095 to your computer and use it in GitHub Desktop.

Select an option

Save bayerj/1510095 to your computer and use it in GitHub Desktop.
import scipy
import theano, theano.tensor as T
# Create parameters.
#
# Parameters are first allocated in a long consecutive array as a shared
# variable. Afterwards, reshaped subtensors are used in the expressions in.
# Why? We want to be able to differentiate wrt groups of parameters which
# are in consecutive memory.
#
# The most of the following code is just bookkeeping.
# Dimensionality of the data
n_inpt, n_hidden = 1, 3
# We have 2 parameter groups with the following shapes.
W1shape = n_inpt, n_hidden
W2shape = n_hidden, n_hidden
n_pars = n_inpt * n_hidden + n_hidden**2
# Allocate big parameter array.
pars = theano.shared(scipy.empty(n_pars))
# Assign slices.
W1 = pars[:n_inpt * n_hidden].reshape(W1shape)
W2 = pars[-n_hidden * n_hidden:].reshape(W2shape)
# Define recurrent model. We are using a model where each input is a tensor
# of shape (T, B, D) where T is the number of timesteps, B is the number of
# sequences iterated over in parallel and D is the dimensionality of each
# item at a timestep.
inpt = T.tensor3('inpt')
target = T.tensor3('target')
# Make these flat in order to be able to use dot products instead of tensordot,
# which is slower.
inpt_flat = inpt.reshape((inpt.shape[0] * inpt.shape[1], inpt.shape[2]))
hidden_flat = T.dot(inpt_flat, W1)
hidden = hidden_flat.reshape((inpt.shape[0], inpt.shape[1], n_hidden))
transfer = lambda x: x
transfer = T.nnet.sigmoid
hidden_rec, _ = theano.scan(
lambda x, h_tm1: transfer(T.dot(h_tm1, W2) + x),
sequences=hidden,
outputs_info=[T.zeros_like(hidden[0])])
hidden_rec_flat = hidden_rec.reshape(
(hidden_rec.shape[0] * hidden_rec.shape[1], hidden_rec.shape[2]))
cost = ((hidden_rec - target)**2).mean()
d_cost_wrt_pars = T.grad(cost, pars)
p = T.dvector()
Hp = T.Rop(d_cost_wrt_pars, pars, p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment