Created
May 27, 2016 16:23
-
-
Save bartvm/18f0571ead92bb51838c9500a6209436 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import theano | |
from theano import tensor, config | |
# The parameters | |
W = theano.shared(numpy.arange(9, dtype=config.floatX).reshape(3, 3)) | |
storage = theano.shared(numpy.zeros((3, 3), dtype=config.floatX)) | |
# The input | |
x = tensor.vector('x') | |
# The operation: Update storage with a value that is a function of W | |
new_storage = tensor.set_subtensor(storage[1], W.dot(x)) | |
out = new_storage.mean() | |
# Now for the gradients | |
f = theano.function([x], tensor.grad(out, [W, storage])) | |
grad_W, grad_storage = f(numpy.ones(3, dtype=config.floatX)) | |
print(grad_W) | |
print(grad_storage) | |
# The second row of storage shouldn't have a gradient (since it was overridden) | |
# W's gradient is the outer product of [1/9, 1/9, 1/9] (the gradient of the | |
# subtensor) and [1, 1, 1] (the input) | |
""" | |
[[ 0.11111111 0.11111111 0.11111111] | |
[ 0.11111111 0.11111111 0.11111111] | |
[ 0.11111111 0.11111111 0.11111111]] | |
[[ 0.11111111 0.11111111 0.11111111] | |
[ 0. 0. 0. ] | |
[ 0.11111111 0.11111111 0.11111111]] | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment