Last active
June 27, 2025 22:58
-
-
Save AngryLoki/39f7363433666c7758484fac17483af1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import torch | |
import triton | |
import triton.language as tl | |
from triton.runtime import driver | |
DEVICE_CPU = torch.device("cpu") | |
DEVICE_GPU = triton.runtime.driver.active.get_active_torch_device() | |
def is_hip(): | |
return triton.runtime.driver.active.get_current_target().backend == "hip" | |
def is_cdna(): | |
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ( | |
"gfx940", | |
"gfx941", | |
"gfx942", | |
"gfx90a", | |
"gfx908", | |
) | |
def naive_softmax(x): | |
x_max = x.max(dim=1)[0] | |
z = x - x_max[:, None] | |
numerator = torch.exp(z) | |
denominator = numerator.sum(dim=1) | |
ret = numerator / denominator[:, None] | |
return ret | |
@triton.jit | |
def softmax_kernel( | |
output_ptr, | |
input_ptr, | |
input_row_stride, | |
output_row_stride, | |
n_rows, | |
n_cols, | |
BLOCK_SIZE: tl.constexpr, | |
num_stages: tl.constexpr | |
): | |
row_start = tl.program_id(0) | |
row_step = tl.num_programs(0) | |
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): | |
row_start_ptr = input_ptr + row_idx * input_row_stride | |
col_offsets = tl.arange(0, BLOCK_SIZE) | |
input_ptrs = row_start_ptr + col_offsets | |
mask = col_offsets < n_cols | |
row = tl.load(input_ptrs, mask=mask, other=-float("inf")) | |
row_minus_max = row - tl.max(row, axis=0) | |
numerator = tl.exp(row_minus_max) | |
denominator = tl.sum(numerator, axis=0) | |
softmax_output = numerator / denominator | |
output_row_start_ptr = output_ptr + row_idx * output_row_stride | |
output_ptrs = output_row_start_ptr + col_offsets | |
tl.store(output_ptrs, softmax_output, mask=mask) | |
properties = driver.active.utils.get_device_properties(DEVICE_GPU.index) | |
NUM_SM = properties["multiprocessor_count"] | |
NUM_REGS = properties["max_num_regs"] | |
SIZE_SMEM = properties["max_shared_mem"] | |
WARP_SIZE = properties["warpSize"] | |
target = triton.runtime.driver.active.get_current_target() | |
kernels = {} | |
def softmax(x): | |
n_rows, n_cols = x.shape | |
BLOCK_SIZE = triton.next_power_of_2(n_cols) | |
num_warps = 8 | |
num_stages = 4 if SIZE_SMEM > 200000 else 2 | |
y = torch.empty_like(x) | |
kernel = softmax_kernel.warmup( | |
y, | |
x, | |
x.stride(0), | |
y.stride(0), | |
n_rows, | |
n_cols, | |
BLOCK_SIZE=BLOCK_SIZE, | |
num_stages=num_stages, | |
num_warps=num_warps, | |
grid=(1,) | |
) | |
kernel._init_handles() | |
n_regs = kernel.n_regs | |
size_smem = kernel.metadata.shared | |
if is_hip(): | |
NUM_GPRS = NUM_REGS | |
if is_cdna(): | |
NUM_GPRS = NUM_REGS * 2 | |
MAX_NUM_THREADS = properties["max_threads_per_sm"] | |
max_num_waves = MAX_NUM_THREADS // WARP_SIZE | |
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps | |
else: | |
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) | |
occupancy = min(occupancy, SIZE_SMEM // size_smem) | |
num_programs = NUM_SM * occupancy | |
num_programs = min(num_programs, n_rows) | |
kernel[(num_programs, 1, 1)]( | |
y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages | |
) | |
return y | |
dtype = torch.float16 | |
torch.manual_seed(0) | |
x_cpu = torch.randn(1823, 781, device=DEVICE_CPU, dtype=dtype) | |
x_gpu = x_cpu.to(DEVICE_GPU) | |
y_torch = torch.softmax(x_cpu, axis=1) | |
y_triton = softmax(x_gpu) | |
assert torch.allclose(y_triton.to(DEVICE_CPU), y_torch, atol=1e-05, rtol=1e-02), (y_triton.to(DEVICE_CPU), y_torch) | |
configurations = list(itertools.product(["triton", "torch"], [torch.float32, torch.float16, torch.bfloat16])) | |
names = [f"{provider} {dtype}" for provider, dtype in configurations] | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=["N"], | |
x_vals=[128 * i for i in range(2, 100)], | |
line_arg="configuration", | |
line_vals=configurations, | |
line_names=names, | |
# styles=[("blue", "-"), ("green", "-")], | |
ylabel="GB/s", | |
plot_name="softmax-performance", | |
args={"M": 4096}, | |
) | |
) | |
def benchmark(M, N, configuration): | |
provider, dtype = configuration | |
DEVICE = DEVICE_CPU if provider == "torch" else DEVICE_GPU | |
x = torch.randn(M, N, device=DEVICE, dtype=dtype) | |
if provider == "triton": | |
stream = getattr(torch, DEVICE.type).Stream() | |
getattr(torch, DEVICE.type).set_stream(stream) | |
if provider == "torch": | |
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) | |
elif provider == "triton": | |
ms = triton.testing.do_bench(lambda: softmax(x)) | |
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) | |
return gbps(ms) | |
benchmark.run(show_plots=True, print_data=True) | |
import matplotlib.pyplot as plt | |
benchmark.run(show_plots=True, print_data=True) | |
plt.savefig("benchmark.png", dpi=72, bbox_inches="tight") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment