Skip to content

Instantly share code, notes, and snippets.

@sueszli
Created February 22, 2026 12:34
Show Gist options
  • Select an option

  • Save sueszli/76acdde675002f64a34a6a3967cbb7fa to your computer and use it in GitHub Desktop.

Select an option

Save sueszli/76acdde675002f64a34a6a3967cbb7fa to your computer and use it in GitHub Desktop.
# /// script
# dependencies = ["exo-lang", "numpy"]
# ///
from __future__ import annotations
import ctypes
import subprocess
import tempfile
from pathlib import Path
import numpy as np
from exo import *
from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop
M, N, K = 512, 512, 512
@proc
def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM):
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
def build(p):
d = Path(tempfile.mkdtemp())
compile_procs([p], d, "o.c", "o.h")
subprocess.run(["clang", "-shared", "-O3", "-march=native", "-ffast-math", "-o", str(d / "lib.so"), str(d / "o.c")], check=True)
fn = getattr(ctypes.CDLL(str(d / "lib.so")), p.name())
fn.argtypes = [ctypes.c_void_p] + [ctypes.POINTER(ctypes.c_float)] * 3
return lambda *args: fn(None, *args)
sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt")
sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True)
sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True)
for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]:
sgemm_opt = reorder_loops(sgemm_opt, swap)
sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack")
sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local")
sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji"))
A = np.random.rand(M, K).astype(np.float32)
B = np.random.rand(K, N).astype(np.float32)
C = np.zeros((M, N), dtype=np.float32)
F32_PTR = ctypes.POINTER(ctypes.c_float)
build(sgemm_opt)(C.ctypes.data_as(F32_PTR), A.ctypes.data_as(F32_PTR), B.ctypes.data_as(F32_PTR))
assert np.allclose(C, A @ B, atol=1e-3)
print("ok.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment