Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Created November 21, 2025 14:39
Show Gist options
  • Select an option

  • Save CoffeeVampir3/d82917f6fce60c0c2cdf00629c4de67d to your computer and use it in GitHub Desktop.

Select an option

Save CoffeeVampir3/d82917f6fce60c0c2cdf00629c4de67d to your computer and use it in GitHub Desktop.
from sys.intrinsics import llvm_intrinsic
from memory import UnsafePointer
from collections import InlineArray
struct TileConfig(Movable):
var palette_id: UInt8
var start_row: UInt8
var reserved: InlineArray[UInt8, 14]
var colsb: InlineArray[UInt16, 16]
var rows: InlineArray[UInt8, 16]
fn __init__(out self):
self.palette_id = 0
self.start_row = 0
self.reserved = InlineArray[UInt8, 14](fill=0)
self.colsb = InlineArray[UInt16, 16](fill=0)
self.rows = InlineArray[UInt8, 16](fill=0)
fn make_i8_gemm_config[TILE_M: Int, TILE_K: Int, TILE_N: Int]() -> TileConfig:
var cfg = TileConfig()
cfg.palette_id = 1
cfg.rows[0] = TILE_M
cfg.rows[1] = TILE_M
cfg.colsb[0] = TILE_K
cfg.colsb[1] = TILE_K
cfg.rows[2] = TILE_K // 4
cfg.rows[3] = TILE_K // 4
cfg.colsb[2] = TILE_N * 4
cfg.colsb[3] = TILE_N * 4
for i in range(4, 8):
cfg.rows[i] = TILE_M
cfg.colsb[i] = TILE_N * 4
return cfg^
fn init_intel_amx() -> Bool:
alias SYS_arch_prctl = 158
alias ARCH_REQ_XCOMP_PERM = 0x1023
alias XFEATURE_XTILEDATA = 18
var result = __mlir_op.`pop.external_call`[
func = "syscall".value,
_type=Int64,
](Int64(SYS_arch_prctl), Int64(ARCH_REQ_XCOMP_PERM), Int64(XFEATURE_XTILEDATA))
return result == 0
fn load_amx_tilecfg():
var cfg = make_i8_gemm_config[16, 64, 16]()
var cfg_ptr = UnsafePointer(to=cfg)
llvm_intrinsic["llvm.x86.ldtilecfg", NoneType](cfg_ptr.bitcast[NoneType]())
fn release_amx_tilecfg():
llvm_intrinsic["llvm.x86.tilerelease", NoneType]()
fn tile_zero[tile_id: Int]():
constrained[tile_id >= 0 and tile_id < 8, "tile_id must be 0-7"]()
llvm_intrinsic["llvm.x86.tilezero", NoneType](Int8(tile_id))
fn tile_load[tile_id: Int, byte_stride: Int, dtype: DType](ptr: UnsafePointer[Scalar[dtype]]):
constrained[tile_id >= 0 and tile_id < 8, "tile_id must be 0-7"]()
llvm_intrinsic["llvm.x86.tileloadd64", NoneType](
Int8(tile_id),
ptr,
byte_stride
)
fn tile_store[tile_id: Int, byte_stride: Int, dtype: DType](ptr: UnsafePointer[Scalar[dtype]]):
constrained[tile_id >= 0 and tile_id < 8, "tile_id must be 0-7"]()
llvm_intrinsic["llvm.x86.tilestored64", NoneType](
Int8(tile_id),
ptr,
byte_stride
)
fn tile_dp[tmm_c: Int, tmm_a: Int, tmm_b: Int, a_dtype: DType, b_dtype: DType](
a_ptr: UnsafePointer[Scalar[a_dtype]],
b_ptr: UnsafePointer[Scalar[b_dtype]],
):
constrained[tmm_c >= 0 and tmm_c < 8 and tmm_a >= 0 and tmm_a < 8 and tmm_b >= 0 and tmm_b < 8, "tile register IDs must be 0-7"]()
@parameter
if a_dtype == DType.uint8 and b_dtype == DType.int8:
llvm_intrinsic["llvm.x86.tdpbusd", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b))
elif a_dtype == DType.int8 and b_dtype == DType.uint8:
llvm_intrinsic["llvm.x86.tdpbsud", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b))
elif a_dtype == DType.int8 and b_dtype == DType.int8:
llvm_intrinsic["llvm.x86.tdpbssd", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b))
elif a_dtype == DType.uint8 and b_dtype == DType.uint8:
llvm_intrinsic["llvm.x86.tdpbuud", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b))
elif a_dtype == DType.bfloat16 and b_dtype == DType.bfloat16:
llvm_intrinsic["llvm.x86.tdpbf16ps", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b))
elif a_dtype == DType.float16 and b_dtype == DType.float16:
llvm_intrinsic["llvm.x86.tdpfp16ps", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b))
else:
constrained[False, "Unsupported dtype combination"]()
from algorithm import parallelize
from layout import Layout, LayoutTensor
from .amx import tile_zero, tile_load, tile_store, tile_dp, load_amx_tilecfg, release_amx_tilecfg
fn matmul_amx_uint8_int8_blocked[
M: Int,
N: Int,
K: Int,
PACK_N_BLOCK: Int,
PACK_K_BLOCK: Int,
A_layout: Layout,
B_layout: Layout,
C_layout: Layout,
](
A: LayoutTensor[DType.uint8, A_layout, _],
B: LayoutTensor[DType.int8, B_layout, _],
C: LayoutTensor[mut=True, DType.int32, C_layout, _],
n_offset: Int = 0,
n_size: Int = N,
):
alias TILE_M = 16
alias TILE_N = 16
alias TILE_K = 64
alias M_STEP = 32
alias N_STEP = 32
alias K_STEP = 64
constrained[M % M_STEP == 0, "M must be divisible by M_STEP"]()
constrained[N % N_STEP == 0, "N must be divisible by N_STEP"]()
constrained[K % K_STEP == 0, "K must be divisible by K_STEP"]()
constrained[M_STEP == TILE_M * 2, "M_STEP must be 2*TILE_M"]()
constrained[N_STEP == TILE_N * 2, "N_STEP must be 2*TILE_N"]()
constrained[K_STEP == TILE_K, "K_STEP must equal TILE_K"]()
for k_block_begin in range(0, K, PACK_K_BLOCK):
var k_block_size = min(PACK_K_BLOCK, K - k_block_begin)
for m_begin in range(0, M, M_STEP):
for n_begin in range(n_offset, n_offset + n_size, N_STEP):
if k_block_begin == 0:
tile_zero[4]()
tile_zero[5]()
tile_zero[6]()
tile_zero[7]()
else:
var c00_ptr = C.ptr.offset(m_begin * N + n_begin)
var c01_ptr = C.ptr.offset(m_begin * N + n_begin + TILE_N)
var c10_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin)
var c11_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin + TILE_N)
tile_load[4, N * 4](c00_ptr)
tile_load[5, N * 4](c01_ptr)
tile_load[6, N * 4](c10_ptr)
tile_load[7, N * 4](c11_ptr)
for k_begin in range(0, k_block_size, K_STEP):
var k = k_block_begin + k_begin
var a0_ptr = A.ptr.offset(m_begin * K + k)
var a1_ptr = A.ptr.offset((m_begin + TILE_M) * K + k)
var n_block_begin = (n_begin // PACK_N_BLOCK) * PACK_N_BLOCK
var n_block_size = min(PACK_N_BLOCK, N - n_block_begin)
var n_within_block = n_begin - n_block_begin
var k_block_base = (k // PACK_K_BLOCK) * PACK_K_BLOCK
var k_block_sz = min(PACK_K_BLOCK, K - k_block_base)
var k_within_block = k - k_block_base
var b_offset = (n_block_begin * K +
k_block_base * n_block_size +
n_within_block * k_block_sz +
k_within_block * N_STEP)
var b0_ptr = B.ptr.offset(b_offset)
var b1_ptr = B.ptr.offset(b_offset + TILE_N * K_STEP)
tile_load[0, K](a0_ptr)
tile_load[1, K](a1_ptr)
tile_load[2, K_STEP](b0_ptr)
tile_load[3, K_STEP](b1_ptr)
tile_dp[4, 0, 2](a0_ptr, b0_ptr)
tile_dp[5, 0, 3](a0_ptr, b1_ptr)
tile_dp[6, 1, 2](a1_ptr, b0_ptr)
tile_dp[7, 1, 3](a1_ptr, b1_ptr)
var c00_ptr = C.ptr.offset(m_begin * N + n_begin)
var c01_ptr = C.ptr.offset(m_begin * N + n_begin + TILE_N)
var c10_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin)
var c11_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin + TILE_N)
tile_store[4, N * 4](c00_ptr)
tile_store[5, N * 4](c01_ptr)
tile_store[6, N * 4](c10_ptr)
tile_store[7, N * 4](c11_ptr)
fn matmul_amx_uint8_int8_blocked_parallel[
M: Int,
N: Int,
K: Int,
PACK_N_BLOCK: Int,
PACK_K_BLOCK: Int,
PARALLEL_N_CHUNK: Int,
num_workers: Int,
A_layout: Layout,
B_layout: Layout,
C_layout: Layout,
](
A: LayoutTensor[DType.uint8, A_layout, _],
B: LayoutTensor[DType.int8, B_layout, _],
C: LayoutTensor[mut=True, DType.int32, C_layout, _],
):
constrained[N % PARALLEL_N_CHUNK == 0, "N must be divisible by PARALLEL_N_CHUNK"]()
@parameter
fn worker(thread_id: Int):
var n_start = thread_id * PARALLEL_N_CHUNK
var n_end = min(n_start + PARALLEL_N_CHUNK, N)
var n_slice_size = n_end - n_start
matmul_amx_uint8_int8_blocked[M, N, K, PACK_N_BLOCK, PACK_K_BLOCK, A_layout, B_layout, C_layout](
A, B, C, n_start, n_slice_size
)
var num_blocks = N // PARALLEL_N_CHUNK
parallelize[worker](num_blocks, num_workers)
from layout import Layout, LayoutTensor
from memory import ImmutUnsafePointer, MutUnsafePointer
struct QuantizedTensor[
dtype: DType,
layout: Layout,
scales_layout: Layout,
origin: Origin[_],
]:
var tensor: LayoutTensor[Self.dtype, Self.layout, Self.origin]
var scales: LayoutTensor[DType.float32, Self.scales_layout, Self.origin]
fn __init__(
out self,
tensor: LayoutTensor[Self.dtype, Self.layout, Self.origin],
scales: LayoutTensor[DType.float32, Self.scales_layout, Self.origin],
):
self.tensor = tensor
self.scales = scales
fn quantize_symmetric_uint8_channelwise_row[
M: Int,
K: Int,
](
data_fp32: ImmutUnsafePointer[Float32],
data_uint8: MutUnsafePointer[UInt8],
scales: MutUnsafePointer[Float32],
):
for m in range(M):
var max_val = Float32(0)
for k in range(K):
var idx = m * K + k
var val = abs(data_fp32[idx])
if val > max_val:
max_val = val
var scale = max_val / 255.0
if scale == 0:
scale = 1.0
scales[m] = scale
for k in range(K):
var idx = m * K + k
var quantized = Int32(data_fp32[idx] / scale)
if quantized > 255:
quantized = 255
elif quantized < 0:
quantized = 0
data_uint8[idx] = UInt8(quantized)
fn quantize_symmetric_int8_channelwise_col[
K: Int,
N: Int,
](
data_fp32: ImmutUnsafePointer[Float32],
data_int8: MutUnsafePointer[Int8],
scales: MutUnsafePointer[Float32],
):
for n in range(N):
var max_val = Float32(0)
for k in range(K):
var idx = k * N + n
var val = abs(data_fp32[idx])
if val > max_val:
max_val = val
var scale = max_val / 127.0
if scale == 0:
scale = 1.0
scales[n] = scale
for k in range(K):
var idx = k * N + n
var quantized = Int32(data_fp32[idx] / scale)
if quantized > 127:
quantized = 127
elif quantized < -128:
quantized = -128
data_int8[idx] = Int8(quantized)
from algorithm import vectorize, parallelize
from layout import Layout, LayoutTensor
from sys.intrinsics import llvm_intrinsic
from .quantization import QuantizedTensor
fn vpdpbusd[width: Int](
src: SIMD[DType.int32, 4 * width],
a: SIMD[DType.uint8, 16 * width],
b: SIMD[DType.int8, 16 * width]
) -> SIMD[DType.int32, 4 * width]:
return llvm_intrinsic[
"llvm.x86.avx512.vpdpbusd." + String(128 * width),
SIMD[DType.int32, 4 * width]
](src, a, b)
fn vpdpbusd_512(
src: SIMD[DType.int32, 16],
a: SIMD[DType.uint8, 64],
b: SIMD[DType.int8, 64]
) -> SIMD[DType.int32, 16]:
return vpdpbusd[4](src, a, b)
fn small_matmul_uint8_vnni_channelwise_parallel[
M: Int,
N: Int,
K: Int,
num_workers: Int,
A_layout: Layout,
B_layout: Layout,
C_layout: Layout,
A_scales_layout: Layout,
B_scales_layout: Layout,
A_origin: Origin[_],
B_origin: Origin[_],
C_origin: MutOrigin,
](
A: QuantizedTensor[DType.uint8, A_layout, A_scales_layout, A_origin],
B: QuantizedTensor[DType.int8, B_layout, B_scales_layout, B_origin],
C: LayoutTensor[mut=True, DType.int32, C_layout, C_origin],
):
constrained[K % 64 == 0, "K must be divisible by 64 for VNNI"]()
@parameter
fn compute_row(m: Int):
for n in range(N):
var acc = SIMD[DType.int32, 16](0)
@parameter
fn dot_k[width: Int](k: Int):
var a_u8 = A.tensor.load[width=64](m, k)
var b_u8 = B.tensor.load[width=64](n, k)
acc = vpdpbusd_512(acc, a_u8, b_u8)
vectorize[dot_k, 64, size=K]()
var scale_a = A.scales[m, 0][0]
var scale_b = B.scales[0, n][0]
var result = Int32(Float32(acc.reduce_add()) * scale_a * scale_b)
C.store(m, n, SIMD[DType.int32, 1](result))
parallelize[compute_row](M, num_workers)
from memory import ImmutUnsafePointer, MutUnsafePointer, alloc
from sys.intrinsics import _RegisterPackType, PrefetchOptions
fn pack_b_matrix_vnni_int8_small[
K: Int,
N: Int,
](
src: ImmutUnsafePointer[Int8],
dst: MutUnsafePointer[Int8],
):
for k in range(K):
for n in range(N):
var src_idx = n + k * N
var dst_idx = n * K + k
dst[dst_idx] = src[src_idx]
fn pack_b_matrix_vnni_int8_large[
K: Int,
N: Int,
N_BLOCK: Int,
K_BLOCK: Int,
TILE_K: Int = 64,
TILE_N: Int = 16,
N_STEP: Int = 32,
K_STEP: Int = 64,
](
src: ImmutUnsafePointer[Int8],
dst: MutUnsafePointer[Int8],
):
constrained[N % N_STEP == 0, "N must be divisible by N_STEP"]()
constrained[K % K_STEP == 0, "K must be divisible by K_STEP"]()
constrained[N_STEP == TILE_N * 2, "N_STEP must be 2*TILE_N"]()
constrained[K_STEP == TILE_K, "K_STEP must equal TILE_K"]()
var temp = alloc[Int8](256)
for n_block_begin in range(0, N, N_BLOCK):
var n_block_end = min(n_block_begin + N_BLOCK, N)
var n_block_size = n_block_end - n_block_begin
for k_block_begin in range(0, K, K_BLOCK):
var k_block_end = min(k_block_begin + K_BLOCK, K)
var k_block_size = k_block_end - k_block_begin
for n_begin in range(0, n_block_size, N_STEP):
for k_begin in range(0, k_block_size, K_STEP):
var tile_base = (n_block_begin * K +
k_block_begin * n_block_size +
n_begin * k_block_size +
k_begin * N_STEP)
for i in range(N_STEP):
var src_row = n_block_begin + n_begin + i
var src_col = k_block_begin + k_begin
for k in range(K_STEP):
dst[tile_base + i * K_STEP + k] = src[src_row * K + src_col + k]
transpose_16x16_vnni_int8_with_temp(
dst.offset(tile_base),
temp,
K_STEP
)
transpose_16x16_vnni_int8_with_temp(
dst.offset(tile_base + TILE_N * K_STEP),
temp,
K_STEP
)
temp.free()
@always_inline
fn transpose_16x16_vnni_int8_with_temp(
ptr: MutUnsafePointer[Int8],
temp: MutUnsafePointer[Int8],
src_stride: Int
):
alias VNNI_BLK = 4
alias TILE_SIZE = 16
for i in range(TILE_SIZE):
for j in range(TILE_SIZE):
temp[i * TILE_SIZE + j] = ptr[i * src_stride + j]
for row in range(0, TILE_SIZE, VNNI_BLK):
for col in range(TILE_SIZE):
for vnni in range(VNNI_BLK):
ptr[col * TILE_SIZE + row + vnni] = temp[(row + vnni) * TILE_SIZE + col]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment