Last active
November 27, 2025 03:23
-
-
Save xzuyn/e80cc38c401ce69845aa955daaf7c70a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright 2025 xzuyn | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| @triton.jit | |
| def quantize_kernel_rhu( # round-half-up | |
| a_ptr, # fp32 | |
| a_quant_ptr, # uint8 | |
| scale_ptr, # fp32 | |
| min_ptr, # fp32 | |
| n_elements, # int | |
| NUM_QUANT_BLOCKS: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(axis=0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| A_fp32 = tl.load(a_ptr + offsets, mask=mask) | |
| chunk_min = tl.min(tl.where(mask, A_fp32, float("inf")), axis=0) | |
| scale = (tl.max(tl.where(mask, A_fp32, float("-inf")), axis=0) - chunk_min) / 255.0 | |
| is_scale_zero = scale == 0.0 | |
| A_fp32 = (A_fp32 - chunk_min) / tl.where(is_scale_zero, 1.0, scale) | |
| # round-half-up to nearest int | |
| A_fp32 = A_fp32 + 0.5 | |
| A_fp32 = tl.floor(A_fp32) | |
| # clamp to 0..255 | |
| A_fp32 = tl.where(is_scale_zero, 0, A_fp32) | |
| A_fp32 = tl.where(A_fp32 < 0, 0, A_fp32) | |
| A_fp32 = tl.where(A_fp32 > 255, 255, A_fp32) | |
| # store quantized bytes (partial store supported by mask) | |
| tl.store(a_quant_ptr + offsets, A_fp32.to(tl.uint8), mask=mask) | |
| # store per-block scale & min (only if block exists) | |
| if pid < NUM_QUANT_BLOCKS: | |
| tl.store(scale_ptr + pid, scale) | |
| tl.store(min_ptr + pid, chunk_min) | |
| @triton.jit | |
| def quantize_kernel_rhe( # round-half-even | |
| a_ptr, # fp32 | |
| a_quant_ptr, # uint8 | |
| scale_ptr, # fp32 | |
| min_ptr, # fp32 | |
| n_elements, # int | |
| NUM_QUANT_BLOCKS: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(axis=0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| A_fp32 = tl.load(a_ptr + offsets, mask=mask) | |
| chunk_min = tl.min(tl.where(mask, A_fp32, float("inf")), axis=0) | |
| scale = (tl.max(tl.where(mask, A_fp32, float("-inf")), axis=0) - chunk_min) / 255.0 | |
| is_scale_zero = scale == 0.0 | |
| A_fp32 = (A_fp32 - chunk_min) / tl.where(is_scale_zero, 1.0, scale) | |
| # round-half-even to nearest int | |
| floor_val = tl.floor(A_fp32) | |
| A_fp32 = tl.where( | |
| ((A_fp32 - floor_val) == 0.5) & ((floor_val.to(tl.int32) % 2) == 0), | |
| floor_val, | |
| tl.floor(A_fp32 + 0.5), | |
| ) | |
| # clamp to 0..255 | |
| A_fp32 = tl.where(is_scale_zero, 0, A_fp32) | |
| A_fp32 = tl.where(A_fp32 < 0, 0, A_fp32) | |
| A_fp32 = tl.where(A_fp32 > 255, 255, A_fp32) | |
| # store quantized bytes (partial store supported by mask) | |
| tl.store(a_quant_ptr + offsets, A_fp32.to(tl.uint8), mask=mask) | |
| # store per-block scale & min (only if block exists) | |
| if pid < NUM_QUANT_BLOCKS: | |
| tl.store(scale_ptr + pid, scale) | |
| tl.store(min_ptr + pid, chunk_min) | |
| @triton.jit | |
| def quantize_kernel_sr( # stochastic-rounding | |
| a_ptr, # fp32 | |
| a_quant_ptr, # uint8 | |
| scale_ptr, # fp32 | |
| min_ptr, # fp32 | |
| seed, # int | |
| n_elements, # int | |
| NUM_QUANT_BLOCKS: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(axis=0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| A_fp32 = tl.load(a_ptr + offsets, mask=mask) | |
| chunk_min = tl.min(tl.where(mask, A_fp32, float("inf")), axis=0) | |
| scale = (tl.max(tl.where(mask, A_fp32, float("-inf")), axis=0) - chunk_min) / 255.0 | |
| is_scale_zero = scale == 0.0 | |
| A_fp32 = (A_fp32 - chunk_min) / tl.where(is_scale_zero, 1.0, scale) | |
| # stochastic rounding to nearest int | |
| A_fp32 = A_fp32 + tl.rand((seed + offsets + pid).to(tl.uint32), offsets) | |
| A_fp32 = tl.floor(A_fp32) | |
| # clamp to 0..255 | |
| A_fp32 = tl.where(is_scale_zero, 0, A_fp32) | |
| A_fp32 = tl.where(A_fp32 < 0, 0, A_fp32) | |
| A_fp32 = tl.where(A_fp32 > 255, 255, A_fp32) | |
| # store quantized bytes (partial store supported by mask) | |
| tl.store(a_quant_ptr + offsets, A_fp32.to(tl.uint8), mask=mask) | |
| # store per-block scale & min (only if block exists) | |
| if pid < NUM_QUANT_BLOCKS: | |
| tl.store(scale_ptr + pid, scale) | |
| tl.store(min_ptr + pid, chunk_min) | |
| @triton.jit | |
| def dequantize_kernel( | |
| a_ptr, # uint8 | |
| a_dequant_ptr, # fp32 | |
| scale_ptr, # fp32 | |
| min_ptr, # fp32 | |
| n_elements, # int | |
| QUANT_BLOCK_SIZE: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(axis=0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| quant_block_idx = offsets // QUANT_BLOCK_SIZE | |
| A_u8 = tl.load(a_ptr + offsets, mask=mask) | |
| scale = tl.load(scale_ptr + quant_block_idx, mask=mask) | |
| min_val = tl.load(min_ptr + quant_block_idx, mask=mask) | |
| A_fp32 = A_u8.to(tl.float32) | |
| A_fp32 = (A_fp32 * scale) + min_val | |
| tl.store(a_dequant_ptr + offsets, A_fp32, mask=mask) | |
| @triton.jit | |
| def add_stochastic_kernel( | |
| a_ptr, # bf16 | |
| b_ptr, # fp32 | |
| alpha, # float | |
| seed, # int | |
| n_elements, # int | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(axis=0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_elements | |
| # load A and B | |
| A_bf16 = tl.load(a_ptr + offsets, mask=mask) | |
| B_fp32 = tl.load(b_ptr + offsets, mask=mask) | |
| # cast A from bf16 to fp32 (adds 16 empty bits to the mantissa) | |
| A_fp32 = A_bf16.cast(tl.float32) | |
| # A + (alpha * B) | |
| A_fp32 = A_fp32 + (alpha * B_fp32) | |
| # stochastic rounding to nearest bf16 decimal | |
| # bitcast A from fp32 to u32 so we can do bit manipulation | |
| A_u32 = A_fp32.cast(tl.uint32, bitcast=True) | |
| # create u32 random noise, mask off its upper 16 bits, and add into A | |
| A_u32 = A_u32 + (tl.randint((seed + offsets + pid).to(tl.uint32), offsets) & 0xFFFF) | |
| # mask off the lower 16 bits of A | |
| A_u32 = A_u32 & 0xFFFF0000 | |
| # bitcast the masked A from u32 to fp32 | |
| A_fp32 = A_u32.cast(tl.float32, bitcast=True) | |
| # cast A from fp32 to bf16 (drop the extra 16 bits in the mantissa) | |
| A_bf16 = A_fp32.cast(tl.bfloat16) | |
| tl.store(a_ptr + offsets, A_bf16, mask=mask) | |
| def quantize_state_triton_rhu(A_fp32, block_size): | |
| n_elements = A_fp32.numel() | |
| if n_elements <= 1: | |
| return A_fp32, {} | |
| shape = A_fp32.shape | |
| num_blocks = (n_elements + block_size - 1) // block_size | |
| mins = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device) | |
| scales = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device) | |
| A_u8 = torch.empty_like(A_fp32, dtype=torch.uint8, device=A_fp32.device) | |
| A_fp32 = A_fp32.reshape(-1) | |
| A_u8 = A_u8.reshape(-1) | |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
| quantize_kernel_rhu[grid]( | |
| A_fp32, | |
| A_u8, | |
| scales, | |
| mins, | |
| n_elements, | |
| NUM_QUANT_BLOCKS=num_blocks, | |
| BLOCK_SIZE=block_size, | |
| ) | |
| return A_u8, { | |
| "scales": scales, | |
| "mins": mins, | |
| "block_size": block_size, | |
| "shape": shape, | |
| } | |
| def quantize_state_triton_rhe(A_fp32, block_size): | |
| n_elements = A_fp32.numel() | |
| if n_elements <= 1: | |
| return A_fp32, {} | |
| shape = A_fp32.shape | |
| num_blocks = (n_elements + block_size - 1) // block_size | |
| mins = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device) | |
| scales = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device) | |
| A_u8 = torch.empty_like(A_fp32, dtype=torch.uint8, device=A_fp32.device) | |
| A_fp32 = A_fp32.reshape(-1) | |
| A_u8 = A_u8.reshape(-1) | |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
| quantize_kernel_rhe[grid]( | |
| A_fp32, | |
| A_u8, | |
| scales, | |
| mins, | |
| n_elements, | |
| NUM_QUANT_BLOCKS=num_blocks, | |
| BLOCK_SIZE=block_size, | |
| ) | |
| return A_u8, { | |
| "scales": scales, | |
| "mins": mins, | |
| "block_size": block_size, | |
| "shape": shape, | |
| } | |
| def quantize_state_triton_sr(A_fp32, block_size, generator=None): | |
| n_elements = A_fp32.numel() | |
| if n_elements <= 1: | |
| return A_fp32, {} | |
| seed = ( | |
| torch.randint(0, 2 ** 32 - 1, (1,), generator=generator).item() if generator is not None | |
| else torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
| ) | |
| shape = A_fp32.shape | |
| num_blocks = (n_elements + block_size - 1) // block_size | |
| mins = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device) | |
| scales = torch.empty((num_blocks,), dtype=torch.float32, device=A_fp32.device) | |
| A_u8 = torch.empty_like(A_fp32, dtype=torch.uint8, device=A_fp32.device) | |
| assert A_fp32.dtype == torch.float32 | |
| assert A_fp32.is_contiguous() | |
| assert A_u8.is_contiguous() | |
| A_fp32 = A_fp32.view(-1) | |
| A_u8 = A_u8.view(-1) | |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
| quantize_kernel_sr[grid]( | |
| A_fp32, | |
| A_u8, | |
| scales, | |
| mins, | |
| seed, | |
| n_elements, | |
| NUM_QUANT_BLOCKS=num_blocks, | |
| BLOCK_SIZE=block_size, | |
| ) | |
| return A_u8, { | |
| "scales": scales, | |
| "mins": mins, | |
| "block_size": block_size, | |
| "shape": shape, | |
| } | |
| def dequantize_state_triton(A_u8, quant_state): | |
| n_elements = A_u8.numel() | |
| if n_elements <= 1: | |
| return A_u8 | |
| A_fp32 = torch.empty_like(A_u8, dtype=torch.float32, device=A_u8.device) | |
| assert A_u8.dtype == torch.uint8 | |
| assert A_u8.is_contiguous() | |
| assert A_fp32.is_contiguous() | |
| A_u8 = A_u8.view(-1) | |
| A_fp32 = A_fp32.view(-1) | |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
| dequantize_kernel[grid]( | |
| A_u8, | |
| A_fp32, | |
| quant_state["scales"], | |
| quant_state["mins"], | |
| n_elements, | |
| QUANT_BLOCK_SIZE=quant_state["block_size"], | |
| BLOCK_SIZE=1024, # TODO: Tune | |
| ) | |
| return A_fp32.view(quant_state["shape"]) | |
| def add_stochastic_triton(A_bf16, B_fp32, alpha=1.0, generator=None): | |
| n_elements = A_bf16.numel() | |
| if n_elements == 0: | |
| return A_bf16 | |
| assert A_bf16.shape == B_fp32.shape | |
| assert A_bf16.dtype == torch.bfloat16 | |
| assert A_bf16.is_contiguous() | |
| with torch.no_grad(): | |
| if B_fp32.dtype != torch.float32: | |
| B_fp32 = B_fp32.to(dtype=torch.float32) | |
| shape = A_bf16.shape | |
| A_bf16 = A_bf16.view(-1) | |
| B_fp32 = B_fp32.view(-1) | |
| seed = ( | |
| torch.randint(0, 2 ** 32 - 1, (1,), generator=generator).item() if generator is not None | |
| else torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
| ) | |
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | |
| add_stochastic_kernel[grid]( | |
| A_bf16, | |
| B_fp32, | |
| float(alpha), | |
| seed, | |
| n_elements, | |
| BLOCK_SIZE=1024, # TODO: tune | |
| ) | |
| return A_bf16.view(shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment