Created
June 26, 2013 14:28
-
-
Save bayerj/5867800 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import theano | |
import theano.tensor as T | |
import gnumpy | |
import theano.misc.gnumpy_utils as gput | |
g = gnumpy.zeros((1, 2)) | |
g += np.array([[2, 1]]) | |
x = T.matrix() | |
expr = 2 * x | |
# Naive approach. This does not work when called with a CudaNdArray. | |
f = theano.function([x], 2 * x) | |
def function(inpt, inpt_templates, expr): | |
# Create shared variables for each input template. | |
shared = [theano.shared(np.empty(i.shape, dtype='float32')) for i in inpt_templates] | |
# Let the function we will compile with theano point at the shared variables instead of the symbolic ones. | |
for i, s in zip(inpt, shared): | |
expr = theano.clone(expr, {i: s}) | |
# Compile a function with no arguments to call. | |
f = theano.function([], expr) | |
def inner(*args): | |
for a, s in zip(args, shared): | |
s.set_value(gput.garray_to_cudandarray(a)) | |
return f() | |
return inner | |
# Does the job. | |
f2 = function([x], [g], expr) | |
f2(g) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment