Last active
February 22, 2026 12:34
-
-
Save sueszli/7d8f72bf998b012a926c454df2b8abbc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = ["exo-lang"] | |
| # /// | |
| from __future__ import annotations | |
| import ctypes | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| from exo import * | |
| from exo import compile_procs | |
| @proc | |
| def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM): | |
| for i in seq(0, M): | |
| for j in seq(0, N): | |
| for k in seq(0, K): | |
| C[i, j] += A[i, k] * B[k, j] | |
| M, N, K = 4, 4, 4 | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp = Path(tmp) | |
| compile_procs([sgemm], tmp, "sgemm.c", "sgemm.h") | |
| lib_path = tmp / "libsgemm.so" | |
| subprocess.run(["clang", "-shared", "-fPIC", "-O3", "-I", str(tmp), "-o", str(lib_path), str(tmp / "sgemm.c")], check=True) | |
| lib = ctypes.CDLL(str(lib_path)) | |
| lib.sgemm.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)] | |
| lib.sgemm.restype = None | |
| A = (ctypes.c_float * (M * K))(*[2.0] * (M * K)) | |
| B = (ctypes.c_float * (K * N))(*[3.0] * (K * N)) | |
| C = (ctypes.c_float * (M * N))() | |
| lib.sgemm(None, M, N, K, C, A, B) | |
| assert all(abs(C[i] - 6.0 * K) < 1e-3 for i in range(M * N)), f"expected {6.0 * K}, got {list(C)}" | |
| print(f"ok") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = ["exo-lang", "numpy"] | |
| # /// | |
| from __future__ import annotations | |
| import ctypes | |
| import subprocess | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| from exo import * | |
| from exo import compile_procs | |
| from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, unroll_loop | |
| M, N, K = 512, 512, 512 | |
| @proc | |
| def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM): | |
| for i in seq(0, M): | |
| for j in seq(0, N): | |
| for k in seq(0, K): | |
| C[i, j] += A[i, k] * B[k, j] | |
| # 4x16 register-blocked microkernel, k-blocked for L1. | |
| # loop order: io > jo > ko > ki > ii(unrolled) > ji(unrolled) | |
| p = sgemm.partial_eval(M=M, N=N, K=K) | |
| p = rename(p, "sgemm_fast") | |
| p = divide_loop(p, "i", 4, ["io", "ii"], perfect=True) | |
| p = divide_loop(p, "j", 16, ["jo", "ji"], perfect=True) | |
| p = divide_loop(p, "k", 64, ["ko", "ki"], perfect=True) | |
| p = reorder_loops(p, "ii jo") | |
| p = reorder_loops(p, "ji ko") | |
| p = reorder_loops(p, "ii ko") | |
| p = reorder_loops(p, "ji ki") | |
| p = reorder_loops(p, "ii ki") | |
| p = unroll_loop(p, "ji") # must be first | |
| p = unroll_loop(p, "ii") | |
| sgemm_fast = simplify(p) | |
| def build(procs, tmp): | |
| tmp = Path(tmp) | |
| compile_procs(list(procs), tmp, "sgemm.c", "sgemm.h") | |
| src = (tmp / "sgemm.c").read_text() | |
| src = src.replace("float* C,", "float* __restrict__ C,") | |
| src = src.replace("const float* A,", "const float* __restrict__ A,") | |
| src = src.replace("const float* B", "const float* __restrict__ B") | |
| (tmp / "sgemm.c").write_text(src) | |
| lib = tmp / "lib.so" | |
| subprocess.run(["clang", "-shared", "-fPIC", "-O3", "-march=native", "-ffast-math", "-I", str(tmp), "-o", str(lib), str(tmp / "sgemm.c")], check=True) | |
| return ctypes.CDLL(str(lib)) | |
| def bench(fn, args, reset=None, n=30): | |
| for _ in range(n): | |
| if reset: | |
| reset() | |
| fn(*args()) | |
| t0 = time.perf_counter() | |
| for _ in range(n): | |
| if reset: | |
| reset() | |
| fn(*args()) | |
| return (time.perf_counter() - t0) / n * 1000 | |
| A = np.full((M, K), 2.0, dtype=np.float32) | |
| B = np.full((K, N), 3.0, dtype=np.float32) | |
| numpy_ms = bench(lambda: A @ B, lambda: (), n=30) | |
| with tempfile.TemporaryDirectory() as tmp: | |
| lib = build([sgemm_fast], tmp) | |
| fp = ctypes.POINTER(ctypes.c_float) | |
| C = np.zeros((M, N), dtype=np.float32) | |
| reset = lambda: C.__setitem__(slice(None), 0) | |
| fn = lib.sgemm_fast | |
| fn.restype = None | |
| fn.argtypes = [ctypes.c_void_p, fp, fp, fp] | |
| exo_ms = bench(fn, lambda: (None, C.ctypes.data_as(fp), A.ctypes.data_as(fp), B.ctypes.data_as(fp)), reset=reset) | |
| assert np.allclose(C, 6.0 * K, atol=1e-3) | |
| gflops = 2 * M * N * K / 1e9 | |
| print(f"numpy (BLAS): {numpy_ms:.2f} ms ({gflops/numpy_ms*1000:.0f} GFLOPS)") | |
| print(f"exo 4x16: {exo_ms:.2f} ms ({gflops/exo_ms*1000:.0f} GFLOPS) ({exo_ms/numpy_ms:.1f}x slower)") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = ["exo-lang", "numpy"] | |
| # /// | |
| from __future__ import annotations | |
| import ctypes | |
| import subprocess | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| from exo import compile_procs, f32, proc, seq, size | |
| from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop | |
| M, N, K = 512, 512, 512 | |
| @proc | |
| def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM): | |
| for i in seq(0, M): | |
| for j in seq(0, N): | |
| for k in seq(0, K): | |
| C[i, j] += A[i, k] * B[k, j] | |
| # optimize | |
| sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt") | |
| sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True) | |
| sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True) | |
| for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]: | |
| sgemm_opt = reorder_loops(sgemm_opt, swap) | |
| sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack") | |
| sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local") | |
| sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji")) | |
| def build_exo(proc_obj): | |
| tmp = Path(tempfile.mkdtemp()) | |
| compile_procs([proc_obj], tmp, "out.c", "out.h") | |
| subprocess.run(["clang", "-shared", "-fPIC", "-O3", "-march=native", "-ffast-math", "-I", str(tmp), "-o", str(tmp / "lib.so"), str(tmp / "out.c")], check=True) | |
| fn = getattr(ctypes.CDLL(str(tmp / "lib.so")), proc_obj.name()) | |
| def wrapper(*args): | |
| c_args = [None] + [a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) if isinstance(a, np.ndarray) else ctypes.c_int(a) for a in args] | |
| fn.argtypes, fn.restype = [ctypes.c_void_p] + [type(a) for a in c_args[1:]], None | |
| fn(*c_args) | |
| return wrapper | |
| def bench(fn, reset=None, n=10): | |
| for _ in range(n): | |
| if reset: | |
| reset() | |
| fn() | |
| t0 = time.perf_counter() | |
| for _ in range(n): | |
| if reset: | |
| reset() | |
| fn() | |
| return (time.perf_counter() - t0) / n * 1000 | |
| # numpy bench | |
| A_np = np.full((M, K), 2.0, dtype=np.float32) | |
| B_np = np.full((K, N), 3.0, dtype=np.float32) | |
| C_buf = np.zeros((M, N), dtype=np.float32) | |
| numpy_ms = bench(lambda: A_np @ B_np, n=20) | |
| # exo bench | |
| opt_fn = build_exo(sgemm_opt) | |
| opt_ms = bench(lambda: opt_fn(C_buf, A_np, B_np), reset=lambda: C_buf.fill(0)) | |
| print(f"numpy (BLAS): {numpy_ms:.2f} ms") | |
| print(f"exo tiled: {opt_ms:.2f} ms") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = ["exo-lang"] | |
| # /// | |
| import ctypes | |
| import subprocess | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from exo import compile_procs, f32, proc, seq, size | |
| from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop | |
| M, N, K = 512, 512, 512 | |
| @proc | |
| def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM): | |
| for i in seq(0, M): | |
| for j in seq(0, N): | |
| for k in seq(0, K): | |
| C[i, j] += A[i, k] * B[k, j] | |
| sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt") | |
| sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True) | |
| sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True) | |
| for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]: | |
| sgemm_opt = reorder_loops(sgemm_opt, swap) | |
| sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack") | |
| sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local") | |
| sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji")) | |
| def build(p): | |
| d = Path(tempfile.mkdtemp()) | |
| compile_procs([p], d, "o.c", "o.h") | |
| subprocess.run(["clang", "-shared", "-O3", "-march=native", "-ffast-math", "-o", str(d / "lib.so"), str(d / "o.c")], check=True) | |
| fn = getattr(ctypes.CDLL(str(d / "lib.so")), p.name()) | |
| fn.argtypes = [ctypes.c_void_p] + [ctypes.POINTER(ctypes.c_float)] * 3 | |
| return lambda *args: fn(None, *args) | |
| def bench(fn, args, n=10): | |
| for _ in range(n): | |
| fn(*args) | |
| t0 = time.perf_counter() | |
| for _ in range(n): | |
| fn(*args) | |
| return (time.perf_counter() - t0) / n * 1000 | |
| F32 = ctypes.c_float | |
| C, A, B = (F32 * (M * N))(), (F32 * (M * K))(), (F32 * (K * N))() | |
| opt_fn = build(sgemm_opt) | |
| print(f"time {bench(opt_fn, [C, A, B]):.2f} ms") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = ["exo-lang", "numpy"] | |
| # /// | |
| from __future__ import annotations | |
| import ctypes | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| from exo import * | |
| from exo.stdlib.scheduling import divide_loop, rename, reorder_loops, simplify, stage_mem, unroll_loop | |
| M, N, K = 512, 512, 512 | |
| @proc | |
| def sgemm(M: size, N: size, K: size, C: f32[M, N] @ DRAM, A: f32[M, K] @ DRAM, B: f32[K, N] @ DRAM): | |
| for i in seq(0, M): | |
| for j in seq(0, N): | |
| for k in seq(0, K): | |
| C[i, j] += A[i, k] * B[k, j] | |
| def build(p): | |
| d = Path(tempfile.mkdtemp()) | |
| compile_procs([p], d, "o.c", "o.h") | |
| subprocess.run(["clang", "-shared", "-O3", "-march=native", "-ffast-math", "-o", str(d / "lib.so"), str(d / "o.c")], check=True) | |
| fn = getattr(ctypes.CDLL(str(d / "lib.so")), p.name()) | |
| fn.argtypes = [ctypes.c_void_p] + [ctypes.POINTER(ctypes.c_float)] * 3 | |
| return lambda *args: fn(None, *args) | |
| sgemm_opt = rename(sgemm.partial_eval(M=M, N=N, K=K), "sgemm_opt") | |
| sgemm_opt = divide_loop(sgemm_opt, "j", 16, ["jo", "ji"], perfect=True) | |
| sgemm_opt = divide_loop(sgemm_opt, "k", 64, ["ko", "ki"], perfect=True) | |
| for swap in ["ji ko", "ji ki", "i jo", "i ko", "jo ko"]: | |
| sgemm_opt = reorder_loops(sgemm_opt, swap) | |
| sgemm_opt = stage_mem(sgemm_opt, "for i in _: _", "B[ko*64 : ko*64+64, jo*16 : jo*16+16]", "B_pack") | |
| sgemm_opt = stage_mem(sgemm_opt, "for ki in _: _", "C[i, jo*16 : jo*16+16]", "C_local") | |
| sgemm_opt = simplify(unroll_loop(sgemm_opt, "ji")) | |
| A = np.random.rand(M, K).astype(np.float32) | |
| B = np.random.rand(K, N).astype(np.float32) | |
| C = np.zeros((M, N), dtype=np.float32) | |
| F32_PTR = ctypes.POINTER(ctypes.c_float) | |
| build(sgemm_opt)(C.ctypes.data_as(F32_PTR), A.ctypes.data_as(F32_PTR), B.ctypes.data_as(F32_PTR)) | |
| assert np.allclose(C, A @ B, atol=1e-3) | |
| print("ok.") |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
original lib:
mlir: