Skip to content

Instantly share code, notes, and snippets.

@fxmarty
Last active July 17, 2024 10:10
Show Gist options
  • Save fxmarty/20df1ccb07a4492a0316e1dd0eac2232 to your computer and use it in GitHub Desktop.
Save fxmarty/20df1ccb07a4492a0316e1dd0eac2232 to your computer and use it in GitHub Desktop.
torch_library
import torch
import time
from torch.profiler import ProfilerActivity, profile
# We somehow need this import otherwise we get AttributeError: '_OpNamespace' 'mycppops' object has no attribute 'sin'
import mycppops
torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
@torch.library.impl("mylib::sin", "default")
def f(x):
torch.sin_(x)
a = torch.rand(8192)
n_runs = 20
print(f"input shape: {a.shape}, device: {a.device}, dtype: {a.dtype}")
print(f"Average over {n_runs} runs")
with torch.no_grad():
# warmup
torch.ops.mylib.sin(a)
start = time.perf_counter()
for i in range(n_runs):
torch.ops.mylib.sin(a)
end = time.perf_counter()
print(f"torch.ops python sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms")
# warmup
torch.ops.mycppops.sin(a)
start = time.perf_counter()
for i in range(n_runs):
torch.ops.mycppops.sin(a)
end = time.perf_counter()
print(f"torch.ops C++ sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms")
# warmup
mycppops.sin_pybind(a)
start = time.perf_counter()
for i in range(n_runs):
mycppops.sin_pybind(a)
end = time.perf_counter()
print(f"pybind sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms")
# warmup
torch.sin_(a)
start = time.perf_counter()
for i in range(n_runs):
torch.sin_(a)
end = time.perf_counter()
print(f"torch.sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms")
# prof.export_chrome_trace("tracenew.json")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment