Skip to content

Instantly share code, notes, and snippets.

@sueszli
Last active February 22, 2026 12:34
Show Gist options
  • Select an option

  • Save sueszli/7d8f72bf998b012a926c454df2b8abbc to your computer and use it in GitHub Desktop.

Select an option

Save sueszli/7d8f72bf998b012a926c454df2b8abbc to your computer and use it in GitHub Desktop.
# /// script
# dependencies = ["exo-lang"]
# ///
from __future__ import annotations
import ctypes
import subprocess
import tempfile
from pathlib import Path
from exo import *
from exo import compile_procs
@proc
def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM):
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
M, N, K = 4, 4, 4
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
compile_procs([sgemm], tmp, "sgemm.c", "sgemm.h")
lib_path = tmp / "libsgemm.so"
subprocess.run(["clang", "-shared", "-fPIC", "-O3", "-I", str(tmp), "-o", str(lib_path), str(tmp / "sgemm.c")], check=True)
lib = ctypes.CDLL(str(lib_path))
lib.sgemm.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)]
lib.sgemm.restype = None
A = (ctypes.c_float * (M * K))(*[2.0] * (M * K))
B = (ctypes.c_float * (K * N))(*[3.0] * (K * N))
C = (ctypes.c_float * (M * N))()
lib.sgemm(None, M, N, K, C, A, B)
assert all(abs(C[i] - 6.0 * K) < 1e-3 for i in range(M * N)), f"expected {6.0 * K}, got {list(C)}"
print(f"ok")
# /// script
# dependencies = ["exo-lang", "numpy"]
# ///
from __future__ import annotations
import ctypes
import subprocess
import tempfile
import time
from pathlib import Path
import numpy as np
from exo import *
from exo import compile_procs
from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, unroll_loop
M, N, K = 512, 512, 512
@proc
def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM):
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
# 4x16 register-blocked microkernel, k-blocked for L1.
# loop order: io > jo > ko > ki > ii(unrolled) > ji(unrolled)
p = sgemm.partial_eval(M=M, N=N, K=K)
p = rename(p, "sgemm_fast")
p = divide_loop(p, "i", 4, ["io", "ii"], perfect=True)
p = divide_loop(p, "j", 16, ["jo", "ji"], perfect=True)
p = divide_loop(p, "k", 64, ["ko", "ki"], perfect=True)
p = reorder_loops(p, "ii jo")
p = reorder_loops(p, "ji ko")
p = reorder_loops(p, "ii ko")
p = reorder_loops(p, "ji ki")
p = reorder_loops(p, "ii ki")
p = unroll_loop(p, "ji") # must be first
p = unroll_loop(p, "ii")
sgemm_fast = simplify(p)
def build(procs, tmp):
tmp = Path(tmp)
compile_procs(list(procs), tmp, "sgemm.c", "sgemm.h")
src = (tmp / "sgemm.c").read_text()
src = src.replace("float* C,", "float* __restrict__ C,")
src = src.replace("const float* A,", "const float* __restrict__ A,")
src = src.replace("const float* B", "const float* __restrict__ B")
(tmp / "sgemm.c").write_text(src)
lib = tmp / "lib.so"
subprocess.run(["clang", "-shared", "-fPIC", "-O3", "-march=native", "-ffast-math", "-I", str(tmp), "-o", str(lib), str(tmp / "sgemm.c")], check=True)
return ctypes.CDLL(str(lib))
def bench(fn, args, reset=None, n=30):
for _ in range(n):
if reset:
reset()
fn(*args())
t0 = time.perf_counter()
for _ in range(n):
if reset:
reset()
fn(*args())
return (time.perf_counter() - t0) / n * 1000
A = np.full((M, K), 2.0, dtype=np.float32)
B = np.full((K, N), 3.0, dtype=np.float32)
numpy_ms = bench(lambda: A @ B, lambda: (), n=30)
with tempfile.TemporaryDirectory() as tmp:
lib = build([sgemm_fast], tmp)
fp = ctypes.POINTER(ctypes.c_float)
C = np.zeros((M, N), dtype=np.float32)
reset = lambda: C.__setitem__(slice(None), 0)
fn = lib.sgemm_fast
fn.restype = None
fn.argtypes = [ctypes.c_void_p, fp, fp, fp]
exo_ms = bench(fn, lambda: (None, C.ctypes.data_as(fp), A.ctypes.data_as(fp), B.ctypes.data_as(fp)), reset=reset)
assert np.allclose(C, 6.0 * K, atol=1e-3)
gflops = 2 * M * N * K / 1e9
print(f"numpy (BLAS): {numpy_ms:.2f} ms ({gflops/numpy_ms*1000:.0f} GFLOPS)")
print(f"exo 4x16: {exo_ms:.2f} ms ({gflops/exo_ms*1000:.0f} GFLOPS) ({exo_ms/numpy_ms:.1f}x slower)")
# /// script
# dependencies = ["exo-lang", "numpy"]
# ///
from __future__ import annotations
import ctypes
import subprocess
import tempfile
import time
from pathlib import Path
import numpy as np
from exo import compile_procs, f32, proc, seq, size
from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop
M, N, K = 512, 512, 512
@proc
def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM):
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
# optimize
sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt")
sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True)
sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True)
for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]:
sgemm_opt = reorder_loops(sgemm_opt, swap)
sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack")
sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local")
sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji"))
def build_exo(proc_obj):
tmp = Path(tempfile.mkdtemp())
compile_procs([proc_obj], tmp, "out.c", "out.h")
subprocess.run(["clang", "-shared", "-fPIC", "-O3", "-march=native", "-ffast-math", "-I", str(tmp), "-o", str(tmp / "lib.so"), str(tmp / "out.c")], check=True)
fn = getattr(ctypes.CDLL(str(tmp / "lib.so")), proc_obj.name())
def wrapper(*args):
c_args = [None] + [a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) if isinstance(a, np.ndarray) else ctypes.c_int(a) for a in args]
fn.argtypes, fn.restype = [ctypes.c_void_p] + [type(a) for a in c_args[1:]], None
fn(*c_args)
return wrapper
def bench(fn, reset=None, n=10):
for _ in range(n):
if reset:
reset()
fn()
t0 = time.perf_counter()
for _ in range(n):
if reset:
reset()
fn()
return (time.perf_counter() - t0) / n * 1000
# numpy bench
A_np = np.full((M, K), 2.0, dtype=np.float32)
B_np = np.full((K, N), 3.0, dtype=np.float32)
C_buf = np.zeros((M, N), dtype=np.float32)
numpy_ms = bench(lambda: A_np @ B_np, n=20)
# exo bench
opt_fn = build_exo(sgemm_opt)
opt_ms = bench(lambda: opt_fn(C_buf, A_np, B_np), reset=lambda: C_buf.fill(0))
print(f"numpy (BLAS): {numpy_ms:.2f} ms")
print(f"exo tiled: {opt_ms:.2f} ms")
# /// script
# dependencies = ["exo-lang"]
# ///
import ctypes
import subprocess
import tempfile
import time
from pathlib import Path
from exo import compile_procs, f32, proc, seq, size
from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop
M, N, K = 512, 512, 512
@proc
def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM):
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt")
sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True)
sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True)
for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]:
sgemm_opt = reorder_loops(sgemm_opt, swap)
sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack")
sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local")
sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji"))
def build(p):
d = Path(tempfile.mkdtemp())
compile_procs([p], d, "o.c", "o.h")
subprocess.run(["clang", "-shared", "-O3", "-march=native", "-ffast-math", "-o", str(d / "lib.so"), str(d / "o.c")], check=True)
fn = getattr(ctypes.CDLL(str(d / "lib.so")), p.name())
fn.argtypes = [ctypes.c_void_p] + [ctypes.POINTER(ctypes.c_float)] * 3
return lambda *args: fn(None, *args)
def bench(fn, args, n=10):
for _ in range(n):
fn(*args)
t0 = time.perf_counter()
for _ in range(n):
fn(*args)
return (time.perf_counter() - t0) / n * 1000
F32 = ctypes.c_float
C, A, B = (F32 * (M * N))(), (F32 * (M * K))(), (F32 * (K * N))()
opt_fn = build(sgemm_opt)
print(f"time {bench(opt_fn, [C, A, B]):.2f} ms")
# /// script
# dependencies = ["exo-lang", "numpy"]
# ///
from __future__ import annotations
import ctypes
import subprocess
import tempfile
from pathlib import Path
import numpy as np
from exo import *
from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop
M, N, K = 512, 512, 512
@proc
def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM):
for i in seq(0, M):
for j in seq(0, N):
for k in seq(0, K):
C[i, j] += A[i, k] * B[k, j]
def build(p):
d = Path(tempfile.mkdtemp())
compile_procs([p], d, "o.c", "o.h")
subprocess.run(["clang", "-shared", "-O3", "-march=native", "-ffast-math", "-o", str(d / "lib.so"), str(d / "o.c")], check=True)
fn = getattr(ctypes.CDLL(str(d / "lib.so")), p.name())
fn.argtypes = [ctypes.c_void_p] + [ctypes.POINTER(ctypes.c_float)] * 3
return lambda *args: fn(None, *args)
sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt")
sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True)
sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True)
for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]:
sgemm_opt = reorder_loops(sgemm_opt, swap)
sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack")
sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local")
sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji"))
A = np.random.rand(M, K).astype(np.float32)
B = np.random.rand(K, N).astype(np.float32)
C = np.zeros((M, N), dtype=np.float32)
F32_PTR = ctypes.POINTER(ctypes.c_float)
build(sgemm_opt)(C.ctypes.data_as(F32_PTR), A.ctypes.data_as(F32_PTR), B.ctypes.data_as(F32_PTR))
assert np.allclose(C, A @ B, atol=1e-3)
print("ok.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment