Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created December 28, 2021 17:31
Show Gist options
  • Save jamesr66a/d3fcfbd2d41b48267ab993ae3a710c48 to your computer and use it in GitHub Desktop.
Save jamesr66a/d3fcfbd2d41b48267ab993ae3a710c48 to your computer and use it in GitHub Desktop.
import torch
d_hid = 512
class ExampleCode(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin = torch.nn.Linear(d_hid, d_hid)
def forward(self, x):
x = torch.mm(x, self.mm_param)
skip_connection = x
x = torch.relu(x)
x = torch.mm(x, self.mm_param)
x = self.lin(x)
x = torch.relu(x)
x = x + skip_connection
x = torch.mm(x, self.mm_param2)
x = self.lin(x)
return x
ec = ExampleCode()
input = torch.randn(53, 512)
# Reference output: full batch
ref_out = ec(input)
# Test output: split batch, process separately, cat together
split_batches = torch.split(input, 10, dim=0)
split_results = []
for batch in split_batches:
split_results.append(ec(batch))
test_out = torch.cat(split_results, dim=0)
# Test epsilon equivalence
torch.testing.assert_allclose(test_out, ref_out)
"""
AssertionError: Tensor-likes are not close!
Mismatched elements: 89 / 27136 (0.3%)
Greatest absolute difference: 0.0024471282958984375 at index (50, 494) (up to 1e-05 allowed)
Greatest relative difference: 0.010300797414561929 at index (26, 8) (up to 0.0001 allowed)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment