Skip to content

Instantly share code, notes, and snippets.

@saulshanabrook
Last active May 1, 2025 12:06
Show Gist options
  • Save saulshanabrook/51ba3040fcedb5fea4697ec7365c1a7c to your computer and use it in GitHub Desktop.
Save saulshanabrook/51ba3040fcedb5fea4697ec7365c1a7c to your computer and use it in GitHub Desktop.
diff --git a/src/mlir_egglog/jit_engine.py b/src/mlir_egglog/jit_engine.py
index de9ac2f..e3b3aa0 100644
--- a/src/mlir_egglog/jit_engine.py
+++ b/src/mlir_egglog/jit_engine.py
@@ -10,7 +10,6 @@ from egglog import RewriteOrRule, Ruleset
import llvmlite.binding as llvm
from mlir_egglog.llvm_runtime import (
- create_execution_engine,
init_llvm,
compile_mod,
)
@@ -35,8 +34,6 @@ class JITEngine:
omppath = find_omp_path()
ctypes.CDLL(omppath, mode=os.RTLD_NOW)
- self.ee = create_execution_engine()
-
def run_frontend(
self,
fn: FunctionType,
@@ -65,11 +62,11 @@ class JITEngine:
llvm_ir += "\n"
mod = llvm.parse_assembly(llvm_ir)
- mod = compile_mod(self.ee, mod)
+ ee = compile_mod(mod)
# Resolve the function address
func_name = f"_mlir_ciface_{KERNEL_NAME}"
- address = self.ee.get_function_address(func_name)
+ address = ee.get_function_address(func_name)
assert address, "Function must be compiled successfully."
return address
diff --git a/src/mlir_egglog/llvm_runtime.py b/src/mlir_egglog/llvm_runtime.py
index 7a2310c..2728f9f 100644
--- a/src/mlir_egglog/llvm_runtime.py
+++ b/src/mlir_egglog/llvm_runtime.py
@@ -11,22 +11,21 @@ def init_llvm():
llvm.initialize_all_asmprinters()
-def compile_mod(engine, mod):
+def compile_mod(mod):
mod.verify()
- engine.add_module(mod)
- engine.finalize_object()
- engine.run_static_constructors()
- return mod
+ engine = create_execution_engine(mod)
+ # engine.finalize_object()
+ # engine.run_static_constructors()
+ return engine
-def create_execution_engine():
+def create_execution_engine(mod):
+ with llvm.create_module_pass_manager() as pm:
+ with llvm.create_pass_manager_builder() as pmb:
+ pmb.populate(pm)
+ pm.run(mod)
+ print("MOD!\n", mod, "done\n\n\n")
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
- backing_mod = llvm.parse_assembly("")
- engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
+ engine = llvm.create_mcjit_compiler(mod, target_machine)
return engine
-
-
-def compile_ir(engine, llvm_ir):
- mod = llvm.parse_assembly(llvm_ir)
- return compile_mod(engine, mod)
$ uv run pytest tests/test_basic_expressions.py::TestBasicExpressions::test_relu_function -s
======================================================================================== test session starts ========================================================================================
platform darwin -- Python 3.12.5, pytest-8.3.5, pluggy-1.5.0
rootdir: /Users/saul/p/mlir-egglog
configfile: pyproject.toml
collected 1 item
tests/test_basic_expressions.py Generated MLIR:
func.func @kernel_worker(
%arg0: memref<?xf32>,
%arg1: memref<?xf32>
) attributes {llvm.emit_c_interface} {
%c0 = index.constant 0
// Get dimension of input array
%dim = memref.dim %arg0, %c0 : memref<?xf32>
// Process each element in a flattened manner
affine.for %idx = %c0 to %dim {
%arg_x = affine.load %arg0[%idx] : memref<?xf32>
%v1 = arith.constant 0.000000e+00 : f32
%v2 = arith.maximumf %arg_x, %v1 : f32
affine.store %v2, %arg1[%idx] : memref<?xf32>
}
return
}
0.44.0
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
define void @kernel_worker(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i64 %7, i64 %8, i64 %9) {
%11 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %5, 0
%12 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %11, ptr %6, 1
%13 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %12, i64 %7, 2
%14 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %13, i64 %8, 3, 0
%15 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %14, i64 %9, 4, 0
%16 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %0, 0
%17 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %16, ptr %1, 1
%18 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, i64 %2, 2
%19 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %18, i64 %3, 3, 0
%20 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %19, i64 %4, 4, 0
%21 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %20, 3, 0
br label %22
22: ; preds = %25, %10
%23 = phi i64 [ %32, %25 ], [ 0, %10 ]
%24 = icmp slt i64 %23, %21
br i1 %24, label %25, label %33
25: ; preds = %22
%26 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %20, 1
%27 = getelementptr float, ptr %26, i64 %23
%28 = load float, ptr %27, align 4
%29 = call float @llvm.maximum.f32(float %28, float 0.000000e+00)
%30 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %15, 1
%31 = getelementptr float, ptr %30, i64 %23
store float %29, ptr %31, align 4
%32 = add i64 %23, 1
br label %22
33: ; preds = %22
ret void
}
define void @_mlir_ciface_kernel_worker(ptr %0, ptr %1) {
%3 = load { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, align 8
%4 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 0
%5 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 1
%6 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 2
%7 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 3, 0
%8 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 4, 0
%9 = load { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, align 8
%10 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 0
%11 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 1
%12 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 2
%13 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 3, 0
%14 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 4, 0
call void @kernel_worker(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, ptr %10, ptr %11, i64 %12, i64 %13, i64 %14)
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare float @llvm.maximum.f32(float, float) #0
attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
Parsing LLVM assembly.
MOD!
; ModuleID = '<string>'
source_filename = "LLVMDialectModule"
; Function Attrs: argmemonly nofree nosync nounwind
define void @kernel_worker(ptr nocapture readnone %0, ptr nocapture readonly %1, i64 %2, i64 %3, i64 %4, ptr nocapture readnone %5, ptr nocapture writeonly %6, i64 %7, i64 %8, i64 %9) local_unnamed_addr #0 {
%11 = icmp sgt i64 %3, 0
br i1 %11, label %.lr.ph, label %._crit_edge
.lr.ph: ; preds = %10, %.lr.ph
%12 = phi i64 [ %17, %.lr.ph ], [ 0, %10 ]
%13 = getelementptr float, ptr %1, i64 %12
%14 = load float, ptr %13, align 4
%15 = tail call float @llvm.maximum.f32(float %14, float 0.000000e+00)
%16 = getelementptr float, ptr %6, i64 %12
store float %15, ptr %16, align 4
%17 = add nuw nsw i64 %12, 1
%18 = icmp slt i64 %17, %3
br i1 %18, label %.lr.ph, label %._crit_edge
._crit_edge: ; preds = %.lr.ph, %10
ret void
}
; Function Attrs: nofree nosync nounwind
define void @_mlir_ciface_kernel_worker(ptr nocapture readonly %0, ptr nocapture readonly %1) local_unnamed_addr #1 {
%.unpack = load ptr, ptr %0, align 8
%.elt1 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 1
%.unpack2 = load ptr, ptr %.elt1, align 8
%.elt3 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 2
%.unpack4 = load i64, ptr %.elt3, align 8
%.elt5 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 3
%.unpack6.unpack = load i64, ptr %.elt5, align 8
%.elt7 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 4
%.unpack8.unpack = load i64, ptr %.elt7, align 8
%.unpack11 = load ptr, ptr %1, align 8
%.elt12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 1
%.unpack13 = load ptr, ptr %.elt12, align 8
%.elt14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 2
%.unpack15 = load i64, ptr %.elt14, align 8
%.elt16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 3
%.unpack17.unpack = load i64, ptr %.elt16, align 8
%.elt18 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 4
%.unpack19.unpack = load i64, ptr %.elt18, align 8
tail call void @kernel_worker(ptr %.unpack, ptr %.unpack2, i64 %.unpack4, i64 %.unpack6.unpack, i64 %.unpack8.unpack, ptr %.unpack11, ptr %.unpack13, i64 %.unpack15, i64 %.unpack17.unpack, i64 %.unpack19.unpack)
ret void
}
; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn
declare float @llvm.maximum.f32(float, float) #2
attributes #0 = { argmemonly nofree nosync nounwind }
attributes #1 = { nofree nosync nounwind }
attributes #2 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
done
LLVM ERROR: Cannot select: 0x7fb51609d4d8: f32 = fmaximum 0x7fb51609dcf8, ConstantFP:f32<0.000000e+00>
0x7fb51609dcf8: f32,ch = load<(load (s32) from %ir.uglygep1)> 0x7fb515f56308, 0x7fb51609de98, undef:i64
0x7fb51609de98: i64 = add 0x7fb51609dbc0, 0x7fb51609ddc8
0x7fb51609dbc0: i64,ch = CopyFromReg 0x7fb515f56308, Register:i64 %3
0x7fb51609d338: i64 = Register %3
0x7fb51609ddc8: i64 = shl 0x7fb51609d268, Constant:i8<2>
0x7fb51609d268: i64,ch = CopyFromReg 0x7fb515f56308, Register:i64 %0
0x7fb51609dc28: i64 = Register %0
0x7fb51609d2d0: i8 = Constant<2>
0x7fb51609daf0: i64 = undef
0x7fb51609d7b0: f32 = ConstantFP<0.000000e+00>
In function: kernel_worker
Fatal Python error: Aborted
Current thread 0x00007ff85f1fddc0 (most recent call first):
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/llvmlite/binding/ffi.py", line 197 in __call__
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/llvmlite/binding/executionengine.py", line 71 in get_function_address
File "/Users/saul/p/mlir-egglog/src/mlir_egglog/jit_engine.py", line 69 in run_backend
File "/Users/saul/p/mlir-egglog/src/mlir_egglog/jit_engine.py", line 85 in jit_compile
File "/Users/saul/p/mlir-egglog/tests/test_basic_expressions.py", line 138 in test_relu_function
File "/usr/local/Cellar/[email protected]/3.12.5/Frameworks/Python.framework/Versions/3.12/lib/python3.12/unittest/case.py", line 589 in _callTestMethod
File "/usr/local/Cellar/[email protected]/3.12.5/Frameworks/Python.framework/Versions/3.12/lib/python3.12/unittest/case.py", line 634 in run
File "/usr/local/Cellar/[email protected]/3.12.5/Frameworks/Python.framework/Versions/3.12/lib/python3.12/unittest/case.py", line 690 in __call__
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/unittest.py", line 351 in runtest
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 174 in pytest_runtest_call
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 242 in <lambda>
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 341 in from_call
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 241 in call_and_report
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 132 in runtestprotocol
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 362 in pytest_runtestloop
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 337 in _main
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 283 in wrap_session
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 330 in pytest_cmdline_main
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/config/__init__.py", line 175 in main
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/config/__init__.py", line 201 in console_main
File "/Users/saul/p/mlir-egglog/.venv/bin/pytest", line 10 in <module>
Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, black.const, black.mode, black.cache, black._width_table, blib2to3.pgen2.token, blib2to3.pgen2.grammar, blib2to3.pytree, black.strings, blib2to3.pgen2.tokenize, blib2to3.pgen2.parse, blib2to3.pgen2.pgen, blib2to3.pgen2.driver, blib2to3.pygram, black.nodes, black.comments, black.handle_ipynb_magics, black.brackets, black.lines, black.numerics, black.rusty, black.trans, black.linegen, black.parsing, black.ranges, black (total: 27)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment