Skip to content

Instantly share code, notes, and snippets.

@mttk
Created September 11, 2018 13:37
Show Gist options
  • Save mttk/4c30bd4bb066d39b1ec80bae0639c5c6 to your computer and use it in GitHub Desktop.
Save mttk/4c30bd4bb066d39b1ec80bae0639c5c6 to your computer and use it in GitHub Desktop.
LSTM stepwise backward
lstm = torch.lstm_cell
grad_history = []
def store_grads(name, timestep):
def hook(grad):
print(name, grad, timestep)
grad_history[timestep][name] = grad
return hook
def lstm_stepwise(input, h0, w_ih, w_hh, b_ih, b_hh):
h = h0
T = input.size(1)
for t in range(T):
grad_history.append({})
w_ih.register_hook(store_grads('w_ih', t))
w_hh.register_hook(store_grads('w_hh', t))
b_ih.register_hook(store_grads('b_ih', t))
b_hh.register_hook(store_grads('b_hh', t))
h = lstm(input[:, t, :], h, w_ih, w_hh, b_ih, b_hh)
return h
out = lstm_stepwise(data, h0, w_ih, w_hh, b_ih, b_hh)
loss = criterion(out[0], y)
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment