Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created September 26, 2024 15:55
Show Gist options
  • Save pashu123/7c8d20a28aaca2c17b66179906e381db to your computer and use it in GitHub Desktop.
Save pashu123/7c8d20a28aaca2c17b66179906e381db to your computer and use it in GitHub Desktop.
func.func @matmul_dispatch_2_mmt4d_4x4x64x16x16x1_f32() {
%c1537_i32 = arith.constant 1537 : i32
%c1_i32 = arith.constant 1 : i32
%c16_i32 = arith.constant 16 : i32
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c16384 = arith.constant 16384 : index
%c32768 = arith.constant 32768 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c16384) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c32768) flags(Indirect) : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%subview = memref.subview %0[%workgroup_id_y, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_0 = memref.subview %1[%workgroup_id_x, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_1 = memref.subview %2[%workgroup_id_y, %workgroup_id_x, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> to memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%3 = iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%subview, %subview_0 : memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_1 : memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) (%c1, %c1, %c64, %c16_i32, %c16_i32, %c1_i32, %c1537_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.fields = ["processor_data"]} strided_outer_dims(1) -> i32
return
}
}
func.func @matmul_dispatch_2_mmt4d_4x4x64x16x16x1_f32() {
%c1537_i32 = arith.constant 1537 : i32
%c1_i32 = arith.constant 1 : i32
%c16_i32 = arith.constant 16 : i32
%c64 = arith.constant 64 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%c16384 = arith.constant 16384 : index
%c32768 = arith.constant 32768 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c16384) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c32768) flags(Indirect) : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
scf.for %arg0 = %workgroup_id_y to %c4 step %workgroup_count_y {
%subview = memref.subview %0[%arg0, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
scf.for %arg1 = %workgroup_id_x to %c4 step %workgroup_count_x {
%subview_0 = memref.subview %2[%arg0, %arg1, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> to memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%subview_1 = memref.subview %1[%arg1, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%3 = iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%subview, %subview_1 : memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_0 : memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) (%c1, %c1, %c64, %c16_i32, %c16_i32, %c1_i32, %c1537_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.fields = ["processor_data"]} strided_outer_dims(1) -> i32
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment