Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created September 20, 2021 17:54
Show Gist options
  • Save bjacob/9bfc27e0f9dcf85d2c878c7063f0e083 to your computer and use it in GitHub Desktop.
Save bjacob/9bfc27e0f9dcf85d2c878c7063f0e083 to your computer and use it in GitHub Desktop.
#map0 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module {
func private @actual(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {noinline} {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
func private @expected(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> attributes {noinline} {
%0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%1 = mulf %arg3, %arg4 : f32
%2 = addf %1, %arg5 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
func @matmul_test() attributes {iree.abi.stub, iree.reflection = {MatmulTest = "entry"}} {
%c10 = constant 10 : index
%c10_0 = constant 10 : index
%0 = linalg.init_tensor [%c10, %c10_0] : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%8 = linalg.index 0 : index
%9 = linalg.index 1 : index
%10 = cmpi eq, %8, %9 : index
%cst = constant 0.000000e+00 : f32
%cst_5 = constant 1.000000e+00 : f32
%11 = select %10, %cst, %cst_5 : f32
linalg.yield %11 : f32
} -> tensor<?x?xf32>
%c10_1 = constant 10 : index
%c10_2 = constant 10 : index
%2 = linalg.init_tensor [%c10_1, %c10_2] : tensor<?x?xf32>
%3 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%8 = linalg.index 0 : index
%9 = linalg.index 1 : index
%10 = cmpi eq, %8, %9 : index
%cst = constant 0.000000e+00 : f32
%cst_5 = constant 1.000000e+00 : f32
%11 = select %10, %cst, %cst_5 : f32
linalg.yield %11 : f32
} -> tensor<?x?xf32>
%c10_3 = constant 10 : index
%c10_4 = constant 10 : index
%4 = linalg.init_tensor [%c10_3, %c10_4] : tensor<?x?xf32>
%5 = linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<?x?xf32>) outs(%4 : tensor<?x?xf32>) {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%8 = linalg.index 0 : index
%9 = linalg.index 1 : index
%10 = cmpi eq, %8, %9 : index
%cst = constant 0.000000e+00 : f32
%cst_5 = constant 1.000000e+00 : f32
%11 = select %10, %cst, %cst_5 : f32
linalg.yield %11 : f32
} -> tensor<?x?xf32>
%6 = call @actual(%1, %3, %5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%7 = call @expected(%1, %3, %5) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
check.expect_eq(%6, %7) : tensor<?x?xf32>
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment