Last active
September 18, 2018 15:02
-
-
Save mttk/38f55a82605b2e25485484ad136eb231 to your computer and use it in GitHub Desktop.
Undefined tensor error in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Just some regular data definition | |
device = 'cuda:0' | |
embed_dim = 2 | |
hidden_dim = 2 | |
B = 4 | |
T = 2 | |
y = torch.tensor([0, 1, 0, 1]).to(device) | |
criterion = nn.CrossEntropyLoss() | |
data = torch.randn((B,T,embed_dim)).to(device) | |
h0 = torch.randn((embed_dim)).repeat(B, 1).to(device) | |
h0 = (h0, h0) | |
net = nn.LSTM(embed_dim, hidden_dim) | |
net.to(device) | |
params = dict(net.named_parameters()) | |
w_hh = params['weight_hh_l0'] | |
w_ih = params['weight_ih_l0'] | |
b_hh = params['bias_hh_l0'] | |
b_ih = params['bias_ih_l0'] | |
lstm = torch.lstm_cell | |
grad_history = [] | |
def store_grads(name, timestep): | |
def hook(grad): | |
print(name, grad, timestep) | |
grad_history[timestep][name] = grad | |
return hook | |
def lstm_stepwise(input, h0, w_ih, w_hh, b_ih, b_hh): | |
hiddens = [h0] | |
T = input.size(1) # check if time is actually the first dim | |
for t in range(T): | |
grad_history.append({}) # new map | |
w_ih = w_ih[:] | |
w_hh = w_hh[:] | |
b_ih = b_ih[:] | |
b_hh = b_hh[:] | |
hiddens.append(lstm(input[:, t, :], hiddens[-1], w_ih, w_hh, b_ih, b_hh)) | |
w_ih.register_hook(store_grads('w_ih', t)) | |
w_hh.register_hook(store_grads('w_hh', t)) | |
b_ih.register_hook(store_grads('b_ih', t)) | |
b_hh.register_hook(store_grads('b_hh', t)) | |
return hiddens | |
hiddens = lstm_stepwise(data, h0, w_ih, w_hh, b_ih, b_hh) | |
h, c = hiddens[-1] | |
# This presumably triggers the error: | |
# the shallow variables for the output gate created by w[:] | |
# are not reached by the backprop signal | |
loss = criterion(c, y) | |
loss.backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment