Created
June 6, 2025 19:04
-
-
Save bjacob/f762d7aff2f4ba0bb08d224f09906a00 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hal.executable public @prefill_bs4$async_dispatch_23 { | |
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "+sramecc,-xnack", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) { | |
hal.executable.export public @prefill_bs4$async_dispatch_23_matmul_like_Dx14336x4096_f8E4M3FNUZxf8E4M3FNUZxf32 ordinal(0) layout(#hal.pipeline.layout<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) count(%arg0: !hal.device, %arg1: index, %arg2: index) -> (index, index, index) { | |
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice %arg1, %arg2 | |
hal.return %x, %y, %z : index, index, index | |
} | |
builtin.module { | |
func.func @prefill_bs4$async_dispatch_23_matmul_like_Dx14336x4096_f8E4M3FNUZxf8E4M3FNUZxf32() { | |
%c32_i64 = arith.constant 32 : i64 | |
%c4 = arith.constant 4 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
%cst_0 = arith.constant 1.000000e+00 : f32 | |
%cst_1 = arith.constant -2.400000e+02 : f32 | |
%cst_2 = arith.constant 2.400000e+02 : f32 | |
%0 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32 | |
%1 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32 | |
%2 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32 | |
%3 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32 | |
%4 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32 | |
%5 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32 | |
%6 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32 | |
%7 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32 | |
%8 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(8) : i32 | |
%9 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(9) : i32 | |
%10 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(10) : i32 | |
%11 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(11) : i32 | |
%12 = hal.interface.constant.load layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(12) : i32 | |
%13 = arith.extui %0 : i32 to i64 | |
%14 = arith.extui %1 : i32 to i64 | |
%15 = arith.shli %14, %c32_i64 : i64 | |
%16 = arith.ori %13, %15 : i64 | |
%17 = arith.index_castui %16 : i64 to index | |
%18 = arith.extui %2 : i32 to i64 | |
%19 = arith.extui %3 : i32 to i64 | |
%20 = arith.shli %19, %c32_i64 : i64 | |
%21 = arith.ori %18, %20 : i64 | |
%22 = arith.index_castui %21 : i64 to index | |
%23 = arith.extui %4 : i32 to i64 | |
%24 = arith.extui %5 : i32 to i64 | |
%25 = arith.shli %24, %c32_i64 : i64 | |
%26 = arith.ori %23, %25 : i64 | |
%27 = arith.index_castui %26 {stream.alignment = 64 : index, stream.values = [1092633152 : index, 1310754240 : index, 1528875328 : index, 1746996416 : index, 1965117504 : index, 2183238592 : index, 2401359680 : index, 2619480768 : index, 2837601856 : index, 3055722944 : index, 3273844032 : index, 3491965120 : index, 3710086208 : index, 3928207296 : index, 4146328384 : index, 4364449472 : index, 4582570560 : index, 4800691648 : index, 5018812736 : index, 5236933824 : index, 5455054912 : index, 5673176000 : index, 5891297088 : index, 6109418176 : index, 6327539264 : index, 6545660352 : index, 6763781440 : index, 6981902528 : index, 7200023616 : index, 7418144704 : index, 7636265792 : index, 7854386880 : index]} : i64 to index | |
%28 = arith.extui %6 : i32 to i64 | |
%29 = arith.extui %5 : i32 to i64 | |
%30 = arith.shli %29, %c32_i64 : i64 | |
%31 = arith.ori %28, %30 : i64 | |
%32 = arith.index_castui %31 {stream.alignment = 128 : index, stream.values = [1210073856 : index, 1428194944 : index, 1646316032 : index, 1864437120 : index, 2082558208 : index, 2300679296 : index, 2518800384 : index, 2736921472 : index, 2955042560 : index, 3173163648 : index, 3391284736 : index, 3609405824 : index, 3827526912 : index, 4045648000 : index, 4263769088 : index, 4481890176 : index, 4700011264 : index, 4918132352 : index, 5136253440 : index, 5354374528 : index, 5572495616 : index, 5790616704 : index, 6008737792 : index, 6226858880 : index, 6444979968 : index, 6663101056 : index, 6881222144 : index, 7099343232 : index, 7317464320 : index, 7535585408 : index, 7753706496 : index, 7971827584 : index]} : i64 to index | |
%33 = arith.index_castui %7 : i32 to index | |
%34 = arith.index_castui %8 : i32 to index | |
%35 = arith.extui %9 : i32 to i64 | |
%36 = arith.extui %10 : i32 to i64 | |
%37 = arith.shli %36, %c32_i64 : i64 | |
%38 = arith.ori %35, %37 : i64 | |
%39 = arith.index_castui %38 : i64 to index | |
%40 = arith.index_castui %11 : i32 to index | |
%41 = arith.index_castui %12 : i32 to index | |
%42:9 = util.assume.int | |
%17<umin = 69206016, umax = 8654946304>, | |
%22<umin = 80757888, umax = 73126888448>, | |
%27[<umin = 1092633152, umax = 1092633152, udiv = 1092633152>, <umin = 1310754240, umax = 1310754240, udiv = 1310754240>, <umin = 1528875328, umax = 1528875328, udiv = 1528875328>, <umin = 1746996416, umax = 1746996416, udiv = 1746996416>, <umin = 1965117504, umax = 1965117504, udiv = 1965117504>, <umin = 2183238592, umax = 2183238592, udiv = 2183238592>, <umin = 2401359680, umax = 2401359680, udiv = 2401359680>, <umin = 2619480768, umax = 2619480768, udiv = 2619480768>, <umin = 2837601856, umax = 2837601856, udiv = 2837601856>, <umin = 3055722944, umax = 3055722944, udiv = 3055722944>, <umin = 3273844032, umax = 3273844032, udiv = 3273844032>, <umin = 3491965120, umax = 3491965120, udiv = 3491965120>, <umin = 3710086208, umax = 3710086208, udiv = 3710086208>, <umin = 3928207296, umax = 3928207296, udiv = 3928207296>, <umin = 4146328384, umax = 4146328384, udiv = 4146328384>, <umin = 4364449472, umax = 4364449472, udiv = 4364449472>, <umin = 4582570560, umax = 4582570560, udiv = 4582570560>, <umin = 4800691648, umax = 4800691648, udiv = 4800691648>, <umin = 5018812736, umax = 5018812736, udiv = 5018812736>, <umin = 5236933824, umax = 5236933824, udiv = 5236933824>, <umin = 5455054912, umax = 5455054912, udiv = 5455054912>, <umin = 5673176000, umax = 5673176000, udiv = 5673176000>, <umin = 5891297088, umax = 5891297088, udiv = 5891297088>, <umin = 6109418176, umax = 6109418176, udiv = 6109418176>, <umin = 6327539264, umax = 6327539264, udiv = 6327539264>, <umin = 6545660352, umax = 6545660352, udiv = 6545660352>, <umin = 6763781440, umax = 6763781440, udiv = 6763781440>, <umin = 6981902528, umax = 6981902528, udiv = 6981902528>, <umin = 7200023616, umax = 7200023616, udiv = 7200023616>, <umin = 7418144704, umax = 7418144704, udiv = 7418144704>, <umin = 7636265792, umax = 7636265792, udiv = 7636265792>, <umin = 7854386880, umax = 7854386880, udiv = 7854386880>], | |
%32[<umin = 1210073856, umax = 1210073856, udiv = 1210073856>, <umin = 1428194944, umax = 1428194944, udiv = 1428194944>, <umin = 1646316032, umax = 1646316032, udiv = 1646316032>, <umin = 1864437120, umax = 1864437120, udiv = 1864437120>, <umin = 2082558208, umax = 2082558208, udiv = 2082558208>, <umin = 2300679296, umax = 2300679296, udiv = 2300679296>, <umin = 2518800384, umax = 2518800384, udiv = 2518800384>, <umin = 2736921472, umax = 2736921472, udiv = 2736921472>, <umin = 2955042560, umax = 2955042560, udiv = 2955042560>, <umin = 3173163648, umax = 3173163648, udiv = 3173163648>, <umin = 3391284736, umax = 3391284736, udiv = 3391284736>, <umin = 3609405824, umax = 3609405824, udiv = 3609405824>, <umin = 3827526912, umax = 3827526912, udiv = 3827526912>, <umin = 4045648000, umax = 4045648000, udiv = 4045648000>, <umin = 4263769088, umax = 4263769088, udiv = 4263769088>, <umin = 4481890176, umax = 4481890176, udiv = 4481890176>, <umin = 4700011264, umax = 4700011264, udiv = 4700011264>, <umin = 4918132352, umax = 4918132352, udiv = 4918132352>, <umin = 5136253440, umax = 5136253440, udiv = 5136253440>, <umin = 5354374528, umax = 5354374528, udiv = 5354374528>, <umin = 5572495616, umax = 5572495616, udiv = 5572495616>, <umin = 5790616704, umax = 5790616704, udiv = 5790616704>, <umin = 6008737792, umax = 6008737792, udiv = 6008737792>, <umin = 6226858880, umax = 6226858880, udiv = 6226858880>, <umin = 6444979968, umax = 6444979968, udiv = 6444979968>, <umin = 6663101056, umax = 6663101056, udiv = 6663101056>, <umin = 6881222144, umax = 6881222144, udiv = 6881222144>, <umin = 7099343232, umax = 7099343232, udiv = 7099343232>, <umin = 7317464320, umax = 7317464320, udiv = 7317464320>, <umin = 7535585408, umax = 7535585408, udiv = 7535585408>, <umin = 7753706496, umax = 7753706496, udiv = 7753706496>, <umin = 7971827584, umax = 7971827584, udiv = 7971827584>], | |
%33[<umin = 524352, umax = 524352, udiv = 524352>, <umin = 524608, umax = 524608, udiv = 524608>, <umin = 524864, umax = 524864, udiv = 524864>, <umin = 525120, umax = 525120, udiv = 525120>, <umin = 525376, umax = 525376, udiv = 525376>, <umin = 525632, umax = 525632, udiv = 525632>, <umin = 525888, umax = 525888, udiv = 525888>, <umin = 526144, umax = 526144, udiv = 526144>, <umin = 526400, umax = 526400, udiv = 526400>, <umin = 526656, umax = 526656, udiv = 526656>, <umin = 526912, umax = 526912, udiv = 526912>, <umin = 527168, umax = 527168, udiv = 527168>, <umin = 527424, umax = 527424, udiv = 527424>, <umin = 527680, umax = 527680, udiv = 527680>, <umin = 527936, umax = 527936, udiv = 527936>, <umin = 528192, umax = 528192, udiv = 528192>, <umin = 528448, umax = 528448, udiv = 528448>, <umin = 528704, umax = 528704, udiv = 528704>, <umin = 528960, umax = 528960, udiv = 528960>, <umin = 529216, umax = 529216, udiv = 529216>, <umin = 529472, umax = 529472, udiv = 529472>, <umin = 529728, umax = 529728, udiv = 529728>, <umin = 529984, umax = 529984, udiv = 529984>, <umin = 530240, umax = 530240, udiv = 530240>, <umin = 530496, umax = 530496, udiv = 530496>, <umin = 530752, umax = 530752, udiv = 530752>, <umin = 531008, umax = 531008, udiv = 531008>, <umin = 531264, umax = 531264, udiv = 531264>, <umin = 531520, umax = 531520, udiv = 531520>, <umin = 531776, umax = 531776, udiv = 531776>, <umin = 532032, umax = 532032, udiv = 532032>, <umin = 532288, umax = 532288, udiv = 532288>], | |
%34[<umin = 524416, umax = 524416, udiv = 524416>, <umin = 524672, umax = 524672, udiv = 524672>, <umin = 524928, umax = 524928, udiv = 524928>, <umin = 525184, umax = 525184, udiv = 525184>, <umin = 525440, umax = 525440, udiv = 525440>, <umin = 525696, umax = 525696, udiv = 525696>, <umin = 525952, umax = 525952, udiv = 525952>, <umin = 526208, umax = 526208, udiv = 526208>, <umin = 526464, umax = 526464, udiv = 526464>, <umin = 526720, umax = 526720, udiv = 526720>, <umin = 526976, umax = 526976, udiv = 526976>, <umin = 527232, umax = 527232, udiv = 527232>, <umin = 527488, umax = 527488, udiv = 527488>, <umin = 527744, umax = 527744, udiv = 527744>, <umin = 528000, umax = 528000, udiv = 528000>, <umin = 528256, umax = 528256, udiv = 528256>, <umin = 528512, umax = 528512, udiv = 528512>, <umin = 528768, umax = 528768, udiv = 528768>, <umin = 529024, umax = 529024, udiv = 529024>, <umin = 529280, umax = 529280, udiv = 529280>, <umin = 529536, umax = 529536, udiv = 529536>, <umin = 529792, umax = 529792, udiv = 529792>, <umin = 530048, umax = 530048, udiv = 530048>, <umin = 530304, umax = 530304, udiv = 530304>, <umin = 530560, umax = 530560, udiv = 530560>, <umin = 530816, umax = 530816, udiv = 530816>, <umin = 531072, umax = 531072, udiv = 531072>, <umin = 531328, umax = 531328, udiv = 531328>, <umin = 531584, umax = 531584, udiv = 531584>, <umin = 531840, umax = 531840, udiv = 531840>, <umin = 532096, umax = 532096, udiv = 532096>, <umin = 532352, umax = 532352, udiv = 532352>], | |
%39<umin = 88097920, umax = 103184319488>, | |
%40<umin = 32, umax = 131040, udiv = 32>, | |
%41<umin = 128, umax = 524160, udiv = 128> | |
: index, index, index, index, index, index, index, index, index | |
%43 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%42#2) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<14336x4096xf8E4M3FNUZ>> | |
%44 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%42#4) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> | |
%45 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%42#5) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> | |
%46 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%42#3) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> | |
%47 = iree_tensor_ext.dispatch.workload.ordinal %42#8, 1 : index | |
%48 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%42#0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x4096xf8E4M3FNUZ>>{%47} | |
%49 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%42#1) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x14336xf32>>{%47} | |
%50 = hal.interface.binding.subspan layout(<constants = 13, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%42#6) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x14336xf8E4M3FNUZ>>{%47} | |
%51 = iree_tensor_ext.dispatch.workload.ordinal %42#7, 0 : index | |
%52 = iree_tensor_ext.dispatch.tensor.load %48, offsets = [0, 0], sizes = [%47, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x4096xf8E4M3FNUZ>>{%47} -> tensor<?x4096xf8E4M3FNUZ> | |
%53 = iree_tensor_ext.dispatch.tensor.load %43, offsets = [0, 0], sizes = [14336, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<14336x4096xf8E4M3FNUZ>> -> tensor<14336x4096xf8E4M3FNUZ> | |
%54 = iree_tensor_ext.dispatch.tensor.load %44, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32> | |
%55 = iree_tensor_ext.dispatch.tensor.load %49, offsets = [0, 0], sizes = [%47, 14336], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x14336xf32>>{%47} -> tensor<?x14336xf32> | |
%56 = iree_tensor_ext.dispatch.tensor.load %45, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32> | |
%57 = iree_tensor_ext.dispatch.tensor.load %46, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32> | |
%58 = arith.muli %51, %c4 overflow<nsw> : index | |
%59 = tensor.empty(%58) : tensor<?x14336xf32> | |
%60 = linalg.fill ins(%cst : f32) outs(%59 : tensor<?x14336xf32>) -> tensor<?x14336xf32> | |
%61 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%52, %53 : tensor<?x4096xf8E4M3FNUZ>, tensor<14336x4096xf8E4M3FNUZ>) outs(%60 : tensor<?x14336xf32>) { | |
^bb0(%in: f8E4M3FNUZ, %in_3: f8E4M3FNUZ, %out: f32): | |
%64 = arith.extf %in : f8E4M3FNUZ to f32 | |
%65 = arith.extf %in_3 : f8E4M3FNUZ to f32 | |
%66 = arith.mulf %64, %65 : f32 | |
%67 = arith.addf %out, %66 : f32 | |
linalg.yield %67 : f32 | |
} -> tensor<?x14336xf32> | |
%62 = tensor.empty(%58) : tensor<?x14336xf8E4M3FNUZ> | |
%63 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%61, %54, %55, %56, %57 : tensor<?x14336xf32>, tensor<f32>, tensor<?x14336xf32>, tensor<f32>, tensor<f32>) outs(%62 : tensor<?x14336xf8E4M3FNUZ>) { | |
^bb0(%in: f32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f8E4M3FNUZ): | |
%64 = arith.mulf %in, %in_3 : f32 | |
%65 = arith.mulf %in_4, %in_5 : f32 | |
%66 = arith.negf %64 : f32 | |
%67 = math.exp %66 : f32 | |
%68 = arith.addf %67, %cst_0 : f32 | |
%69 = arith.divf %cst_0, %68 : f32 | |
%70 = arith.mulf %69, %64 : f32 | |
%71 = arith.mulf %70, %65 : f32 | |
%72 = arith.divf %71, %in_6 : f32 | |
%73 = arith.cmpf ult, %72, %cst_1 : f32 | |
%74 = arith.select %73, %cst_1, %72 : f32 | |
%75 = arith.cmpf ugt, %74, %cst_2 : f32 | |
%76 = arith.select %75, %cst_2, %74 : f32 | |
%77 = arith.truncf %76 : f32 to f8E4M3FNUZ | |
linalg.yield %77 : f8E4M3FNUZ | |
} -> tensor<?x14336xf8E4M3FNUZ> | |
iree_tensor_ext.dispatch.tensor.store %63, %50, offsets = [0, 0], sizes = [%47, 14336], strides = [1, 1] : tensor<?x14336xf8E4M3FNUZ> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x14336xf8E4M3FNUZ>>{%47} | |
return | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment