Skip to content

Instantly share code, notes, and snippets.

@saulshanabrook
Created April 29, 2025 10:30
Show Gist options
  • Save saulshanabrook/e6daf6615d0a2f65055d23f477eadd9d to your computer and use it in GitHub Desktop.
Save saulshanabrook/e6daf6615d0a2f65055d23f477eadd9d to your computer and use it in GitHub Desktop.
$ git diff
diff --git a/src/mlir_egglog/mlir_backend.py b/src/mlir_egglog/mlir_backend.py
index 6ab6867..91f3f53 100644
--- a/src/mlir_egglog/mlir_backend.py
+++ b/src/mlir_egglog/mlir_backend.py
@@ -49,6 +49,7 @@ COMMON_INITIAL_OPTIONS = (
"-convert-vector-to-scf",
"-convert-linalg-to-loops",
"-lower-affine",
+ "--convert-arith-to-llvm",
)
# OpenMP lowering sequence
@@ -136,13 +137,13 @@ class MLIRCompiler:
assert out_mode in "tb"
with (
- NamedTemporaryFile(mode=f"w{in_mode}") as src_file,
- NamedTemporaryFile(mode=f"r{out_mode}") as out_file,
+ NamedTemporaryFile(mode=f"w{in_mode}", delete=False, delete_on_close=False) as src_file,
+ NamedTemporaryFile(mode=f"r{out_mode}", delete=False, delete_on_close=False) as out_file,
):
src_file.write(src)
src_file.flush()
-
shell_cmd = *cmd, src_file.name, "-o", out_file.name
+ print(" ".join(shell_cmd))
if self._debug:
print(shell_cmd)
subprocess.run(shell_cmd)
$ mlir-opt --version
Homebrew LLVM version 20.1.3
Optimized build.
$ uv run pytest -x tests/test_basic_expressions.py::TestBasicExpressions::test_arithmetic_expression
==================================================================================== test session starts ====================================================================================
platform darwin -- Python 3.12.5, pytest-8.3.5, pluggy-1.5.0
rootdir: /Users/saul/p/mlir-egglog
configfile: pyproject.toml
collected 1 item
tests/test_basic_expressions.py F
========================================================================================= FAILURES ==========================================================================================
______________________________________________________________________ TestBasicExpressions.test_arithmetic_expression ______________________________________________________________________
self = <test_basic_expressions.TestBasicExpressions testMethod=test_arithmetic_expression>
def test_arithmetic_expression(self):
def arithmetic_fn(x):
return x * 2.0 + 1.0
# Test frontend compilation (MLIR generation)
mlir_code = compile(arithmetic_fn, debug=True)
self.assertIn("arith.mulf", mlir_code)
self.assertIn("arith.addf", mlir_code)
# Test full pipeline compilation
jit = JITEngine()
> func_addr = jit.jit_compile(arithmetic_fn)
tests/test_basic_expressions.py:23:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/mlir_egglog/jit_engine.py:86: in jit_compile
address = self.run_backend(mlir)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <mlir_egglog.jit_engine.JITEngine object at 0x106ec8800>
mlir_src = '\nfunc.func @kernel_worker(\n %arg0: memref<?xf32>,\n %arg1: memref<?xf32>\n) attributes {llvm.emit_c_interface... %v4 = arith.addf %v3, %v2 : f32\n affine.store %v4, %arg1[%idx] : memref<?xf32>\n }\n return\n}\n'
def run_backend(self, mlir_src: str) -> bytes:
mlir_compiler = MLIRCompiler(debug=False)
mlir_omp = mlir_compiler.to_llvm_dialect(mlir_src, target=Target.BASIC_LOOPS)
llvm_ir = mlir_compiler.mlir_translate_to_llvm_ir(mlir_omp)
print(llvm_ir)
print("Parsing LLVM assembly.")
# try:
# Clean up the LLVM IR by ensuring proper line endings and formatting
llvm_ir = llvm_ir.strip()
# Clean up problematic attribute strings (hack for divergence in modern LLVM IR syntax with old llvmlite)
llvm_ir = llvm_ir.replace("captures(none)", " ")
llvm_ir = llvm_ir.replace("memory(argmem: readwrite)", "")
llvm_ir = llvm_ir.replace("memory(none)", "")
llvm_ir += "\n"
# mod = llvm.parse_assembly(llvm_ir)
# mod = compile_mod(self.ee, mod)
# Resolve the function address
func_name = f"_mlir_ciface_{KERNEL_NAME}"
address = self.ee.get_function_address(func_name)
> assert address, "Function must be compiled successfully."
E AssertionError: Function must be compiled successfully.
src/mlir_egglog/jit_engine.py:72: AssertionError
----------------------------------------------------------------------------------- Captured stdout call ------------------------------------------------------------------------------------
0.44.0
mlir-opt --debugify-level=locations --snapshot-op-locations --inline -affine-loop-normalize -affine-parallelize -affine-super-vectorize --affine-scalrep -lower-affine -convert-vector-to-scf -convert-linalg-to-loops -lower-affine --convert-arith-to-llvm -convert-scf-to-cf -cse -convert-vector-to-llvm -convert-math-to-llvm -expand-strided-metadata -finalize-memref-to-llvm -convert-func-to-llvm -convert-index-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts --llvm-request-c-wrappers /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpbot3e3rm -o /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpna4wdae8
mlir-translate --mlir-print-local-scope --mlir-print-debuginfo=false --print-after-all --mlir-to-llvmir --verify-diagnostics /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpsg0qr32j -o /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpkh9a08d3
Parsing LLVM assembly.
----------------------------------------------------------------------------------- Captured stderr call ------------------------------------------------------------------------------------
/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpsg0qr32j:30:11: error: unexpected error: Dialect `arith' not found for custom op 'arith.cmpi'
%25 = arith.cmpi slt, %24, %22 : index
^
/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpsg0qr32j:30:11: error: unexpected note: Registered dialects: acc, amx, arm_neon, arm_sme, arm_sve, builtin, dlti, func, gpu, llvm, nvvm, omp, rocdl, spirv, vcix, x86vector ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management
%25 = arith.cmpi slt, %24, %22 : index
^
================================================================================== short test summary info ==================================================================================
FAILED tests/test_basic_expressions.py::TestBasicExpressions::test_arithmetic_expression - AssertionError: Function must be compiled successfully.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
===================================================================================== 1 failed in 0.63s =====================================================================================
$ cat /var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/tmpna4wdae8
module {
llvm.func @kernel_worker(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64) attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.insertvalue %arg5, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.insertvalue %arg6, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.insertvalue %arg7, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.insertvalue %arg8, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.insertvalue %arg9, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.insertvalue %arg4, %10[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%12 = llvm.mlir.constant(0 : index) : i64
%13 = llvm.mlir.constant(2.000000e+00 : f32) : f32
%14 = llvm.mlir.constant(1.000000e+00 : f32) : f32
%15 = llvm.mlir.constant(1 : index) : i64
%16 = builtin.unrealized_conversion_cast %15 : i64 to index
%17 = llvm.mlir.constant(1 : index) : i64
%18 = llvm.extractvalue %11[3] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%19 = llvm.alloca %17 x !llvm.array<1 x i64> : (i64) -> !llvm.ptr
llvm.store %18, %19 : !llvm.array<1 x i64>, !llvm.ptr
%20 = llvm.getelementptr %19[0, %12] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<1 x i64>
%21 = llvm.load %20 : !llvm.ptr -> i64
%22 = builtin.unrealized_conversion_cast %21 : i64 to index
llvm.br ^bb1(%12 : i64)
^bb1(%23: i64): // 2 preds: ^bb0, ^bb2
%24 = builtin.unrealized_conversion_cast %23 : i64 to index
%25 = arith.cmpi slt, %24, %22 : index
llvm.cond_br %25, ^bb2, ^bb3
^bb2: // pred: ^bb1
%26 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%27 = llvm.getelementptr %26[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%28 = llvm.load %27 : !llvm.ptr -> f32
%29 = llvm.fmul %28, %13 : f32
%30 = llvm.fadd %29, %14 : f32
%31 = llvm.extractvalue %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%32 = llvm.getelementptr %31[%23] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %30, %32 : f32, !llvm.ptr
%33 = arith.addi %24, %16 : index
%34 = builtin.unrealized_conversion_cast %33 : index to i64
llvm.br ^bb1(%34 : i64)
^bb3: // pred: ^bb1
llvm.return
}
llvm.func @_mlir_ciface_kernel_worker(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {llvm.emit_c_interface} {
%0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.extractvalue %0[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.extractvalue %0[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%7 = llvm.extractvalue %6[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%8 = llvm.extractvalue %6[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%9 = llvm.extractvalue %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%10 = llvm.extractvalue %6[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%11 = llvm.extractvalue %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
llvm.call @kernel_worker(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
llvm.return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment