Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created June 30, 2023 04:23
Show Gist options
  • Save davidberard98/066fd2115f59f5889ef61e4527d1eba5 to your computer and use it in GitHub Desktop.
Save davidberard98/066fd2115f59f5889ef61e4527d1eba5 to your computer and use it in GitHub Desktop.
import torch
import torch._dynamo
import torch._inductor.inductor_prims
def fn(values, boundaries):
return torch.bucketize(values, boundaries)
def fn_ind(values, boundaries):
return torch.ops.prims._inductor_bucketize(values, boundaries)
def get_inputs():
values = torch.rand((16, 1024, 1024)).cuda()
boundaries = torch.rand((1025,)).sort()[0].cuda()
return (values, boundaries)
inputs = [get_inputs() for _ in range(32)]
opt_fn = torch.compile(fn_ind)
start_evt_eager = torch.cuda.Event(enable_timing=True)
end_evt_eager = torch.cuda.Event(enable_timing=True)
start_evt_pt2 = torch.cuda.Event(enable_timing=True)
end_evt_pt2 = torch.cuda.Event(enable_timing=True)
for inp in inputs:
fn(*inp)
opt_fn(*inp)
torch.cuda.synchronize()
start_evt_eager.record()
for inp in inputs:
fn(*inp)
end_evt_eager.record()
torch.cuda.synchronize()
print(f"Eager {start_evt_eager.elapsed_time(end_evt_eager) / len(inputs)} ms")
start_evt_pt2.record()
for inp in inputs:
opt_fn(*inp)
end_evt_pt2.record()
torch.cuda.synchronize()
print(f"PT2 {start_evt_pt2.elapsed_time(end_evt_pt2) / len(inputs)} ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment