Created
May 11, 2017 15:50
-
-
Save apaszke/40eef574d59c751b38568aa10227ba8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From baaec5d9f9e0e32bbd7d089d99698e1d83966f5d Mon Sep 17 00:00:00 2001 | |
From: Adam Paszke <[email protected]> | |
Date: Thu, 11 May 2017 07:36:09 -0700 | |
Subject: [PATCH 1/2] Disable fused RNN kernels | |
--- | |
test/test_nn.py | 24 ++++++++++++++---------- | |
torch/nn/_functions/rnn.py | 7 ++++--- | |
torch/nn/_functions/thnn/rnnFusedPointwise.py | 2 ++ | |
3 files changed, 20 insertions(+), 13 deletions(-) | |
diff --git a/test/test_nn.py b/test/test_nn.py | |
index cfe2304..d2a649e 100644 | |
--- a/test/test_nn.py | |
+++ b/test/test_nn.py | |
@@ -1673,16 +1673,20 @@ class TestNN(NNTestCase): | |
self.assertEqual(hidden1, hidden2) | |
def _test_rnn_retain_variables(self, dtype): | |
- rnn = nn.LSTM(10, 20, num_layers=2).type(dtype) | |
- input = Variable(torch.randn(5, 6, 10).type(dtype), requires_grad=True) | |
- output = rnn(input) | |
- output[0].sum().backward(retain_graph=True) | |
- grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()] | |
- rnn.zero_grad() | |
- input.grad.data.zero_() | |
- output[0].sum().backward(retain_graph=True) | |
- grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()] | |
- self.assertEqual(grads, grads2) | |
+ rnns = [nn.LSTM(10, 20, num_layers=2).type(dtype), | |
+ nn.GRU(10, 20, num_layers=2).type(dtype), | |
+ nn.RNN(10, 20, num_layers=2).type(dtype)] | |
+ for rnn in rnns: | |
+ input = Variable(torch.randn(5, 6, 10).type(dtype), requires_grad=True) | |
+ output = rnn(input) | |
+ output[0].sum().backward(retain_graph=True) | |
+ grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()] | |
+ for i in range(4): | |
+ rnn.zero_grad() | |
+ input.grad.data.zero_() | |
+ output[0].sum().backward(retain_graph=True) | |
+ grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()] | |
+ self.assertEqual(grads, grads2) | |
def test_rnn_retain_variables(self): | |
self._test_rnn_retain_variables(torch.DoubleTensor) | |
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py | |
index 6881a2d..85a705e 100644 | |
--- a/torch/nn/_functions/rnn.py | |
+++ b/torch/nn/_functions/rnn.py | |
@@ -20,7 +20,8 @@ def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
- if input.is_cuda: | |
+ # TODO: enable fused again | |
+ if False and input.is_cuda: | |
igates = F.linear(input, w_ih) | |
hgates = F.linear(hidden[0], w_hh) | |
state = fusedBackend.LSTMFused() | |
@@ -43,8 +44,8 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
- | |
- if input.is_cuda: | |
+ # TODO: enable fused again | |
+ if False and input.is_cuda: | |
gi = F.linear(input, w_ih) | |
gh = F.linear(hidden, w_hh) | |
state = fusedBackend.GRUFused() | |
diff --git a/torch/nn/_functions/thnn/rnnFusedPointwise.py b/torch/nn/_functions/thnn/rnnFusedPointwise.py | |
index 19f05e5..7a789a7 100644 | |
--- a/torch/nn/_functions/thnn/rnnFusedPointwise.py | |
+++ b/torch/nn/_functions/thnn/rnnFusedPointwise.py | |
@@ -8,6 +8,7 @@ class GRUFused(Function): | |
self.backend = None | |
def forward(self, input_gate, hidden_gate, hx, ibias=None, hbias=None): | |
+ raise RuntimeError("fused RNNs are disabled") | |
if self.backend is None: | |
self.backend = type2backend[type(input_gate)] | |
hy = input_gate.new() | |
@@ -46,6 +47,7 @@ class LSTMFused(Function): | |
self.backend = None | |
def forward(self, input_gate, hidden_gate, cx, ibias=None, hbias=None): | |
+ raise RuntimeError("fused RNNs are disabled") | |
if self.backend is None: | |
self.backend = type2backend[type(input_gate)] | |
hy = input_gate.new() | |
-- | |
2.9.3 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment