Last active
July 17, 2024 10:10
-
-
Save fxmarty/20df1ccb07a4492a0316e1dd0eac2232 to your computer and use it in GitHub Desktop.
torch_library
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
from torch.profiler import ProfilerActivity, profile | |
# We somehow need this import otherwise we get AttributeError: '_OpNamespace' 'mycppops' object has no attribute 'sin' | |
import mycppops | |
torch.library.define("mylib::sin", "(Tensor x) -> Tensor") | |
@torch.library.impl("mylib::sin", "default") | |
def f(x): | |
torch.sin_(x) | |
a = torch.rand(8192) | |
n_runs = 20 | |
print(f"input shape: {a.shape}, device: {a.device}, dtype: {a.dtype}") | |
print(f"Average over {n_runs} runs") | |
with torch.no_grad(): | |
# warmup | |
torch.ops.mylib.sin(a) | |
start = time.perf_counter() | |
for i in range(n_runs): | |
torch.ops.mylib.sin(a) | |
end = time.perf_counter() | |
print(f"torch.ops python sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms") | |
# warmup | |
torch.ops.mycppops.sin(a) | |
start = time.perf_counter() | |
for i in range(n_runs): | |
torch.ops.mycppops.sin(a) | |
end = time.perf_counter() | |
print(f"torch.ops C++ sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms") | |
# warmup | |
mycppops.sin_pybind(a) | |
start = time.perf_counter() | |
for i in range(n_runs): | |
mycppops.sin_pybind(a) | |
end = time.perf_counter() | |
print(f"pybind sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms") | |
# warmup | |
torch.sin_(a) | |
start = time.perf_counter() | |
for i in range(n_runs): | |
torch.sin_(a) | |
end = time.perf_counter() | |
print(f"torch.sin_ took: {(end - start) * 1e3 / n_runs:.3f} ms") | |
# prof.export_chrome_trace("tracenew.json") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment