Last active
February 27, 2023 14:38
-
-
Save bluenote10/3370da06204b94995614ed014410f6c2 to your computer and use it in GitHub Desktop.
Benchmark delay line
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from contextlib import contextmanager | |
from pathlib import Path | |
from typing import List, Literal | |
import numba | |
import numba.cuda | |
import numpy as np | |
import tabulate | |
import torch | |
from torch.jit._script import script as torchscript | |
from torch.testing._comparison import assert_close | |
from typing_extensions import assert_never | |
import delay_line_cpp | |
def delay_line_impl(samples, delays): | |
for i in range(len(samples)): | |
delay = int(delays[i].item()) | |
index_delayed = i - delay | |
if index_delayed < 0: | |
index_delayed = 0 | |
samples[i] = 0.5 * (samples[i] + samples[index_delayed]) | |
delay_line_numba_cpu = numba.njit()(delay_line_impl) | |
delay_line_numba_cuda = numba.cuda.jit(delay_line_impl) | |
delay_line_torchscript = torchscript(delay_line_impl) | |
Mode = Literal["plain_python", "torchscript", "numba", "cpp"] | |
def delay_line(samples: torch.Tensor, delays: torch.Tensor, mode: Mode): | |
if mode == "plain_python": | |
return delay_line_impl(samples, delays) | |
elif mode == "torchscript": | |
return delay_line_torchscript(samples, delays) # type: ignore | |
elif mode == "numba": | |
if samples.is_cuda: | |
samples_view = numba.cuda.as_cuda_array(samples.detach()) | |
delays_view = numba.cuda.as_cuda_array(delays.detach()) | |
return delay_line_numba_cuda[1, 8](samples_view, delays_view) # type: ignore | |
else: | |
return delay_line_numba_cpu(samples.detach().numpy(), delays.detach().numpy()) # type: ignore | |
elif mode == "cpp": | |
delay_line_cpp.delay_line_forward(samples, delays) | |
else: | |
assert_never(mode) | |
def main(): | |
# Needed for CUDA just-in-time compilation. | |
os.environ["CUDA_HOME"] = str(Path.home() / "bin/cuda") | |
size = 1024 | |
modes: List[Mode] = ["plain_python", "torchscript", "numba", "cpp"] | |
results = [] | |
for use_gpu in [False, True]: | |
device_name = "GPU" if use_gpu else "CPU" | |
print(f" *** Testing on {device_name}") | |
for mode in modes: | |
device = torch.device("cuda:0" if use_gpu else "cpu") | |
timer = MeasureTime(f"{device_name} / {mode}", report_every=10) | |
for _ in range(100): | |
samples = torch.tensor( | |
np.random.uniform(-1, +1, size=(size,)), device=device, dtype=torch.float32 | |
) | |
delays = torch.tensor( | |
np.random.randint(1, 100, size=(size,)), device=device, dtype=torch.float32 | |
) | |
expected_output = samples.clone() | |
delay_line(expected_output, delays, mode="plain_python") | |
with timer.timed(): | |
delay_line(samples, delays, mode) | |
assert_close(samples, expected_output) | |
results.append( | |
{"Method": mode, "Device": device_name, "Median time [ms]": timer.median * 1000} | |
) | |
print() | |
sort_by_mode = False | |
if sort_by_mode: | |
results = sorted( | |
results, | |
key=lambda row: {"plain_python": 1, "torchscript": 2, "numba": 3, "cpp": 4}[ | |
row["Method"] | |
], | |
) | |
print(tabulate.tabulate(results, headers="keys", tablefmt="rounded_outline", floatfmt=".3f")) | |
class MeasureTime: | |
def __init__(self, name: str, *, report_every: int): | |
self.name = name | |
self.every = report_every | |
self.times: List[float] = [] | |
@contextmanager | |
def timed(self): | |
t1 = time.monotonic() | |
yield | |
t2 = time.monotonic() | |
self.times.append(t2 - t1) | |
if len(self.times) % self.every == 0: | |
self.show_stats() | |
def show_stats(self): | |
mean = np.mean(self.times_array) | |
std = np.std(self.times_array) | |
median = np.median(self.times_array) | |
unit = "sec" | |
if mean < 1: | |
mean *= 1000 | |
std *= 1000 | |
median *= 1000 | |
unit = "ms" | |
if mean < 1: | |
mean *= 1000 | |
std *= 1000 | |
median *= 1000 | |
unit = "us" | |
msg = f"{self.name:<30s} mean: {mean:.1f} ± {std:.1f} {unit} median: {median:.3f} {unit}" | |
print(msg) | |
@property | |
def times_array(self) -> np.ndarray: | |
return np.array(self.times) | |
@property | |
def median(self) -> float: | |
return float(np.median(self.times_array)) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/extension.h> | |
void delay_line_forward(torch::Tensor samples, torch::Tensor delays) { | |
int64_t input_size = samples.size(-1); | |
for (int64_t i = 0; i < input_size; ++i) { | |
int64_t delay = delays[i].item<int64_t>(); | |
int64_t index_delayed = i - delay; | |
if (index_delayed < 0) { | |
index_delayed = 0; | |
} | |
samples[i] = 0.5 * (samples[i] + samples[index_delayed]); | |
} | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
py::module::import("torch"); | |
m.def("delay_line_forward", &delay_line_forward); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import setuptools | |
from torch.utils import cpp_extension | |
if __name__ == "__main__": | |
setuptools.setup( | |
name="cpp_extension_test", | |
ext_modules=[ | |
cpp_extension.CppExtension("delay_line_cpp", ["delay_line.cpp"], extra_compile_args=["-O3"]) | |
], | |
cmdclass={"build_ext": cpp_extension.BuildExtension}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment