Last active
December 26, 2015 12:49
-
-
Save sisp/7154041 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.tensor as T | |
floatX = theano.config.floatX | |
def scan1(a, y, x=None): | |
""" | |
works | |
""" | |
def step(a_, y_): | |
return 0.9 * a_, theano.scan_module.until(y_ <= 0) | |
return theano.scan(step, outputs_info=[a], non_sequences=[y], n_steps=10) | |
def scan11(a, y, x=None): | |
""" | |
does not work | |
This case is somewhat close to what I am doing in line search. | |
""" | |
def step(a_, y_): | |
y_ = theano.clone(y_, replace={x: 2*x}) | |
return 0.9 * a_, theano.scan_module.until(y_ <= 0) | |
return theano.scan(step, outputs_info=[a], non_sequences=[y], n_steps=10) | |
def scan2(a, y, x=None): | |
""" | |
does not work | |
""" | |
def step(a_): | |
return 0.9 * a_, theano.scan_module.until(y <= 0) | |
return theano.scan(step, outputs_info=[a], n_steps=10) | |
def run(a, y, givens, scan_fn, x=None): | |
rval, updates = scan_fn(a, y, x) | |
f = theano.function([], rval, givens=givens, updates=updates) | |
print f() | |
if __name__ == '__main__': | |
a = T.constant(1.0, name='a') | |
x = T.vector('x') | |
y = x.sum() | |
X = theano.shared(np.random.uniform(size=10).astype(floatX), borrow=True) | |
givens = {x: X} | |
run(a, y, givens, scan1) # works | |
run(a, y, givens, scan11, x) # does not work | |
run(a, y, givens, scan2) # does not work |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment