Skip to content

Instantly share code, notes, and snippets.

@leslie-fang-intel
Created April 2, 2025 09:54
Show Gist options
  • Select an option

  • Save leslie-fang-intel/0e0da35d787f554fbb01ec3b5266b6ca to your computer and use it in GitHub Desktop.

Select an option

Save leslie-fang-intel/0e0da35d787f554fbb01ec3b5266b6ca to your computer and use it in GitHub Desktop.
import torch
import torch._inductor.config as config
# config.realize_opcount_threshold = 1
class SimpleModel(torch.nn.Module):
def forward(self, x0, x1, x2):
tmp = x0 + x1
tmp2 = tmp * x2
return tmp2
if __name__ == "__main__":
with torch.no_grad():
model = SimpleModel().eval()
x0 = torch.randn(2, 64, dtype=torch.float32)
x1 = torch.randn(2, 64, dtype=torch.float32)
x2 = torch.randn(2, 64, dtype=torch.float32)
ref_res = model(x0, x1, x2)
cfn = torch.compile(model)
res = cfn(x0, x1, x2)
print(torch.allclose(ref_res, res), flush=True)
print("done", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment