Skip to content

Instantly share code, notes, and snippets.

@AngryLoki
Last active June 27, 2025 22:58
Show Gist options
  • Save AngryLoki/39f7363433666c7758484fac17483af1 to your computer and use it in GitHub Desktop.
Save AngryLoki/39f7363433666c7758484fac17483af1 to your computer and use it in GitHub Desktop.
import itertools
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE_CPU = torch.device("cpu")
DEVICE_GPU = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
"gfx940",
"gfx941",
"gfx942",
"gfx90a",
"gfx908",
)
def naive_softmax(x):
x_max = x.max(dim=1)[0]
z = x - x_max[:, None]
numerator = torch.exp(z)
denominator = numerator.sum(dim=1)
ret = numerator / denominator[:, None]
return ret
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
properties = driver.active.utils.get_device_properties(DEVICE_GPU.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = 8
num_stages = 4 if SIZE_SMEM > 200000 else 2
y = torch.empty_like(x)
kernel = softmax_kernel.warmup(
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages,
num_warps=num_warps,
grid=(1,)
)
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
kernel[(num_programs, 1, 1)](
y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages
)
return y
dtype = torch.float16
torch.manual_seed(0)
x_cpu = torch.randn(1823, 781, device=DEVICE_CPU, dtype=dtype)
x_gpu = x_cpu.to(DEVICE_GPU)
y_torch = torch.softmax(x_cpu, axis=1)
y_triton = softmax(x_gpu)
assert torch.allclose(y_triton.to(DEVICE_CPU), y_torch, atol=1e-05, rtol=1e-02), (y_triton.to(DEVICE_CPU), y_torch)
configurations = list(itertools.product(["triton", "torch"], [torch.float32, torch.float16, torch.bfloat16]))
names = [f"{provider} {dtype}" for provider, dtype in configurations]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[128 * i for i in range(2, 100)],
line_arg="configuration",
line_vals=configurations,
line_names=names,
# styles=[("blue", "-"), ("green", "-")],
ylabel="GB/s",
plot_name="softmax-performance",
args={"M": 4096},
)
)
def benchmark(M, N, configuration):
provider, dtype = configuration
DEVICE = DEVICE_CPU if provider == "torch" else DEVICE_GPU
x = torch.randn(M, N, device=DEVICE, dtype=dtype)
if provider == "triton":
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == "torch":
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
elif provider == "triton":
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)
import matplotlib.pyplot as plt
benchmark.run(show_plots=True, print_data=True)
plt.savefig("benchmark.png", dpi=72, bbox_inches="tight")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment