Skip to content

Instantly share code, notes, and snippets.

View minjang's full-sized avatar

Minjang Kim minjang

  • Facebook
  • Menlo Park, CA
View GitHub Profile
@minjang
minjang / matmul_kernel.ttmir
Last active November 20, 2024 08:03
TTMIR for matmul_kernel (03-matrix-multiplication-cpu.py)
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
#loc = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cpu", "triton_gpu.threads-per-warp" = 1 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg1: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg2: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg3: i32 {tt.
@minjang
minjang / matmul_kernel.asm
Last active November 20, 2024 07:57
x86-64 (AVX512) for matmul_kernel (03-matrix-multiplication-cpu.py) from TTMIR
.text
.file "LLVMDialectModule"
.section .rodata,"a",@progbits
.p2align 6, 0x0 # -- Begin function matmul_kernel
.LCPI0_0:
.zero 4
.long 1 # 0x1
.long 2 # 0x2
.long 3 # 0x3
.long 4 # 0x4
@minjang
minjang / 03-matrix-multiplication-cpu.py
Created November 20, 2024 08:00
matmul_kernel for 03-matrix-multiplication-cpu.py without leaky_relu
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #