Skip to content

Instantly share code, notes, and snippets.

@xzuyn
Last active November 27, 2025 03:23
Show Gist options
  • Select an option

  • Save xzuyn/e80cc38c401ce69845aa955daaf7c70a to your computer and use it in GitHub Desktop.

Select an option

Save xzuyn/e80cc38c401ce69845aa955daaf7c70a to your computer and use it in GitHub Desktop.
# Copyright 2025 xzuyn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
import triton.language as tl
@triton.jit
def quantize_kernel_rhu( # round-half-up
a_ptr, # fp32
a_quant_ptr, # uint8
scale_ptr, # fp32
min_ptr, # fp32
n_elements, # int
NUM_QUANT_BLOCKS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
A_fp32 = tl.load(a_ptr + offsets, mask=mask)
chunk_min = tl.min(tl.where(mask, A_fp32, float("inf")), axis=0)
scale = (tl.max(tl.where(mask, A_fp32, float("-inf")), axis=0) - chunk_min) / 255.0
is_scale_zero = scale == 0.0
A_fp32 = (A_fp32 - chunk_min) / tl.where(is_scale_zero, 1.0, scale)
# round-half-up to nearest int
A_fp32 = A_fp32 + 0.5
A_fp32 = tl.floor(A_fp32)
# clamp to 0..255
A_fp32 = tl.where(is_scale_zero, 0, A_fp32)
A_fp32 = tl.where(A_fp32 < 0, 0, A_fp32)
A_fp32 = tl.where(A_fp32 > 255, 255, A_fp32)
# store quantized bytes (partial store supported by mask)
tl.store(a_quant_ptr + offsets, A_fp32.to(tl.uint8), mask=mask)
# store per-block scale & min (only if block exists)
if pid < NUM_QUANT_BLOCKS:
tl.store(scale_ptr + pid, scale)
tl.store(min_ptr + pid, chunk_min)
@triton.jit
def quantize_kernel_rhe( # round-half-even
a_ptr, # fp32
a_quant_ptr, # uint8
scale_ptr, # fp32
min_ptr, # fp32
n_elements, # int
NUM_QUANT_BLOCKS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
A_fp32 = tl.load(a_ptr + offsets, mask=mask)
chunk_min = tl.min(tl.where(mask, A_fp32, float("inf")), axis=0)
scale = (tl.max(tl.where(mask, A_fp32, float("-inf")), axis=0) - chunk_min) / 255.0
is_scale_zero = scale == 0.0
A_fp32 = (A_fp32 - chunk_min) / tl.where(is_scale_zero, 1.0, scale)
# round-half-even to nearest int
floor_val = tl.floor(A_fp32)
A_fp32 = tl.where(
((A_fp32 - floor_val) == 0.5) & ((floor_val.to(tl.int32) % 2) == 0),
floor_val,
tl.floor(A_fp32 + 0.5),
)
# clamp to 0..255
A_fp32 = tl.where(is_scale_zero, 0, A_fp32)
A_fp32 = tl.where(A_fp32 < 0, 0, A_fp32)
A_fp32 = tl.where(A_fp32 > 255, 255, A_fp32)
# store quantized bytes (partial store supported by mask)
tl.store(a_quant_ptr + offsets, A_fp32.to(tl.uint8), mask=mask)
# store per-block scale & min (only if block exists)
if pid < NUM_QUANT_BLOCKS:
tl.store(scale_ptr + pid, scale)
tl.store(min_ptr + pid, chunk_min)
@triton.jit
def quantize_kernel_sr( # stochastic-rounding
a_ptr, # fp32
a_quant_ptr, # uint8
scale_ptr, # fp32
min_ptr, # fp32
seed, # int
n_elements, # int
NUM_QUANT_BLOCKS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
A_fp32 = tl.load(a_ptr + offsets, mask=mask)
chunk_min = tl.min(tl.where(mask, A_fp32, float("inf")), axis=0)
scale = (tl.max(tl.where(mask, A_fp32, float("-inf")), axis=0) - chunk_min) / 255.0
is_scale_zero = scale == 0.0
A_fp32 = (A_fp32 - chunk_min) / tl.where(is_scale_zero, 1.0, scale)
# stochastic rounding to nearest int
A_fp32 = A_fp32 + tl.rand((seed + offsets + pid).to(tl.uint32), offsets)
A_fp32 = tl.floor(A_fp32)
# clamp to 0..255
A_fp32 = tl.where(is_scale_zero, 0, A_fp32)
A_fp32 = tl.where(A_fp32 < 0, 0, A_fp32)
A_fp32 = tl.where(A_fp32 > 255, 255, A_fp32)
# store quantized bytes (partial store supported by mask)
tl.store(a_quant_ptr + offsets, A_fp32.to(tl.uint8), mask=mask)
# store per-block scale & min (only if block exists)
if pid < NUM_QUANT_BLOCKS:
tl.store(scale_ptr + pid, scale)
tl.store(min_ptr + pid, chunk_min)
@triton.jit
def dequantize_kernel(
a_ptr, # uint8
a_dequant_ptr, # fp32
scale_ptr, # fp32
min_ptr, # fp32
n_elements, # int
QUANT_BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
quant_block_idx = offsets // QUANT_BLOCK_SIZE
A_u8 = tl.load(a_ptr + offsets, mask=mask)
scale = tl.load(scale_ptr + quant_block_idx, mask=mask)
min_val = tl.load(min_ptr + quant_block_idx, mask=mask)
A_fp32 = A_u8.to(tl.float32)
A_fp32 = (A_fp32 * scale) + min_val
tl.store(a_dequant_ptr + offsets, A_fp32, mask=mask)
@triton.jit
def add_stochastic_kernel(
a_ptr, # bf16
b_ptr, # fp32
alpha, # float
seed, # int
n_elements, # int
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# load A and B
A_bf16 = tl.load(a_ptr + offsets, mask=mask)
B_fp32 = tl.load(b_ptr + offsets, mask=mask)
# cast A from bf16 to fp32 (adds 16 empty bits to the mantissa)
A_fp32 = A_bf16.cast(tl.float32)
# A + (alpha * B)
A_fp32 = A_fp32 + (alpha * B_fp32)
# stochastic rounding to nearest bf16 decimal
# bitcast A from fp32 to u32 so we can do bit manipulation
A_u32 = A_fp32.cast(tl.uint32, bitcast=True)
# create u32 random noise, mask off its upper 16 bits, and add into A
A_u32 = A_u32 + (tl.randint((seed + offsets + pid).to(tl.uint32), offsets) & 0xFFFF)
# mask off the lower 16 bits of A
A_u32 = A_u32 & 0xFFFF0000
# bitcast the masked A from u32 to fp32
A_fp32 = A_u32.cast(tl.float32, bitcast=True)
# cast A from fp32 to bf16 (drop the extra 16 bits in the mantissa)
A_bf16 = A_fp32.cast(tl.bfloat16)
tl.store(a_ptr + offsets, A_bf16, mask=mask)
def quantize_state_triton_rhu(A_fp32, block_size):
n_elements = A_fp32.numel()
if n_elements <= 1:
return A_fp32, {}
shape = A_fp32.shape
num_blocks = (n_elements + block_size - 1) // block_size
mins = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device)
scales = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device)
A_u8 = torch.empty_like(A_fp32, dtype=torch.uint8, device=A_fp32.device)
A_fp32 = A_fp32.reshape(-1)
A_u8 = A_u8.reshape(-1)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
quantize_kernel_rhu[grid](
A_fp32,
A_u8,
scales,
mins,
n_elements,
NUM_QUANT_BLOCKS=num_blocks,
BLOCK_SIZE=block_size,
)
return A_u8, {
"scales": scales,
"mins": mins,
"block_size": block_size,
"shape": shape,
}
def quantize_state_triton_rhe(A_fp32, block_size):
n_elements = A_fp32.numel()
if n_elements <= 1:
return A_fp32, {}
shape = A_fp32.shape
num_blocks = (n_elements + block_size - 1) // block_size
mins = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device)
scales = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device)
A_u8 = torch.empty_like(A_fp32, dtype=torch.uint8, device=A_fp32.device)
A_fp32 = A_fp32.reshape(-1)
A_u8 = A_u8.reshape(-1)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
quantize_kernel_rhe[grid](
A_fp32,
A_u8,
scales,
mins,
n_elements,
NUM_QUANT_BLOCKS=num_blocks,
BLOCK_SIZE=block_size,
)
return A_u8, {
"scales": scales,
"mins": mins,
"block_size": block_size,
"shape": shape,
}
def quantize_state_triton_sr(A_fp32, block_size, generator=None):
n_elements = A_fp32.numel()
if n_elements <= 1:
return A_fp32, {}
seed = (
torch.randint(0, 2 ** 32 - 1, (1,), generator=generator).item() if generator is not None
else torch.randint(0, 2 ** 32 - 1, (1,)).item()
)
shape = A_fp32.shape
num_blocks = (n_elements + block_size - 1) // block_size
mins = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device)
scales = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device)
A_u8 = torch.empty_like(A_fp32, dtype=torch.uint8, device=A_fp32.device)
assert A_fp32.dtype == torch.float32
assert A_fp32.is_contiguous()
assert A_u8.is_contiguous()
A_fp32 = A_fp32.view(-1)
A_u8 = A_u8.view(-1)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
quantize_kernel_sr[grid](
A_fp32,
A_u8,
scales,
mins,
seed,
n_elements,
NUM_QUANT_BLOCKS=num_blocks,
BLOCK_SIZE=block_size,
)
return A_u8, {
"scales": scales,
"mins": mins,
"block_size": block_size,
"shape": shape,
}
def dequantize_state_triton(A_u8, quant_state):
n_elements = A_u8.numel()
if n_elements <= 1:
return A_u8
A_fp32 = torch.empty_like(A_u8, dtype=torch.float32, device=A_u8.device)
assert A_u8.dtype == torch.uint8
assert A_u8.is_contiguous()
assert A_fp32.is_contiguous()
A_u8 = A_u8.view(-1)
A_fp32 = A_fp32.view(-1)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
dequantize_kernel[grid](
A_u8,
A_fp32,
quant_state["scales"],
quant_state["mins"],
n_elements,
QUANT_BLOCK_SIZE=quant_state["block_size"],
BLOCK_SIZE=1024, # TODO: Tune
)
return A_fp32.view(quant_state["shape"])
def add_stochastic_triton(A_bf16, B_fp32, alpha=1.0, generator=None):
n_elements = A_bf16.numel()
if n_elements == 0:
return A_bf16
assert A_bf16.shape == B_fp32.shape
assert A_bf16.dtype == torch.bfloat16
assert A_bf16.is_contiguous()
with torch.no_grad():
if B_fp32.dtype != torch.float32:
B_fp32 = B_fp32.to(dtype=torch.float32)
shape = A_bf16.shape
A_bf16 = A_bf16.view(-1)
B_fp32 = B_fp32.view(-1)
seed = (
torch.randint(0, 2 ** 32 - 1, (1,), generator=generator).item() if generator is not None
else torch.randint(0, 2 ** 32 - 1, (1,)).item()
)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_stochastic_kernel[grid](
A_bf16,
B_fp32,
float(alpha),
seed,
n_elements,
BLOCK_SIZE=1024, # TODO: tune
)
return A_bf16.view(shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment