Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 13, 2024 13:06
Show Gist options
  • Save pashu123/b0297910063e17b6100496261e18823b to your computer and use it in GitHub Desktop.
Save pashu123/b0297910063e17b6100496261e18823b to your computer and use it in GitHub Desktop.
func.func @matmul_broad_dispatch_2_batch_mmt4d_DxDx540x3200x16x16x1_f32xf16xf32() attributes {translation_info = #iree_codegen.translation_info<Mmt4dTilingExpert>} {
%c1 = arith.constant 1 : index
%c3200 = arith.constant 3200 : index
%c540 = arith.constant 540 : index
%c55296000 = arith.constant 55296000 : index
%c0 = arith.constant 0 : index
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[4] : i32
%3 = hal.interface.constant.load[5] : i32
%4 = hal.interface.constant.load[6] : i32
%5 = hal.interface.constant.load[7] : i32
%6 = arith.extui %0 : i32 to i64
%7 = arith.extui %1 : i32 to i64
%8 = arith.shli %7, %c32_i64 : i64
%9 = arith.ori %6, %8 : i64
%10 = arith.index_castui %9 : i64 to index
%11 = arith.extui %2 : i32 to i64
%12 = arith.extui %3 : i32 to i64
%13 = arith.shli %12, %c32_i64 : i64
%14 = arith.ori %11, %13 : i64
%15 = arith.index_castui %14 : i64 to index
%16 = arith.extui %4 : i32 to i64
%17 = arith.extui %5 : i32 to i64
%18 = arith.shli %17, %c32_i64 : i64
%19 = arith.ori %16, %18 : i64
%20 = arith.index_castui %19 : i64 to index
%21 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<540x3200x16x1xf16>>
%22 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c55296000) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x3200x16x1xf32>>{%15, %20}
%23 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%10) : !flow.dispatch.tensor<writeonly:tensor<?x?x540x16x16xf32>>{%15, %20}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index
scf.for %arg0 = %workgroup_id_z to %15 step %workgroup_count_z {
scf.for %arg1 = %workgroup_id_y to %20 step %workgroup_count_y {
scf.for %arg2 = %workgroup_id_x to %c540 step %workgroup_count_x {
%24 = flow.dispatch.tensor.load %23, offsets = [%arg0, %arg1, %arg2, 0, 0], sizes = [1, 1, 1, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?x540x16x16xf32>>{%15, %20} -> tensor<1x1x1x16x16xf32>
%25 = flow.dispatch.tensor.load %22, offsets = [%arg0, %arg1, 0, 0, 0], sizes = [1, 1, 3200, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x3200x16x1xf32>>{%15, %20} -> tensor<1x1x3200x16x1xf32>
%26 = flow.dispatch.tensor.load %21, offsets = [%arg2, 0, 0, 0], sizes = [1, 3200, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<540x3200x16x1xf16>> -> tensor<1x3200x16x1xf16>
%27 = tensor.empty() : tensor<1x1x3200x16x1xf16>
%28 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%26 : tensor<1x3200x16x1xf16>) outs(%27 : tensor<1x1x3200x16x1xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 16], [0, 0, 0, 0, 0], [0, 0, 0, 1, 0]]>} {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<1x1x3200x16x1xf16>
%29 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 16], [0, 0, 0, 0, 0], [0, 0, 0, 1, 0]]>} ins(%cst : f32) outs(%24 : tensor<1x1x1x16x16xf32>) -> tensor<1x1x1x16x16xf32>
%30 = scf.for %arg3 = %c0 to %c3200 step %c1 iter_args(%arg4 = %29) -> (tensor<1x1x1x16x16xf32>) {
%extracted_slice = tensor.extract_slice %25[0, 0, %arg3, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<1x1x3200x16x1xf32> to tensor<1x1x1x16x1xf32>
%extracted_slice_0 = tensor.extract_slice %28[0, 0, %arg3, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<1x1x3200x16x1xf16> to tensor<1x1x1x16x1xf16>
%31 = linalg.batch_mmt4d {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 16, 16, 0], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0]]>} ins(%extracted_slice, %extracted_slice_0 : tensor<1x1x1x16x1xf32>, tensor<1x1x1x16x1xf16>) outs(%arg4 : tensor<1x1x1x16x16xf32>) -> tensor<1x1x1x16x16xf32>
scf.yield %31 : tensor<1x1x1x16x16xf32>
}
flow.dispatch.tensor.store %30, %23, offsets = [%arg0, %arg1, %arg2, 0, 0], sizes = [1, 1, 1, 16, 16], strides = [1, 1, 1, 1, 1] : tensor<1x1x1x16x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?x540x16x16xf32>>{%15, %20}
}
}
}
return
}
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
#translation = #iree_codegen.translation_info<Mmt4dTilingExpert>
module {
func.func @matmul_broad_dispatch_2_batch_mmt4d_DxDx540x3200x16x16x1_f32xf16xf32() attributes {translation_info = #translation} {
%cst = arith.constant dense<0.000000e+00> : vector<1x1x1x16x16xf32>
%cst_0 = arith.constant 0.000000e+00 : f16
%c1 = arith.constant 1 : index
%c3200 = arith.constant 3200 : index
%c540 = arith.constant 540 : index
%c55296000 = arith.constant 55296000 : index
%c0 = arith.constant 0 : index
%c32_i64 = arith.constant 32 : i64
%cst_1 = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[4] : i32
%3 = hal.interface.constant.load[5] : i32
%4 = hal.interface.constant.load[6] : i32
%5 = hal.interface.constant.load[7] : i32
%6 = arith.extui %0 : i32 to i64
%7 = arith.extui %1 : i32 to i64
%8 = arith.shli %7, %c32_i64 : i64
%9 = arith.ori %6, %8 : i64
%10 = arith.index_castui %9 : i64 to index
%11 = arith.extui %2 : i32 to i64
%12 = arith.extui %3 : i32 to i64
%13 = arith.shli %12, %c32_i64 : i64
%14 = arith.ori %11, %13 : i64
%15 = arith.index_castui %14 : i64 to index
%16 = arith.extui %4 : i32 to i64
%17 = arith.extui %5 : i32 to i64
%18 = arith.shli %17, %c32_i64 : i64
%19 = arith.ori %16, %18 : i64
%20 = arith.index_castui %19 : i64 to index
%21 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<540x3200x16x1xf16>>
%22 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c55296000) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<?x?x3200x16x1xf32>>{%15, %20}
%23 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%10) : !flow.dispatch.tensor<writeonly:tensor<?x?x540x16x16xf32>>{%15, %20}
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index
scf.for %arg0 = %workgroup_id_z to %15 step %workgroup_count_z {
scf.for %arg1 = %workgroup_id_y to %20 step %workgroup_count_y {
scf.for %arg2 = %workgroup_id_x to %c540 step %workgroup_count_x {
%24 = flow.dispatch.tensor.load %23, offsets = [%arg0, %arg1, %arg2, 0, 0], sizes = [1, 1, 1, 16, 16], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<?x?x540x16x16xf32>>{%15, %20} -> tensor<1x1x1x16x16xf32>
%25 = flow.dispatch.tensor.load %22, offsets = [%arg0, %arg1, 0, 0, 0], sizes = [1, 1, 3200, 16, 1], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x3200x16x1xf32>>{%15, %20} -> tensor<1x1x3200x16x1xf32>
%26 = flow.dispatch.tensor.load %21, offsets = [%arg2, 0, 0, 0], sizes = [1, 3200, 16, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<540x3200x16x1xf16>> -> tensor<1x3200x16x1xf16>
%27 = tensor.empty() : tensor<1x1x3200x16x1xf16>
%28 = vector.transfer_read %26[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true, true]} : tensor<1x3200x16x1xf16>, vector<1x3200x16x1xf16>
%29 = vector.broadcast %28 : vector<1x3200x16x1xf16> to vector<1x1x3200x16x1xf16>
%30 = vector.transfer_write %29, %27[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true]} : vector<1x1x3200x16x1xf16>, tensor<1x1x3200x16x1xf16>
%31 = vector.transfer_write %cst, %24[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true]} : vector<1x1x1x16x16xf32>, tensor<1x1x1x16x16xf32>
%32 = scf.for %arg3 = %c0 to %c3200 step %c1 iter_args(%arg4 = %31) -> (tensor<1x1x1x16x16xf32>) {
%extracted_slice = tensor.extract_slice %25[0, 0, %arg3, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<1x1x3200x16x1xf32> to tensor<1x1x1x16x1xf32>
%extracted_slice_2 = tensor.extract_slice %30[0, 0, %arg3, 0, 0] [1, 1, 1, 16, 1] [1, 1, 1, 1, 1] : tensor<1x1x3200x16x1xf16> to tensor<1x1x1x16x1xf16>
%33 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true, true, true]} : tensor<1x1x1x16x1xf32>, vector<1x1x1x16x1xf32>
%34 = vector.transfer_read %extracted_slice_2[%c0, %c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true, true, true]} : tensor<1x1x1x16x1xf16>, vector<1x1x1x16x1xf16>
%35 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0, %c0], %cst_1 {in_bounds = [true, true, true, true, true]} : tensor<1x1x1x16x16xf32>, vector<1x1x1x16x16xf32>
%36 = arith.extf %34 : vector<1x1x1x16x1xf16> to vector<1x1x1x16x1xf32>
%37 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %33, %36, %35 : vector<1x1x1x16x1xf32>, vector<1x1x1x16x1xf32> into vector<1x1x1x16x16xf32>
%38 = vector.transfer_write %37, %arg4[%c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true, true]} : vector<1x1x1x16x16xf32>, tensor<1x1x1x16x16xf32>
scf.yield %38 : tensor<1x1x1x16x16xf32>
}
flow.dispatch.tensor.store %32, %23, offsets = [%arg0, %arg1, %arg2, 0, 0], sizes = [1, 1, 1, 16, 16], strides = [1, 1, 1, 1, 1] : tensor<1x1x1x16x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x?x540x16x16xf32>>{%15, %20}
}
}
}
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment