Last active
January 11, 2016 15:12
-
-
Save dwf/8cd6d6d347e052353e4d to your computer and use it in GitHub Desktop.
Spooky bug in Theano and/or Blocks.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import io | |
| from blocks.bricks import Rectifier, MLP, Sequence, Logistic | |
| from blocks.select import Selector | |
| from theano import tensor | |
| from theano.printing import debugprint | |
| def wtf(call_b): | |
| sub_mlp = MLP([Rectifier(), Rectifier()], [5, 4, 3]) | |
| a = Sequence([sub_mlp.apply, Logistic().apply]) | |
| b = Sequence([sub_mlp.apply, Logistic().apply]) | |
| x = tensor.matrix() | |
| y = a.apply(x) | |
| dprint1 = io.StringIO() | |
| debugprint(y, file=dprint1) | |
| if call_b: | |
| b.apply(x) | |
| dprint2 = io.StringIO() | |
| debugprint(y, file=dprint2) | |
| # SAME debugprint of y's graph | |
| assert dprint1.getvalue() == dprint2.getvalue() | |
| params = Selector([a]).get_parameters().values() | |
| # SAME contents in the params list. | |
| assert len(params) == 4 | |
| assert sub_mlp.linear_transformations[0].W in params | |
| assert sub_mlp.linear_transformations[0].b in params | |
| assert sub_mlp.linear_transformations[1].W in params | |
| assert sub_mlp.linear_transformations[1].b in params | |
| # This will error if b.apply was called. | |
| print(tensor.grad(y.sum(), list(params))) | |
| if __name__ == "__main__": | |
| print("Without calling b.apply:") | |
| wtf(call_b=False) | |
| print("--------------") | |
| print("With calling b.apply (result never used!):") | |
| wtf(call_b=True) |
Author
That's nasty. I will try to figure out why do not have a defence against multiple allocation.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The answer:
sub_mlpgets re-allocated during the implicit call tob.allocate()in line 18.