Created
November 21, 2025 14:39
-
-
Save CoffeeVampir3/d82917f6fce60c0c2cdf00629c4de67d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from sys.intrinsics import llvm_intrinsic | |
| from memory import UnsafePointer | |
| from collections import InlineArray | |
| struct TileConfig(Movable): | |
| var palette_id: UInt8 | |
| var start_row: UInt8 | |
| var reserved: InlineArray[UInt8, 14] | |
| var colsb: InlineArray[UInt16, 16] | |
| var rows: InlineArray[UInt8, 16] | |
| fn __init__(out self): | |
| self.palette_id = 0 | |
| self.start_row = 0 | |
| self.reserved = InlineArray[UInt8, 14](fill=0) | |
| self.colsb = InlineArray[UInt16, 16](fill=0) | |
| self.rows = InlineArray[UInt8, 16](fill=0) | |
| fn make_i8_gemm_config[TILE_M: Int, TILE_K: Int, TILE_N: Int]() -> TileConfig: | |
| var cfg = TileConfig() | |
| cfg.palette_id = 1 | |
| cfg.rows[0] = TILE_M | |
| cfg.rows[1] = TILE_M | |
| cfg.colsb[0] = TILE_K | |
| cfg.colsb[1] = TILE_K | |
| cfg.rows[2] = TILE_K // 4 | |
| cfg.rows[3] = TILE_K // 4 | |
| cfg.colsb[2] = TILE_N * 4 | |
| cfg.colsb[3] = TILE_N * 4 | |
| for i in range(4, 8): | |
| cfg.rows[i] = TILE_M | |
| cfg.colsb[i] = TILE_N * 4 | |
| return cfg^ | |
| fn init_intel_amx() -> Bool: | |
| alias SYS_arch_prctl = 158 | |
| alias ARCH_REQ_XCOMP_PERM = 0x1023 | |
| alias XFEATURE_XTILEDATA = 18 | |
| var result = __mlir_op.`pop.external_call`[ | |
| func = "syscall".value, | |
| _type=Int64, | |
| ](Int64(SYS_arch_prctl), Int64(ARCH_REQ_XCOMP_PERM), Int64(XFEATURE_XTILEDATA)) | |
| return result == 0 | |
| fn load_amx_tilecfg(): | |
| var cfg = make_i8_gemm_config[16, 64, 16]() | |
| var cfg_ptr = UnsafePointer(to=cfg) | |
| llvm_intrinsic["llvm.x86.ldtilecfg", NoneType](cfg_ptr.bitcast[NoneType]()) | |
| fn release_amx_tilecfg(): | |
| llvm_intrinsic["llvm.x86.tilerelease", NoneType]() | |
| fn tile_zero[tile_id: Int](): | |
| constrained[tile_id >= 0 and tile_id < 8, "tile_id must be 0-7"]() | |
| llvm_intrinsic["llvm.x86.tilezero", NoneType](Int8(tile_id)) | |
| fn tile_load[tile_id: Int, byte_stride: Int, dtype: DType](ptr: UnsafePointer[Scalar[dtype]]): | |
| constrained[tile_id >= 0 and tile_id < 8, "tile_id must be 0-7"]() | |
| llvm_intrinsic["llvm.x86.tileloadd64", NoneType]( | |
| Int8(tile_id), | |
| ptr, | |
| byte_stride | |
| ) | |
| fn tile_store[tile_id: Int, byte_stride: Int, dtype: DType](ptr: UnsafePointer[Scalar[dtype]]): | |
| constrained[tile_id >= 0 and tile_id < 8, "tile_id must be 0-7"]() | |
| llvm_intrinsic["llvm.x86.tilestored64", NoneType]( | |
| Int8(tile_id), | |
| ptr, | |
| byte_stride | |
| ) | |
| fn tile_dp[tmm_c: Int, tmm_a: Int, tmm_b: Int, a_dtype: DType, b_dtype: DType]( | |
| a_ptr: UnsafePointer[Scalar[a_dtype]], | |
| b_ptr: UnsafePointer[Scalar[b_dtype]], | |
| ): | |
| constrained[tmm_c >= 0 and tmm_c < 8 and tmm_a >= 0 and tmm_a < 8 and tmm_b >= 0 and tmm_b < 8, "tile register IDs must be 0-7"]() | |
| @parameter | |
| if a_dtype == DType.uint8 and b_dtype == DType.int8: | |
| llvm_intrinsic["llvm.x86.tdpbusd", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b)) | |
| elif a_dtype == DType.int8 and b_dtype == DType.uint8: | |
| llvm_intrinsic["llvm.x86.tdpbsud", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b)) | |
| elif a_dtype == DType.int8 and b_dtype == DType.int8: | |
| llvm_intrinsic["llvm.x86.tdpbssd", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b)) | |
| elif a_dtype == DType.uint8 and b_dtype == DType.uint8: | |
| llvm_intrinsic["llvm.x86.tdpbuud", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b)) | |
| elif a_dtype == DType.bfloat16 and b_dtype == DType.bfloat16: | |
| llvm_intrinsic["llvm.x86.tdpbf16ps", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b)) | |
| elif a_dtype == DType.float16 and b_dtype == DType.float16: | |
| llvm_intrinsic["llvm.x86.tdpfp16ps", NoneType](Int8(tmm_c), Int8(tmm_a), Int8(tmm_b)) | |
| else: | |
| constrained[False, "Unsupported dtype combination"]() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from algorithm import parallelize | |
| from layout import Layout, LayoutTensor | |
| from .amx import tile_zero, tile_load, tile_store, tile_dp, load_amx_tilecfg, release_amx_tilecfg | |
| fn matmul_amx_uint8_int8_blocked[ | |
| M: Int, | |
| N: Int, | |
| K: Int, | |
| PACK_N_BLOCK: Int, | |
| PACK_K_BLOCK: Int, | |
| A_layout: Layout, | |
| B_layout: Layout, | |
| C_layout: Layout, | |
| ]( | |
| A: LayoutTensor[DType.uint8, A_layout, _], | |
| B: LayoutTensor[DType.int8, B_layout, _], | |
| C: LayoutTensor[mut=True, DType.int32, C_layout, _], | |
| n_offset: Int = 0, | |
| n_size: Int = N, | |
| ): | |
| alias TILE_M = 16 | |
| alias TILE_N = 16 | |
| alias TILE_K = 64 | |
| alias M_STEP = 32 | |
| alias N_STEP = 32 | |
| alias K_STEP = 64 | |
| constrained[M % M_STEP == 0, "M must be divisible by M_STEP"]() | |
| constrained[N % N_STEP == 0, "N must be divisible by N_STEP"]() | |
| constrained[K % K_STEP == 0, "K must be divisible by K_STEP"]() | |
| constrained[M_STEP == TILE_M * 2, "M_STEP must be 2*TILE_M"]() | |
| constrained[N_STEP == TILE_N * 2, "N_STEP must be 2*TILE_N"]() | |
| constrained[K_STEP == TILE_K, "K_STEP must equal TILE_K"]() | |
| for k_block_begin in range(0, K, PACK_K_BLOCK): | |
| var k_block_size = min(PACK_K_BLOCK, K - k_block_begin) | |
| for m_begin in range(0, M, M_STEP): | |
| for n_begin in range(n_offset, n_offset + n_size, N_STEP): | |
| if k_block_begin == 0: | |
| tile_zero[4]() | |
| tile_zero[5]() | |
| tile_zero[6]() | |
| tile_zero[7]() | |
| else: | |
| var c00_ptr = C.ptr.offset(m_begin * N + n_begin) | |
| var c01_ptr = C.ptr.offset(m_begin * N + n_begin + TILE_N) | |
| var c10_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin) | |
| var c11_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin + TILE_N) | |
| tile_load[4, N * 4](c00_ptr) | |
| tile_load[5, N * 4](c01_ptr) | |
| tile_load[6, N * 4](c10_ptr) | |
| tile_load[7, N * 4](c11_ptr) | |
| for k_begin in range(0, k_block_size, K_STEP): | |
| var k = k_block_begin + k_begin | |
| var a0_ptr = A.ptr.offset(m_begin * K + k) | |
| var a1_ptr = A.ptr.offset((m_begin + TILE_M) * K + k) | |
| var n_block_begin = (n_begin // PACK_N_BLOCK) * PACK_N_BLOCK | |
| var n_block_size = min(PACK_N_BLOCK, N - n_block_begin) | |
| var n_within_block = n_begin - n_block_begin | |
| var k_block_base = (k // PACK_K_BLOCK) * PACK_K_BLOCK | |
| var k_block_sz = min(PACK_K_BLOCK, K - k_block_base) | |
| var k_within_block = k - k_block_base | |
| var b_offset = (n_block_begin * K + | |
| k_block_base * n_block_size + | |
| n_within_block * k_block_sz + | |
| k_within_block * N_STEP) | |
| var b0_ptr = B.ptr.offset(b_offset) | |
| var b1_ptr = B.ptr.offset(b_offset + TILE_N * K_STEP) | |
| tile_load[0, K](a0_ptr) | |
| tile_load[1, K](a1_ptr) | |
| tile_load[2, K_STEP](b0_ptr) | |
| tile_load[3, K_STEP](b1_ptr) | |
| tile_dp[4, 0, 2](a0_ptr, b0_ptr) | |
| tile_dp[5, 0, 3](a0_ptr, b1_ptr) | |
| tile_dp[6, 1, 2](a1_ptr, b0_ptr) | |
| tile_dp[7, 1, 3](a1_ptr, b1_ptr) | |
| var c00_ptr = C.ptr.offset(m_begin * N + n_begin) | |
| var c01_ptr = C.ptr.offset(m_begin * N + n_begin + TILE_N) | |
| var c10_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin) | |
| var c11_ptr = C.ptr.offset((m_begin + TILE_M) * N + n_begin + TILE_N) | |
| tile_store[4, N * 4](c00_ptr) | |
| tile_store[5, N * 4](c01_ptr) | |
| tile_store[6, N * 4](c10_ptr) | |
| tile_store[7, N * 4](c11_ptr) | |
| fn matmul_amx_uint8_int8_blocked_parallel[ | |
| M: Int, | |
| N: Int, | |
| K: Int, | |
| PACK_N_BLOCK: Int, | |
| PACK_K_BLOCK: Int, | |
| PARALLEL_N_CHUNK: Int, | |
| num_workers: Int, | |
| A_layout: Layout, | |
| B_layout: Layout, | |
| C_layout: Layout, | |
| ]( | |
| A: LayoutTensor[DType.uint8, A_layout, _], | |
| B: LayoutTensor[DType.int8, B_layout, _], | |
| C: LayoutTensor[mut=True, DType.int32, C_layout, _], | |
| ): | |
| constrained[N % PARALLEL_N_CHUNK == 0, "N must be divisible by PARALLEL_N_CHUNK"]() | |
| @parameter | |
| fn worker(thread_id: Int): | |
| var n_start = thread_id * PARALLEL_N_CHUNK | |
| var n_end = min(n_start + PARALLEL_N_CHUNK, N) | |
| var n_slice_size = n_end - n_start | |
| matmul_amx_uint8_int8_blocked[M, N, K, PACK_N_BLOCK, PACK_K_BLOCK, A_layout, B_layout, C_layout]( | |
| A, B, C, n_start, n_slice_size | |
| ) | |
| var num_blocks = N // PARALLEL_N_CHUNK | |
| parallelize[worker](num_blocks, num_workers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from layout import Layout, LayoutTensor | |
| from memory import ImmutUnsafePointer, MutUnsafePointer | |
| struct QuantizedTensor[ | |
| dtype: DType, | |
| layout: Layout, | |
| scales_layout: Layout, | |
| origin: Origin[_], | |
| ]: | |
| var tensor: LayoutTensor[Self.dtype, Self.layout, Self.origin] | |
| var scales: LayoutTensor[DType.float32, Self.scales_layout, Self.origin] | |
| fn __init__( | |
| out self, | |
| tensor: LayoutTensor[Self.dtype, Self.layout, Self.origin], | |
| scales: LayoutTensor[DType.float32, Self.scales_layout, Self.origin], | |
| ): | |
| self.tensor = tensor | |
| self.scales = scales | |
| fn quantize_symmetric_uint8_channelwise_row[ | |
| M: Int, | |
| K: Int, | |
| ]( | |
| data_fp32: ImmutUnsafePointer[Float32], | |
| data_uint8: MutUnsafePointer[UInt8], | |
| scales: MutUnsafePointer[Float32], | |
| ): | |
| for m in range(M): | |
| var max_val = Float32(0) | |
| for k in range(K): | |
| var idx = m * K + k | |
| var val = abs(data_fp32[idx]) | |
| if val > max_val: | |
| max_val = val | |
| var scale = max_val / 255.0 | |
| if scale == 0: | |
| scale = 1.0 | |
| scales[m] = scale | |
| for k in range(K): | |
| var idx = m * K + k | |
| var quantized = Int32(data_fp32[idx] / scale) | |
| if quantized > 255: | |
| quantized = 255 | |
| elif quantized < 0: | |
| quantized = 0 | |
| data_uint8[idx] = UInt8(quantized) | |
| fn quantize_symmetric_int8_channelwise_col[ | |
| K: Int, | |
| N: Int, | |
| ]( | |
| data_fp32: ImmutUnsafePointer[Float32], | |
| data_int8: MutUnsafePointer[Int8], | |
| scales: MutUnsafePointer[Float32], | |
| ): | |
| for n in range(N): | |
| var max_val = Float32(0) | |
| for k in range(K): | |
| var idx = k * N + n | |
| var val = abs(data_fp32[idx]) | |
| if val > max_val: | |
| max_val = val | |
| var scale = max_val / 127.0 | |
| if scale == 0: | |
| scale = 1.0 | |
| scales[n] = scale | |
| for k in range(K): | |
| var idx = k * N + n | |
| var quantized = Int32(data_fp32[idx] / scale) | |
| if quantized > 127: | |
| quantized = 127 | |
| elif quantized < -128: | |
| quantized = -128 | |
| data_int8[idx] = Int8(quantized) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from algorithm import vectorize, parallelize | |
| from layout import Layout, LayoutTensor | |
| from sys.intrinsics import llvm_intrinsic | |
| from .quantization import QuantizedTensor | |
| fn vpdpbusd[width: Int]( | |
| src: SIMD[DType.int32, 4 * width], | |
| a: SIMD[DType.uint8, 16 * width], | |
| b: SIMD[DType.int8, 16 * width] | |
| ) -> SIMD[DType.int32, 4 * width]: | |
| return llvm_intrinsic[ | |
| "llvm.x86.avx512.vpdpbusd." + String(128 * width), | |
| SIMD[DType.int32, 4 * width] | |
| ](src, a, b) | |
| fn vpdpbusd_512( | |
| src: SIMD[DType.int32, 16], | |
| a: SIMD[DType.uint8, 64], | |
| b: SIMD[DType.int8, 64] | |
| ) -> SIMD[DType.int32, 16]: | |
| return vpdpbusd[4](src, a, b) | |
| fn small_matmul_uint8_vnni_channelwise_parallel[ | |
| M: Int, | |
| N: Int, | |
| K: Int, | |
| num_workers: Int, | |
| A_layout: Layout, | |
| B_layout: Layout, | |
| C_layout: Layout, | |
| A_scales_layout: Layout, | |
| B_scales_layout: Layout, | |
| A_origin: Origin[_], | |
| B_origin: Origin[_], | |
| C_origin: MutOrigin, | |
| ]( | |
| A: QuantizedTensor[DType.uint8, A_layout, A_scales_layout, A_origin], | |
| B: QuantizedTensor[DType.int8, B_layout, B_scales_layout, B_origin], | |
| C: LayoutTensor[mut=True, DType.int32, C_layout, C_origin], | |
| ): | |
| constrained[K % 64 == 0, "K must be divisible by 64 for VNNI"]() | |
| @parameter | |
| fn compute_row(m: Int): | |
| for n in range(N): | |
| var acc = SIMD[DType.int32, 16](0) | |
| @parameter | |
| fn dot_k[width: Int](k: Int): | |
| var a_u8 = A.tensor.load[width=64](m, k) | |
| var b_u8 = B.tensor.load[width=64](n, k) | |
| acc = vpdpbusd_512(acc, a_u8, b_u8) | |
| vectorize[dot_k, 64, size=K]() | |
| var scale_a = A.scales[m, 0][0] | |
| var scale_b = B.scales[0, n][0] | |
| var result = Int32(Float32(acc.reduce_add()) * scale_a * scale_b) | |
| C.store(m, n, SIMD[DType.int32, 1](result)) | |
| parallelize[compute_row](M, num_workers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from memory import ImmutUnsafePointer, MutUnsafePointer, alloc | |
| from sys.intrinsics import _RegisterPackType, PrefetchOptions | |
| fn pack_b_matrix_vnni_int8_small[ | |
| K: Int, | |
| N: Int, | |
| ]( | |
| src: ImmutUnsafePointer[Int8], | |
| dst: MutUnsafePointer[Int8], | |
| ): | |
| for k in range(K): | |
| for n in range(N): | |
| var src_idx = n + k * N | |
| var dst_idx = n * K + k | |
| dst[dst_idx] = src[src_idx] | |
| fn pack_b_matrix_vnni_int8_large[ | |
| K: Int, | |
| N: Int, | |
| N_BLOCK: Int, | |
| K_BLOCK: Int, | |
| TILE_K: Int = 64, | |
| TILE_N: Int = 16, | |
| N_STEP: Int = 32, | |
| K_STEP: Int = 64, | |
| ]( | |
| src: ImmutUnsafePointer[Int8], | |
| dst: MutUnsafePointer[Int8], | |
| ): | |
| constrained[N % N_STEP == 0, "N must be divisible by N_STEP"]() | |
| constrained[K % K_STEP == 0, "K must be divisible by K_STEP"]() | |
| constrained[N_STEP == TILE_N * 2, "N_STEP must be 2*TILE_N"]() | |
| constrained[K_STEP == TILE_K, "K_STEP must equal TILE_K"]() | |
| var temp = alloc[Int8](256) | |
| for n_block_begin in range(0, N, N_BLOCK): | |
| var n_block_end = min(n_block_begin + N_BLOCK, N) | |
| var n_block_size = n_block_end - n_block_begin | |
| for k_block_begin in range(0, K, K_BLOCK): | |
| var k_block_end = min(k_block_begin + K_BLOCK, K) | |
| var k_block_size = k_block_end - k_block_begin | |
| for n_begin in range(0, n_block_size, N_STEP): | |
| for k_begin in range(0, k_block_size, K_STEP): | |
| var tile_base = (n_block_begin * K + | |
| k_block_begin * n_block_size + | |
| n_begin * k_block_size + | |
| k_begin * N_STEP) | |
| for i in range(N_STEP): | |
| var src_row = n_block_begin + n_begin + i | |
| var src_col = k_block_begin + k_begin | |
| for k in range(K_STEP): | |
| dst[tile_base + i * K_STEP + k] = src[src_row * K + src_col + k] | |
| transpose_16x16_vnni_int8_with_temp( | |
| dst.offset(tile_base), | |
| temp, | |
| K_STEP | |
| ) | |
| transpose_16x16_vnni_int8_with_temp( | |
| dst.offset(tile_base + TILE_N * K_STEP), | |
| temp, | |
| K_STEP | |
| ) | |
| temp.free() | |
| @always_inline | |
| fn transpose_16x16_vnni_int8_with_temp( | |
| ptr: MutUnsafePointer[Int8], | |
| temp: MutUnsafePointer[Int8], | |
| src_stride: Int | |
| ): | |
| alias VNNI_BLK = 4 | |
| alias TILE_SIZE = 16 | |
| for i in range(TILE_SIZE): | |
| for j in range(TILE_SIZE): | |
| temp[i * TILE_SIZE + j] = ptr[i * src_stride + j] | |
| for row in range(0, TILE_SIZE, VNNI_BLK): | |
| for col in range(TILE_SIZE): | |
| for vnni in range(VNNI_BLK): | |
| ptr[col * TILE_SIZE + row + vnni] = temp[(row + vnni) * TILE_SIZE + col] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment