Created
September 26, 2024 15:55
-
-
Save pashu123/7c8d20a28aaca2c17b66179906e381db to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @matmul_dispatch_2_mmt4d_4x4x64x16x16x1_f32() { | |
%c1537_i32 = arith.constant 1537 : i32 | |
%c1_i32 = arith.constant 1 : i32 | |
%c16_i32 = arith.constant 16 : i32 | |
%c64 = arith.constant 64 : index | |
%c1 = arith.constant 1 : index | |
%c0 = arith.constant 0 : index | |
%c16384 = arith.constant 16384 : index | |
%c32768 = arith.constant 32768 : index | |
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %0, 64 : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> | |
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c16384) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %1, 64 : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> | |
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c32768) flags(Indirect) : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %2, 64 : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> | |
%workgroup_id_x = hal.interface.workgroup.id[0] : index | |
%workgroup_count_x = hal.interface.workgroup.count[0] : index | |
%workgroup_id_y = hal.interface.workgroup.id[1] : index | |
%workgroup_count_y = hal.interface.workgroup.count[1] : index | |
%subview = memref.subview %0[%workgroup_id_y, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%subview_0 = memref.subview %1[%workgroup_id_x, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%subview_1 = memref.subview %2[%workgroup_id_y, %workgroup_id_x, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> to memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%3 = iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%subview, %subview_0 : memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_1 : memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) (%c1, %c1, %c64, %c16_i32, %c16_i32, %c1_i32, %c1537_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.fields = ["processor_data"]} strided_outer_dims(1) -> i32 | |
return | |
} | |
} | |
func.func @matmul_dispatch_2_mmt4d_4x4x64x16x16x1_f32() { | |
%c1537_i32 = arith.constant 1537 : i32 | |
%c1_i32 = arith.constant 1 : i32 | |
%c16_i32 = arith.constant 16 : i32 | |
%c64 = arith.constant 64 : index | |
%c1 = arith.constant 1 : index | |
%c4 = arith.constant 4 : index | |
%c0 = arith.constant 0 : index | |
%c16384 = arith.constant 16384 : index | |
%c32768 = arith.constant 32768 : index | |
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %0, 64 : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> | |
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c16384) flags("ReadOnly|Indirect") : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %1, 64 : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> | |
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c32768) flags(Indirect) : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %2, 64 : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> | |
%workgroup_id_x = hal.interface.workgroup.id[0] : index | |
%workgroup_count_x = hal.interface.workgroup.count[0] : index | |
%workgroup_id_y = hal.interface.workgroup.id[1] : index | |
%workgroup_count_y = hal.interface.workgroup.count[1] : index | |
scf.for %arg0 = %workgroup_id_y to %c4 step %workgroup_count_y { | |
%subview = memref.subview %0[%arg0, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
scf.for %arg1 = %workgroup_id_x to %c4 step %workgroup_count_x { | |
%subview_0 = memref.subview %2[%arg0, %arg1, 0, 0] [1, 1, 16, 16] [1, 1, 1, 1] : memref<4x4x16x16xf32, strided<[1024, 256, 16, 1], offset: 8192>, #hal.descriptor_type<storage_buffer>> to memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%subview_1 = memref.subview %1[%arg1, 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1] : memref<4x64x16x1xf32, strided<[1024, 16, 1, 1], offset: 4096>, #hal.descriptor_type<storage_buffer>> to memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%3 = iree_codegen.ukernel.generic "iree_uk_mmt4d" ins(%subview, %subview_1 : memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, memref<1x64x16x1xf32, strided<[1024, 16, 1, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) outs(%subview_0 : memref<1x1x16x16xf32, strided<[1024, 256, 16, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) (%c1, %c1, %c64, %c16_i32, %c16_i32, %c1_i32, %c1537_i32 : index, index, index, i32, i32, i32, i32) fn_def_attrs {hal.import.bitcode = true, hal.import.fields = ["processor_data"]} strided_outer_dims(1) -> i32 | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment