Created
April 15, 2019 01:37
-
-
Save Stonesjtu/de6b7db0c3e227c5f11568922741dd20 to your computer and use it in GitHub Desktop.
Test the bidirectional LSTM type
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A simple script to test the biLSTM type that pytorch uses. | |
The gradients are computed only w.r.t the output of one single direction, | |
so gradient of the reverse direction in layer 1 should be zero if type1. | |
In my tests, it's type2 | |
""" | |
import torch | |
from torch import nn | |
bilstm = nn.LSTM(10, 100, 2, bidirectional=True) | |
fake_input = torch.Tensor(10, 10, 10) | |
output, _ = bilstm(fake_input) | |
output[..., :100].sum().backward() | |
print('==========BiLSTM on CPU==========') | |
for name, param in bilstm.named_parameters(): | |
print('{} grad: {}'.format(name, param.grad.mean().item())) | |
bilstm.zero_grad() | |
bilstm.cuda() | |
output, _ = bilstm(fake_input.cuda()) | |
output[..., :100].sum().backward() | |
print('==========BiLSTM on GPU==========') | |
for name, param in bilstm.named_parameters(): | |
print('{} grad: {}'.format(name, param.grad.mean().item())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment