Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 20, 2024 05:13
Show Gist options
  • Save pashu123/306bb71d86d18561d1e944d4e6a64d28 to your computer and use it in GitHub Desktop.
Save pashu123/306bb71d86d18561d1e944d4e6a64d28 to your computer and use it in GitHub Desktop.
import torch
def matmul_func(a, b):
y = torch.broadcast_to(b, (a.shape[0], b.shape[0], b.shape[1]))
y = torch.transpose(y, 1, 2).float()
z = torch.bmm(a, y)
return z
compiled_module = torch.jit.script(matmul_func)
print(compiled_module.graph)
a = torch.randn(4, 128, 3200)
b = torch.randn(8640, 3200).to(torch.half)
import time
for i in range(10):
start = time.perf_counter()
x = compiled_module(a, b)
end = time.perf_counter()
print(f"The time taken for the operation is: {(end - start)*1000} ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment