Created
December 28, 2021 17:31
-
-
Save jamesr66a/d3fcfbd2d41b48267ab993ae3a710c48 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
d_hid = 512 | |
class ExampleCode(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) | |
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) | |
self.lin = torch.nn.Linear(d_hid, d_hid) | |
def forward(self, x): | |
x = torch.mm(x, self.mm_param) | |
skip_connection = x | |
x = torch.relu(x) | |
x = torch.mm(x, self.mm_param) | |
x = self.lin(x) | |
x = torch.relu(x) | |
x = x + skip_connection | |
x = torch.mm(x, self.mm_param2) | |
x = self.lin(x) | |
return x | |
ec = ExampleCode() | |
input = torch.randn(53, 512) | |
# Reference output: full batch | |
ref_out = ec(input) | |
# Test output: split batch, process separately, cat together | |
split_batches = torch.split(input, 10, dim=0) | |
split_results = [] | |
for batch in split_batches: | |
split_results.append(ec(batch)) | |
test_out = torch.cat(split_results, dim=0) | |
# Test epsilon equivalence | |
torch.testing.assert_allclose(test_out, ref_out) | |
""" | |
AssertionError: Tensor-likes are not close! | |
Mismatched elements: 89 / 27136 (0.3%) | |
Greatest absolute difference: 0.0024471282958984375 at index (50, 494) (up to 1e-05 allowed) | |
Greatest relative difference: 0.010300797414561929 at index (26, 8) (up to 0.0001 allowed) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment