Last active
May 1, 2025 12:06
-
-
Save saulshanabrook/51ba3040fcedb5fea4697ec7365c1a7c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/mlir_egglog/jit_engine.py b/src/mlir_egglog/jit_engine.py | |
index de9ac2f..e3b3aa0 100644 | |
--- a/src/mlir_egglog/jit_engine.py | |
+++ b/src/mlir_egglog/jit_engine.py | |
@@ -10,7 +10,6 @@ from egglog import RewriteOrRule, Ruleset | |
import llvmlite.binding as llvm | |
from mlir_egglog.llvm_runtime import ( | |
- create_execution_engine, | |
init_llvm, | |
compile_mod, | |
) | |
@@ -35,8 +34,6 @@ class JITEngine: | |
omppath = find_omp_path() | |
ctypes.CDLL(omppath, mode=os.RTLD_NOW) | |
- self.ee = create_execution_engine() | |
- | |
def run_frontend( | |
self, | |
fn: FunctionType, | |
@@ -65,11 +62,11 @@ class JITEngine: | |
llvm_ir += "\n" | |
mod = llvm.parse_assembly(llvm_ir) | |
- mod = compile_mod(self.ee, mod) | |
+ ee = compile_mod(mod) | |
# Resolve the function address | |
func_name = f"_mlir_ciface_{KERNEL_NAME}" | |
- address = self.ee.get_function_address(func_name) | |
+ address = ee.get_function_address(func_name) | |
assert address, "Function must be compiled successfully." | |
return address | |
diff --git a/src/mlir_egglog/llvm_runtime.py b/src/mlir_egglog/llvm_runtime.py | |
index 7a2310c..2728f9f 100644 | |
--- a/src/mlir_egglog/llvm_runtime.py | |
+++ b/src/mlir_egglog/llvm_runtime.py | |
@@ -11,22 +11,21 @@ def init_llvm(): | |
llvm.initialize_all_asmprinters() | |
-def compile_mod(engine, mod): | |
+def compile_mod(mod): | |
mod.verify() | |
- engine.add_module(mod) | |
- engine.finalize_object() | |
- engine.run_static_constructors() | |
- return mod | |
+ engine = create_execution_engine(mod) | |
+ # engine.finalize_object() | |
+ # engine.run_static_constructors() | |
+ return engine | |
-def create_execution_engine(): | |
+def create_execution_engine(mod): | |
+ with llvm.create_module_pass_manager() as pm: | |
+ with llvm.create_pass_manager_builder() as pmb: | |
+ pmb.populate(pm) | |
+ pm.run(mod) | |
+ print("MOD!\n", mod, "done\n\n\n") | |
target = llvm.Target.from_default_triple() | |
target_machine = target.create_target_machine() | |
- backing_mod = llvm.parse_assembly("") | |
- engine = llvm.create_mcjit_compiler(backing_mod, target_machine) | |
+ engine = llvm.create_mcjit_compiler(mod, target_machine) | |
return engine | |
- | |
- | |
-def compile_ir(engine, llvm_ir): | |
- mod = llvm.parse_assembly(llvm_ir) | |
- return compile_mod(engine, mod) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ uv run pytest tests/test_basic_expressions.py::TestBasicExpressions::test_relu_function -s | |
======================================================================================== test session starts ======================================================================================== | |
platform darwin -- Python 3.12.5, pytest-8.3.5, pluggy-1.5.0 | |
rootdir: /Users/saul/p/mlir-egglog | |
configfile: pyproject.toml | |
collected 1 item | |
tests/test_basic_expressions.py Generated MLIR: | |
func.func @kernel_worker( | |
%arg0: memref<?xf32>, | |
%arg1: memref<?xf32> | |
) attributes {llvm.emit_c_interface} { | |
%c0 = index.constant 0 | |
// Get dimension of input array | |
%dim = memref.dim %arg0, %c0 : memref<?xf32> | |
// Process each element in a flattened manner | |
affine.for %idx = %c0 to %dim { | |
%arg_x = affine.load %arg0[%idx] : memref<?xf32> | |
%v1 = arith.constant 0.000000e+00 : f32 | |
%v2 = arith.maximumf %arg_x, %v1 : f32 | |
affine.store %v2, %arg1[%idx] : memref<?xf32> | |
} | |
return | |
} | |
0.44.0 | |
; ModuleID = 'LLVMDialectModule' | |
source_filename = "LLVMDialectModule" | |
define void @kernel_worker(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i64 %7, i64 %8, i64 %9) { | |
%11 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %5, 0 | |
%12 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %11, ptr %6, 1 | |
%13 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %12, i64 %7, 2 | |
%14 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %13, i64 %8, 3, 0 | |
%15 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %14, i64 %9, 4, 0 | |
%16 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } undef, ptr %0, 0 | |
%17 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %16, ptr %1, 1 | |
%18 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, i64 %2, 2 | |
%19 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %18, i64 %3, 3, 0 | |
%20 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %19, i64 %4, 4, 0 | |
%21 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %20, 3, 0 | |
br label %22 | |
22: ; preds = %25, %10 | |
%23 = phi i64 [ %32, %25 ], [ 0, %10 ] | |
%24 = icmp slt i64 %23, %21 | |
br i1 %24, label %25, label %33 | |
25: ; preds = %22 | |
%26 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %20, 1 | |
%27 = getelementptr float, ptr %26, i64 %23 | |
%28 = load float, ptr %27, align 4 | |
%29 = call float @llvm.maximum.f32(float %28, float 0.000000e+00) | |
%30 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %15, 1 | |
%31 = getelementptr float, ptr %30, i64 %23 | |
store float %29, ptr %31, align 4 | |
%32 = add i64 %23, 1 | |
br label %22 | |
33: ; preds = %22 | |
ret void | |
} | |
define void @_mlir_ciface_kernel_worker(ptr %0, ptr %1) { | |
%3 = load { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, align 8 | |
%4 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 0 | |
%5 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 1 | |
%6 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 2 | |
%7 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 3, 0 | |
%8 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %3, 4, 0 | |
%9 = load { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, align 8 | |
%10 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 0 | |
%11 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 1 | |
%12 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 2 | |
%13 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 3, 0 | |
%14 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %9, 4, 0 | |
call void @kernel_worker(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, ptr %10, ptr %11, i64 %12, i64 %13, i64 %14) | |
ret void | |
} | |
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) | |
declare float @llvm.maximum.f32(float, float) #0 | |
attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } | |
!llvm.module.flags = !{!0} | |
!0 = !{i32 2, !"Debug Info Version", i32 3} | |
Parsing LLVM assembly. | |
MOD! | |
; ModuleID = '<string>' | |
source_filename = "LLVMDialectModule" | |
; Function Attrs: argmemonly nofree nosync nounwind | |
define void @kernel_worker(ptr nocapture readnone %0, ptr nocapture readonly %1, i64 %2, i64 %3, i64 %4, ptr nocapture readnone %5, ptr nocapture writeonly %6, i64 %7, i64 %8, i64 %9) local_unnamed_addr #0 { | |
%11 = icmp sgt i64 %3, 0 | |
br i1 %11, label %.lr.ph, label %._crit_edge | |
.lr.ph: ; preds = %10, %.lr.ph | |
%12 = phi i64 [ %17, %.lr.ph ], [ 0, %10 ] | |
%13 = getelementptr float, ptr %1, i64 %12 | |
%14 = load float, ptr %13, align 4 | |
%15 = tail call float @llvm.maximum.f32(float %14, float 0.000000e+00) | |
%16 = getelementptr float, ptr %6, i64 %12 | |
store float %15, ptr %16, align 4 | |
%17 = add nuw nsw i64 %12, 1 | |
%18 = icmp slt i64 %17, %3 | |
br i1 %18, label %.lr.ph, label %._crit_edge | |
._crit_edge: ; preds = %.lr.ph, %10 | |
ret void | |
} | |
; Function Attrs: nofree nosync nounwind | |
define void @_mlir_ciface_kernel_worker(ptr nocapture readonly %0, ptr nocapture readonly %1) local_unnamed_addr #1 { | |
%.unpack = load ptr, ptr %0, align 8 | |
%.elt1 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 1 | |
%.unpack2 = load ptr, ptr %.elt1, align 8 | |
%.elt3 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 2 | |
%.unpack4 = load i64, ptr %.elt3, align 8 | |
%.elt5 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 3 | |
%.unpack6.unpack = load i64, ptr %.elt5, align 8 | |
%.elt7 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, i64 0, i32 4 | |
%.unpack8.unpack = load i64, ptr %.elt7, align 8 | |
%.unpack11 = load ptr, ptr %1, align 8 | |
%.elt12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 1 | |
%.unpack13 = load ptr, ptr %.elt12, align 8 | |
%.elt14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 2 | |
%.unpack15 = load i64, ptr %.elt14, align 8 | |
%.elt16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 3 | |
%.unpack17.unpack = load i64, ptr %.elt16, align 8 | |
%.elt18 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %1, i64 0, i32 4 | |
%.unpack19.unpack = load i64, ptr %.elt18, align 8 | |
tail call void @kernel_worker(ptr %.unpack, ptr %.unpack2, i64 %.unpack4, i64 %.unpack6.unpack, i64 %.unpack8.unpack, ptr %.unpack11, ptr %.unpack13, i64 %.unpack15, i64 %.unpack17.unpack, i64 %.unpack19.unpack) | |
ret void | |
} | |
; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn | |
declare float @llvm.maximum.f32(float, float) #2 | |
attributes #0 = { argmemonly nofree nosync nounwind } | |
attributes #1 = { nofree nosync nounwind } | |
attributes #2 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn } | |
!llvm.module.flags = !{!0} | |
!0 = !{i32 2, !"Debug Info Version", i32 3} | |
done | |
LLVM ERROR: Cannot select: 0x7fb51609d4d8: f32 = fmaximum 0x7fb51609dcf8, ConstantFP:f32<0.000000e+00> | |
0x7fb51609dcf8: f32,ch = load<(load (s32) from %ir.uglygep1)> 0x7fb515f56308, 0x7fb51609de98, undef:i64 | |
0x7fb51609de98: i64 = add 0x7fb51609dbc0, 0x7fb51609ddc8 | |
0x7fb51609dbc0: i64,ch = CopyFromReg 0x7fb515f56308, Register:i64 %3 | |
0x7fb51609d338: i64 = Register %3 | |
0x7fb51609ddc8: i64 = shl 0x7fb51609d268, Constant:i8<2> | |
0x7fb51609d268: i64,ch = CopyFromReg 0x7fb515f56308, Register:i64 %0 | |
0x7fb51609dc28: i64 = Register %0 | |
0x7fb51609d2d0: i8 = Constant<2> | |
0x7fb51609daf0: i64 = undef | |
0x7fb51609d7b0: f32 = ConstantFP<0.000000e+00> | |
In function: kernel_worker | |
Fatal Python error: Aborted | |
Current thread 0x00007ff85f1fddc0 (most recent call first): | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/llvmlite/binding/ffi.py", line 197 in __call__ | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/llvmlite/binding/executionengine.py", line 71 in get_function_address | |
File "/Users/saul/p/mlir-egglog/src/mlir_egglog/jit_engine.py", line 69 in run_backend | |
File "/Users/saul/p/mlir-egglog/src/mlir_egglog/jit_engine.py", line 85 in jit_compile | |
File "/Users/saul/p/mlir-egglog/tests/test_basic_expressions.py", line 138 in test_relu_function | |
File "/usr/local/Cellar/[email protected]/3.12.5/Frameworks/Python.framework/Versions/3.12/lib/python3.12/unittest/case.py", line 589 in _callTestMethod | |
File "/usr/local/Cellar/[email protected]/3.12.5/Frameworks/Python.framework/Versions/3.12/lib/python3.12/unittest/case.py", line 634 in run | |
File "/usr/local/Cellar/[email protected]/3.12.5/Frameworks/Python.framework/Versions/3.12/lib/python3.12/unittest/case.py", line 690 in __call__ | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/unittest.py", line 351 in runtest | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 174 in pytest_runtest_call | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__ | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 242 in <lambda> | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 341 in from_call | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 241 in call_and_report | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 132 in runtestprotocol | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__ | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 362 in pytest_runtestloop | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__ | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 337 in _main | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 283 in wrap_session | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/main.py", line 330 in pytest_cmdline_main | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__ | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/config/__init__.py", line 175 in main | |
File "/Users/saul/p/mlir-egglog/.venv/lib/python3.12/site-packages/_pytest/config/__init__.py", line 201 in console_main | |
File "/Users/saul/p/mlir-egglog/.venv/bin/pytest", line 10 in <module> | |
Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, black.const, black.mode, black.cache, black._width_table, blib2to3.pgen2.token, blib2to3.pgen2.grammar, blib2to3.pytree, black.strings, blib2to3.pgen2.tokenize, blib2to3.pgen2.parse, blib2to3.pgen2.pgen, blib2to3.pgen2.driver, blib2to3.pygram, black.nodes, black.comments, black.handle_ipynb_magics, black.brackets, black.lines, black.numerics, black.rusty, black.trans, black.linegen, black.parsing, black.ranges, black (total: 27) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment