Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created March 14, 2023 15:56
Show Gist options
  • Save pashu123/9976376b2e387f1cb5d0493ba7f531f5 to your computer and use it in GitHub Desktop.
Save pashu123/9976376b2e387f1cb5d0493ba7f531f5 to your computer and use it in GitHub Desktop.
from iree import runtime as ireert
from iree.compiler import compile_str
import numpy as np
LINALG_IR = '''
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @forward(%arg0: tensor<1x6x32x64x2xf32>) -> tensor<1x6x32x64xcomplex<f32>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<1x6x32x64xcomplex<f32>>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%0 : tensor<1x6x32x64xcomplex<f32>>) {
^bb0(%out: complex<f32>):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = linalg.index 2 : index
%5 = linalg.index 3 : index
%extracted = tensor.extract %arg0[%2, %3, %4, %5, %c0] : tensor<1x6x32x64x2xf32>
%extracted_0 = tensor.extract %arg0[%2, %3, %4, %5, %c1] : tensor<1x6x32x64x2xf32>
%8 = complex.create %extracted, %extracted_0 : complex<f32>
linalg.yield %8 : complex<f32>
} -> tensor<1x6x32x64xcomplex<f32>>
return %1 : tensor<1x6x32x64xcomplex<f32>>
}
}
'''
backend = "llvm-cpu"
args = ["--iree-llvmcpu-target-cpu-features=host"]
backend_config = "local-task"
flatbuffer_blob = compile_str(LINALG_IR, target_backends=[backend], extra_args=args)
config = ireert.Config("local-sync")
vm_module = ireert.VmModule.from_flatbuffer(config.vm_instance, flatbuffer_blob)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
complex_compiled = ctx.modules.module
input = np.random.rand(1,6,32,64,2).astype(np.float32)
x = complex_compiled.forward(input)
print(x.to_host())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment