Last active
December 26, 2015 18:39
-
-
Save sisp/7196092 to your computer and use it in GitHub Desktop.
This is the output I get for the two function calls. [-3.11292315 -2.80163074 -2.52146769 -2.26932096 -2.04238892 -1.83815002 -1.65433502 -1.4889015 -1.34001136 -1.20601022]
Traceback (most recent call last): File "test_scan.py", line 52, in <module> run(y, dy_dw, w, None, givens, scan) # doesn't work File "test_scan.py", line 36, in run rval, u…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.tensor as T | |
floatX = theano.config.floatX | |
def scan(y, dy_dw, w, x): | |
def step(dy_dw_, y_, w_, a_, *x_): | |
""" | |
This function computes part of what is done during a simple line search | |
iteration. The line search I'm referring to is described in [1] on | |
page 34. `y_` is somewhat similar to the left-hand side of the | |
termination condition. `y_ < 0` is just an arbitrary condition that is | |
in fact never met, but it suffices in order to demonstrate the | |
situation. | |
This function decreases the step size of the update proposal (here it's | |
the negative gradient) by a factor `a_` in every iteration until the | |
cost `y_` falls below some value (here 0). This doesn't make much sense | |
in terms of line search, but it shows the critical computations that | |
cause Theano to throw an error, in case the input of the graph `x` is | |
not added to the `non_sequences` list. | |
[1] http://www.cs.toronto.edu/~jmartens/docs/HF_book_chapter.pdf | |
""" | |
y_ = theano.clone(y_, replace={w_: w_ + dy_dw_}) | |
return a_ * dy_dw_, theano.scan_module.until(y_ < 0) | |
a = T.constant(0.9, name='a') | |
return theano.scan(step, | |
outputs_info=[dy_dw], | |
non_sequences=[y, w, a] + ([x] if x else []), | |
n_steps=10) | |
def run(y, dy_dw, w, x, givens, scan_fn): | |
rval, updates = scan_fn(y, dy_dw, w, x) | |
f = theano.function([], rval, givens=givens, updates=updates) | |
print f() | |
if __name__ == '__main__': | |
# some meaningless cost function just to have something here | |
x = T.vector('x') | |
w = theano.shared(np.asarray(np.random.uniform(), dtype=floatX), borrow=True) | |
y = ((w*x) ** 2).sum() | |
dy_dw = -T.grad(y, w) | |
X = theano.shared(np.random.uniform(size=10).astype(floatX), borrow=True) | |
givens = {x: X} | |
run(y, dy_dw, w, x, givens, scan) # works | |
run(y, dy_dw, w, None, givens, scan) # doesn't work |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment