Created
May 14, 2016 00:17
-
-
Save bartvm/0f93143fe428c7f57f9c9ed9d7021f95 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import theano | |
from theano import tensor | |
import numpy as np | |
def main(length): | |
# Sequential input | |
x = tensor.vector('x') | |
# Targets | |
y = tensor.vector('y') | |
# Sequential targets i.e. this says where to write at step t | |
idxs = tensor.lvector('idxs') | |
# The vector we use to write to | |
h = tensor.alloc(np.float32(0), *x.shape) | |
# At each step, write somewhere in this global tensor | |
def step(x_elem, idx, h): | |
h = tensor.set_subtensor(h[idx], x_elem) | |
return h | |
ys, _ = theano.scan(fn=step, outputs_info=[h], sequences=[x, idxs]) | |
cost = tensor.sqr(ys[-1] - y).mean() | |
grads = tensor.grad(cost, x) | |
# Some numerical data to test this with | |
x_val = np.linspace(0, 1, length) | |
y_val = np.linspace(0, 1, length) | |
idxs_val = np.arange(length)[::-1] | |
rev_idxs_val = idxs_val | |
# rev_idxs_val = np.argsort(idxs_val) | |
f = theano.function([x, y, idxs], [cost, grads], allow_input_downcast=True) | |
print(f(x_val, y_val, rev_idxs_val)) | |
if __name__ == "__main__": | |
main(int(sys.argv[1])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment