Created
March 27, 2021 19:33
-
-
Save pashu123/4276bcfe9e352a3c4a32f9a1c1948ada to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import ctypes | |
import gc, sys | |
from mlir.ir import * | |
from mlir.passmanager import * | |
from mlir.execution_engine import * | |
class MemRefDescriptor(ctypes.Structure): | |
""" Creates a ctype struct for memref descriptor""" | |
_fields_ = [ | |
("allocated", ctypes.c_longlong), | |
("aligned", ctypes.POINTER(ctypes.c_float)), | |
("offset", ctypes.c_longlong), | |
("sizes", ctypes.c_longlong * 1), | |
("strides", ctypes.c_longlong * 1), | |
] | |
# Reference: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ctypes.html | |
def npToCtype(np_array): | |
x = MemRefDescriptor() | |
x.allocated = np_array.ctypes.data | |
x.aligned = np_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) | |
x.offset = ctypes.c_longlong(np_array.dtype.itemsize) | |
x.sizes = np_array.ctypes.shape | |
x.strides = np_array.ctypes.strides | |
return ctypes.pointer(x) | |
# Log everything to stderr and flush so that we have a unified stream to match | |
# errors/info emitted by MLIR to stderr. | |
def log(*args): | |
print(*args, file=sys.stderr) | |
sys.stderr.flush() | |
def run(f): | |
log("\nTEST:", f.__name__) | |
f() | |
gc.collect() | |
assert Context._get_live_count() == 0 | |
def lowerToLLVM(module): | |
import mlir.conversions | |
pm = PassManager.parse("convert-std-to-llvm") | |
pm.run(module) | |
return module | |
def testInvokeMemrefAdd(): | |
with Context(): | |
module = Module.parse( | |
""" | |
module { | |
func @main(%arg0: memref<1xf32>, %arg1: memref<1xf32>) attributes { llvm.emit_c_interface } { | |
%0 = constant 0 : index | |
%1 = memref.load %arg0[%0] : memref<1xf32> | |
%2 = memref.load %arg0[%0] : memref<1xf32> | |
%3 = addf %1, %2 : f32 | |
memref.store %3, %arg1[%0] : memref<1xf32> | |
return | |
} | |
} """ | |
) | |
inp_arr = np.random.rand(1).astype(np.float32) | |
res_arr = np.random.rand(1).astype(np.float32) | |
inp_ctype = npToCtype(inp_arr) | |
res_ctype = npToCtype(res_arr) | |
execution_engine = ExecutionEngine(lowerToLLVM(module)) | |
execution_engine.invoke("main", inp_ctype, res_ctype) | |
run(testInvokeMemrefAdd) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment