Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 6, 2025 16:46
Show Gist options
  • Save pashu123/0b261b96af91e893e055c662d9e8079b to your computer and use it in GitHub Desktop.
Save pashu123/0b261b96af91e893e055c662d9e8079b to your computer and use it in GitHub Desktop.
hal.executable public @decode_bs1$async_dispatch_21 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) {
hal.executable.export public @decode_bs1$async_dispatch_21_attention_8x4xDx1x2x64xf8E4M3FNUZ_generic ordinal(0) layout(#hal.pipeline.layout<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
^bb0(%arg0: !hal.device, %arg1: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @decode_bs1$async_dispatch_21_attention_8x4xDx1x2x64xf8E4M3FNUZ_generic() {
%c67117632 = arith.constant 67117632 : index
%c32_i64 = arith.constant 32 : i64
%c1 = arith.constant 1 : index
%cst = arith.constant -2.400000e+02 : f32
%cst_0 = arith.constant 2.400000e+02 : f32
%c0 = arith.constant 0 : index
%c67108928 = arith.constant 67108928 : index
%c67109184 = arith.constant 67109184 : index
%0 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
%3 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
%4 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32
%5 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32
%6 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32
%7 = hal.interface.constant.load layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32
%8 = arith.index_castui %0 : i32 to index
%9 = arith.index_castui %0 : i32 to index
%10 = arith.index_castui %1 : i32 to index
%11 = arith.index_castui %2 : i32 to index
%12 = arith.extui %3 : i32 to i64
%13 = arith.extui %4 : i32 to i64
%14 = arith.shli %13, %c32_i64 : i64
%15 = arith.ori %12, %14 : i64
%16 = arith.index_castui %15 {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : i64 to index
%17 = arith.index_castui %5 : i32 to index
%18 = arith.bitcast %6 : i32 to f32
%19 = arith.index_castui %7 : i32 to index
%20:7 = util.assume.int
%8[<umin = 12352, umax = 12352, udiv = 12352>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>],
%9[<umin = 12352, umax = 12352, udiv = 12352>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>, <umin = 22784, umax = 22784, udiv = 22784>],
%10<umin = 67183360, umax = 335684160>,
%11<umin = 67314432, umax = 872424000>,
%16[<umin = 1075847616, umax = 1075847616, udiv = 1075847616>, <umin = 1293968512, umax = 1293968512, udiv = 1293968512>, <umin = 1512089408, umax = 1512089408, udiv = 1512089408>, <umin = 1730210304, umax = 1730210304, udiv = 1730210304>, <umin = 1948331200, umax = 1948331200, udiv = 1948331200>, <umin = 2166452096, umax = 2166452096, udiv = 2166452096>, <umin = 2384572992, umax = 2384572992, udiv = 2384572992>, <umin = 2602693888, umax = 2602693888, udiv = 2602693888>, <umin = 2820814784, umax = 2820814784, udiv = 2820814784>, <umin = 3038935680, umax = 3038935680, udiv = 3038935680>, <umin = 3257056576, umax = 3257056576, udiv = 3257056576>, <umin = 3475177472, umax = 3475177472, udiv = 3475177472>, <umin = 3693298368, umax = 3693298368, udiv = 3693298368>, <umin = 3911419264, umax = 3911419264, udiv = 3911419264>, <umin = 4129540160, umax = 4129540160, udiv = 4129540160>, <umin = 4347661056, umax = 4347661056, udiv = 4347661056>, <umin = 4565781952, umax = 4565781952, udiv = 4565781952>, <umin = 4783902848, umax = 4783902848, udiv = 4783902848>, <umin = 5002023744, umax = 5002023744, udiv = 5002023744>, <umin = 5220144640, umax = 5220144640, udiv = 5220144640>, <umin = 5438265536, umax = 5438265536, udiv = 5438265536>, <umin = 5656386432, umax = 5656386432, udiv = 5656386432>, <umin = 5874507328, umax = 5874507328, udiv = 5874507328>, <umin = 6092628224, umax = 6092628224, udiv = 6092628224>, <umin = 6310749120, umax = 6310749120, udiv = 6310749120>, <umin = 6528870016, umax = 6528870016, udiv = 6528870016>, <umin = 6746990912, umax = 6746990912, udiv = 6746990912>, <umin = 6965111808, umax = 6965111808, udiv = 6965111808>, <umin = 7183232704, umax = 7183232704, udiv = 7183232704>, <umin = 7401353600, umax = 7401353600, udiv = 7401353600>, <umin = 7619474496, umax = 7619474496, udiv = 7619474496>, <umin = 7837595392, umax = 7837595392, udiv = 7837595392>],
%17[<umin = 0, umax = 0>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 10496, umax = 10496, udiv = 10496>, <umin = 8192, umax = 8192, udiv = 8192>],
%19<umin = 32, umax = 131040, udiv = 32>
: index, index, index, index, index, index, index
%21 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%20#0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x32x1x2x64xf8E4M3FNUZ>>
%22 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%20#1) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x1x1x2x64xf8E4M3FNUZ>>
%23 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67108928) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x2x64xbf16>>
%24 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67109184) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x2x64xbf16>>
%25 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%20#4) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<f32>>
%26 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%20#5) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<8x4x1x128xf8E4M3FNUZ>>
%27 = flow.dispatch.workload.ordinal %20#6, 0 : index
%28 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67117632) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<?xf8E4M3FNUZ>>{%27}
%29 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%20#2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x?x1x2x64xf8E4M3FNUZ>>{%27}
%30 = hal.interface.binding.subspan layout(<constants = 8, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%20#3) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x128x?xf8E4M3FNUZ>>{%27}
%31 = flow.dispatch.tensor.load %28, offsets = [0], sizes = [%27], strides = [1] : !flow.dispatch.tensor<readonly:tensor<?xf8E4M3FNUZ>>{%27} -> tensor<?xf8E4M3FNUZ>
%32 = flow.dispatch.tensor.load %21, offsets = [0, 0, 0, 0, 0], sizes = [1, 32, 1, 2, 64], strides = [1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x32x1x2x64xf8E4M3FNUZ>> -> tensor<1x32x1x2x64xf8E4M3FNUZ>
%33 = flow.dispatch.tensor.load %22, offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 4, 1, 1, 2, 64], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x1x1x2x64xf8E4M3FNUZ>> -> tensor<8x4x1x1x2x64xf8E4M3FNUZ>
%34 = flow.dispatch.tensor.load %23, offsets = [0, 0, 0], sizes = [1, 2, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x2x64xbf16>> -> tensor<1x2x64xbf16>
%35 = flow.dispatch.tensor.load %24, offsets = [0, 0, 0], sizes = [1, 2, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x2x64xbf16>> -> tensor<1x2x64xbf16>
%36 = flow.dispatch.tensor.load %29, offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 4, %27, 1, 2, 64], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x?x1x2x64xf8E4M3FNUZ>>{%27} -> tensor<8x4x?x1x2x64xf8E4M3FNUZ>
%37 = flow.dispatch.tensor.load %30, offsets = [0, 0, 0, 0], sizes = [8, 4, 128, %27], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x128x?xf8E4M3FNUZ>>{%27} -> tensor<8x4x128x?xf8E4M3FNUZ>
%38 = flow.dispatch.tensor.load %25, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32>
%39 = tensor.empty() : tensor<8x4x1x128xf32>
%expanded = tensor.expand_shape %31 [[0, 1]] output_shape [1, %27] : tensor<?xf8E4M3FNUZ> into tensor<1x?xf8E4M3FNUZ>
%collapsed = tensor.collapse_shape %33 [[0], [1, 2, 3], [4], [5]] : tensor<8x4x1x1x2x64xf8E4M3FNUZ> into tensor<8x4x2x64xf8E4M3FNUZ>
%collapsed_1 = tensor.collapse_shape %34 [[0, 1], [2]] : tensor<1x2x64xbf16> into tensor<2x64xbf16>
%collapsed_2 = tensor.collapse_shape %35 [[0, 1], [2]] : tensor<1x2x64xbf16> into tensor<2x64xbf16>
%40 = tensor.empty() : tensor<8x4x2x64xf8E4M3FNUZ>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%collapsed, %collapsed_1, %collapsed_2 : tensor<8x4x2x64xf8E4M3FNUZ>, tensor<2x64xbf16>, tensor<2x64xbf16>) outs(%40 : tensor<8x4x2x64xf8E4M3FNUZ>) {
^bb0(%in: f8E4M3FNUZ, %in_6: bf16, %in_7: bf16, %out: f8E4M3FNUZ):
%45 = linalg.index 0 : index
%46 = linalg.index 1 : index
%47 = linalg.index 2 : index
%48 = linalg.index 3 : index
%49 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 4)>()[%46, %45]
%50 = arith.subi %c1, %47 : index
%extracted = tensor.extract %32[%c0, %49, %c0, %50, %48] : tensor<1x32x1x2x64xf8E4M3FNUZ>
%51 = arith.negf %extracted : f8E4M3FNUZ
%52 = arith.cmpi eq, %50, %c1 : index
%53 = arith.select %52, %51, %extracted : f8E4M3FNUZ
%54 = arith.truncf %in_7 : bf16 to f8E4M3FNUZ
%55 = arith.mulf %53, %54 : f8E4M3FNUZ
%56 = arith.truncf %in_6 : bf16 to f8E4M3FNUZ
%57 = arith.mulf %in, %56 : f8E4M3FNUZ
%58 = arith.addf %57, %55 : f8E4M3FNUZ
linalg.yield %58 : f8E4M3FNUZ
} -> tensor<8x4x2x64xf8E4M3FNUZ>
%expanded_3 = tensor.expand_shape %41 [[0], [1, 2, 3], [4], [5]] output_shape [8, 4, 1, 1, 2, 64] : tensor<8x4x2x64xf8E4M3FNUZ> into tensor<8x4x1x1x2x64xf8E4M3FNUZ>
%42 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d4, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d7, d4, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d3, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3)>]} ins(%expanded_3, %36, %37, %18, %expanded : tensor<8x4x1x1x2x64xf8E4M3FNUZ>, tensor<8x4x?x1x2x64xf8E4M3FNUZ>, tensor<8x4x128x?xf8E4M3FNUZ>, f32, tensor<1x?xf8E4M3FNUZ>) outs(%39 : tensor<8x4x1x128xf32>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<8x4x1x128xf32>
%collapsed_4 = tensor.collapse_shape %42 [[0], [1, 2], [3]] : tensor<8x4x1x128xf32> into tensor<8x4x128xf32>
%43 = tensor.empty() : tensor<8x4x128xf8E4M3FNUZ>
%44 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed_4, %38 : tensor<8x4x128xf32>, tensor<f32>) outs(%43 : tensor<8x4x128xf8E4M3FNUZ>) {
^bb0(%in: f32, %in_6: f32, %out: f8E4M3FNUZ):
%45 = arith.divf %in, %in_6 : f32
%46 = arith.cmpf ult, %45, %cst : f32
%47 = arith.select %46, %cst, %45 : f32
%48 = arith.cmpf ugt, %47, %cst_0 : f32
%49 = arith.select %48, %cst_0, %47 : f32
%50 = arith.truncf %49 : f32 to f8E4M3FNUZ
linalg.yield %50 : f8E4M3FNUZ
} -> tensor<8x4x128xf8E4M3FNUZ>
%expanded_5 = tensor.expand_shape %44 [[0], [1, 2], [3]] output_shape [8, 4, 1, 128] : tensor<8x4x128xf8E4M3FNUZ> into tensor<8x4x1x128xf8E4M3FNUZ>
flow.dispatch.tensor.store %expanded_5, %26, offsets = [0, 0, 0, 0], sizes = [8, 4, 1, 128], strides = [1, 1, 1, 1] : tensor<8x4x1x128xf8E4M3FNUZ> -> !flow.dispatch.tensor<writeonly:tensor<8x4x1x128xf8E4M3FNUZ>>
return
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment