Skip to content

Instantly share code, notes, and snippets.

@mallamanis
Created April 6, 2016 12:39
Show Gist options
  • Save mallamanis/76018c12a82f8294a8d41bd53aa6b533 to your computer and use it in GitHub Desktop.
Save mallamanis/76018c12a82f8294a8d41bd53aa6b533 to your computer and use it in GitHub Desktop.
Stupid simple Theano example that allows using all previous states in a `scan`
in_seq = T.matrix(name='input')
initial_state = theano.shared(np.array([0., 0.]), name='initial')
initial_scan_state = T.zeros((in_seq.shape[0], initial_state.shape[0]))
initial_scan_state = T.set_subtensor(initial_scan_state[0], initial_state)
def fun(i, current_element, state):
res = T.sum(state[:i], axis=0)+current_element
return T.set_subtensor(state[i], res)
r, u = theano.scan(fun, sequences=[T.arange(in_seq.shape[0]), in_seq], outputs_info=[initial_scan_state])
scf = theano.function(inputs=[in_seq], outputs=r[-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment