Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created January 9, 2025 09:25
Show Gist options
  • Select an option

  • Save leslie-fang-intel/9e731ab132cd51c23d50041ef4c2ea88 to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/9e731ab132cd51c23d50041ef4c2ea88 to your computer and use it in GitHub Desktop.
import torch
from torch._inductor import config
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16
# dtype = torch.float16
bias = True
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return torch.relu(x)
if __name__ == "__main__":
with torch.no_grad():
input = torch.randn(batch_size, in_features, dtype=dtype)
m = M(bias=bias).eval()
ref_res = m(input)
cm = torch.compile(m)
act_res = cm(input)
input2 = torch.randn(batch_size+2, in_features, dtype=dtype)
act_res = cm(input2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment