Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 27, 2024 16:17
Show Gist options
  • Save pashu123/8f0eb3a85643bc711963524992d4679c to your computer and use it in GitHub Desktop.
Save pashu123/8f0eb3a85643bc711963524992d4679c to your computer and use it in GitHub Desktop.
util.func public @matmul_broad(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
%cst = arith.constant 0.000000e+00 : f16
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
%3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
%4 = tensor.empty(%0) : tensor<?x8640x3200xf16>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3 : tensor<8640x3200xf16>) outs(%4 : tensor<?x8640x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x8640x3200xf16>
%6 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%dim = tensor.dim %2, %c0 : tensor<?x?x3200xf32>
%dim_1 = tensor.dim %2, %c1 : tensor<?x?x3200xf32>
%7 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_1]
%8 = tensor.empty(%dim, %7) : tensor<?x?x3200x16x1xf32>
%pack = tensor.pack %2 padding_value(%cst_0 : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %8 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
%9 = tensor.empty(%0) : tensor<?x540x3200x16x1xf16>
%pack_2 = tensor.pack %5 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %9 : tensor<?x8640x3200xf16> -> tensor<?x540x3200x16x1xf16>
%10 = affine.apply affine_map<()[s0, s1, s2] -> (-s1 + s2 + (s1 ceildiv s0) * s0)>()[%c1, %0, %0]
%11 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
%12 = tensor.empty(%10, %11) : tensor<?x?x540x16x16xf32>
%13 = linalg.fill ins(%cst_0 : f32) outs(%12 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%14 = linalg.batch_mmt4d ins(%pack, %pack_2 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%13 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%15 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%unpack = tensor.unpack %14 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %15 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%unpack : tensor<?x?x8640xf32>) outs(%6 : tensor<?x?x8640xf32>) {
^bb0(%in: f32, %out: f32):
%18 = arith.maximumf %in, %cst_0 : f32
linalg.yield %18 : f32
} -> tensor<?x?x8640xf32>
%17 = hal.tensor.export %16 "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
util.return %17 : !hal.buffer_view
}
util.func public @matmul_broad_cast_relu(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @matmul_broad(%input0: tensor<?x?x3200xf32>, %input1: tensor<8640x3200xf16>) -> (%output0: tensor<?x?x8640xf32>)"}} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<?x?x3200xf32>{%0, %1}
%3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<8640x3200xf16>
%4 = tensor.empty(%0) : tensor<?x540x3200x16x1xf16>
%5 = tensor.empty() : tensor<540x3200x16x1xf16>
%pack = tensor.pack %3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %5 : tensor<8640x3200xf16> -> tensor<540x3200x16x1xf16>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack : tensor<540x3200x16x1xf16>) outs(%4 : tensor<?x540x3200x16x1xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x540x3200x16x1xf16>
%7 = tensor.empty(%0, %1) : tensor<?x?x8640xf32>
%8 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%1]
%9 = tensor.empty(%0, %8) : tensor<?x?x3200x16x1xf32>
%pack_0 = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %9 : tensor<?x?x3200xf32> -> tensor<?x?x3200x16x1xf32>
%10 = tensor.empty(%0, %8) : tensor<?x?x540x16x16xf32>
%11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%12 = linalg.batch_mmt4d ins(%pack_0, %6 : tensor<?x?x3200x16x1xf32>, tensor<?x540x3200x16x1xf16>) outs(%11 : tensor<?x?x540x16x16xf32>) -> tensor<?x?x540x16x16xf32>
%unpack = tensor.unpack %12 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %7 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%dim = tensor.dim %7, %c0 : tensor<?x?x8640xf32>
%dim_1 = tensor.dim %7, %c1 : tensor<?x?x8640xf32>
%13 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_1]
%14 = tensor.empty(%dim, %13) : tensor<?x?x540x16x16xf32>
%dim_2 = tensor.dim %unpack, %c0 : tensor<?x?x8640xf32>
%dim_3 = tensor.dim %unpack, %c1 : tensor<?x?x8640xf32>
%15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 16)>()[%dim_3]
%16 = tensor.empty(%dim_2, %15) : tensor<?x?x540x16x16xf32>
%pack_4 = tensor.pack %unpack outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %16 : tensor<?x?x8640xf32> -> tensor<?x?x540x16x16xf32>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%pack_4 : tensor<?x?x540x16x16xf32>) outs(%14 : tensor<?x?x540x16x16xf32>) {
^bb0(%in: f32, %out: f32):
%21 = arith.maximumf %in, %cst : f32
linalg.yield %21 : f32
} -> tensor<?x?x540x16x16xf32>
%dim_5 = tensor.dim %17, %c0 : tensor<?x?x540x16x16xf32>
%dim_6 = tensor.dim %17, %c1 : tensor<?x?x540x16x16xf32>
%18 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%dim_6]
%19 = tensor.empty(%dim_5, %18) : tensor<?x?x8640xf32>
%unpack_7 = tensor.unpack %17 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %19 : tensor<?x?x540x16x16xf32> -> tensor<?x?x8640xf32>
%20 = hal.tensor.export %unpack_7 "output0" : tensor<?x?x8640xf32>{%0, %1} -> !hal.buffer_view
util.return %20 : !hal.buffer_view
}
//#map = affine_map<(d0, d1, d2) -> (d1, d2)>
//#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
//module {
// util.func public @broadcast_matmul_relu(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
// %cst = arith.constant 0.000000e+00 : f32
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// %dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
// %0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
// %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
// ^bb0(%in: f16, %out: f16):
// linalg.yield %in : f16
// } -> tensor<?x8640x3200xf16>
// %2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
// %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
// %4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
// %5 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<?x?x8640xf32>) outs(%2 : tensor<?x?x8640xf32>) {
// ^bb0(%in: f32, %out: f32):
// %6 = arith.maximumf %in, %cst : f32
// linalg.yield %6 : f32
// } -> tensor<?x?x8640xf32>
// util.return %5 : tensor<?x?x8640xf32>
// }
//}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment