Created
December 9, 2022 02:08
-
-
Save void-main/840fe163f4c891a7dddb17af076959dc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
import numpy as np | |
def torch_xyxy2xywh(x): | |
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right | |
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) | |
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center | |
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center | |
y[:, 2] = x[:, 2] - x[:, 0] # width | |
y[:, 3] = x[:, 3] - x[:, 1] # height | |
return y | |
# xyxy2xywh | |
@triton.autotune( | |
configs=[ | |
triton.Config({'BLOCK_SIZE': 64}), | |
triton.Config({'BLOCK_SIZE': 128}), | |
triton.Config({'BLOCK_SIZE': 256}), | |
triton.Config({'BLOCK_SIZE': 512}), | |
], | |
key=['n_rows'], | |
) | |
@triton.jit | |
def triton_xyxy2xywh_kernel( | |
input_ptr, | |
output_ptr, | |
n_rows, | |
BLOCK_SIZE: tl.constexpr, | |
): | |
# one program process `[BLOCK_SIZE, 4]` block | |
pid = tl.program_id(0) | |
stride_row = 4 # always have 4 cols | |
offset_row = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
offset_col = tl.arange(0, 4) | |
row_mask = offset_row[:, None] < n_rows | |
x0_ptrs = input_ptr + offset_row[:, None] * stride_row | |
x1_ptrs = input_ptr + offset_row[:, None] * stride_row + 1 | |
x2_ptrs = input_ptr + offset_row[:, None] * stride_row + 2 | |
x3_ptrs = input_ptr + offset_row[:, None] * stride_row + 3 | |
x0 = tl.load(x0_ptrs, mask=row_mask) | |
x1 = tl.load(x1_ptrs, mask=row_mask) | |
x2 = tl.load(x2_ptrs, mask=row_mask) | |
x3 = tl.load(x3_ptrs, mask=row_mask) | |
y0 = (x0 + x2) / 2 # x center | |
y1 = (x1 + x3) / 2 # y center | |
y2 = x2 - x0 # width | |
y3 = x3 - x1 # height | |
tl.store(output_ptr + offset_row[:, None] * stride_row , y0, mask=row_mask) | |
tl.store(output_ptr + offset_row[:, None] * stride_row + 1, y1, mask=row_mask) | |
tl.store(output_ptr + offset_row[:, None] * stride_row + 2, y2, mask=row_mask) | |
tl.store(output_ptr + offset_row[:, None] * stride_row + 3, y3, mask=row_mask) | |
def triton_xyxy2xywh(x: torch.Tensor): | |
n_cols = x.shape[0] | |
n_elements = x.numel() | |
output = torch.empty_like(x) | |
assert x.is_cuda and output.is_cuda | |
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) | |
triton_xyxy2xywh_kernel[grid](x, output, n_cols) | |
return output | |
torch.manual_seed(0) | |
size = 767 | |
x = torch.rand(size, 4, device='cuda') | |
output_torch = torch_xyxy2xywh(x) | |
output_triton = triton_xyxy2xywh(x) | |
print(f'max difference: {torch.max(torch.abs(output_triton - output_torch))}') | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=['size'], # argument names to use as an x-axis for the plot | |
x_vals=[ | |
2 ** i for i in range(4, 24, 1) | |
], # different possible values for `x_name` | |
x_log=True, # x axis is logarithmic | |
line_arg='provider', # argument name whose value corresponds to a different line in the plot | |
line_vals=['triton', 'torch'], # possible values for `line_arg` | |
line_names=['Triton', 'Torch'], # label name for the lines | |
styles=[('blue', '-'), ('green', '-')], # line styles | |
ylabel='GB/s', # label name for the y-axis | |
plot_name='xyxy2xywh-performance', # name for the plot. Used also as a file name for saving the plot. | |
args={}, # values for function arguments not in `x_names` and `y_name` | |
) | |
) | |
def benchmark(size, provider): | |
x = torch.rand(size, 4, device='cuda', dtype=torch.float32) | |
if provider == 'torch': | |
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_xyxy2xywh(x)) | |
if provider == 'triton': | |
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_xyxy2xywh(x)) | |
gbps = lambda ms: 12 * size / ms * 1e-6 | |
return gbps(ms), gbps(max_ms), gbps(min_ms) | |
benchmark.run(print_data=True, save_path='.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment