Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 14, 2025 11:00
Show Gist options
  • Save pashu123/e21bc74fafbc4ce3ae23b0adf3ac75b5 to your computer and use it in GitHub Desktop.
Save pashu123/e21bc74fafbc4ce3ae23b0adf3ac75b5 to your computer and use it in GitHub Desktop.
hal.executable public @prefill_bs1$async_dispatch_19 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) {
hal.executable.export public @prefill_bs1$async_dispatch_19_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic ordinal(0) layout(#hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @prefill_bs1$async_dispatch_19_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic() {
%c67108864 = arith.constant 67108864 : index
%c32_i64 = arith.constant 32 : i64
%cst = arith.constant 0xFF800000 : f32
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant -2.400000e+02 : f32
%cst_2 = arith.constant 2.400000e+02 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
%3 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
%4 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32
%5 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32
%6 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32
%7 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32
%8 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(8) : i32
%9 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(9) : i32
%10 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(10) : i32
%11 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(11) : i32
%12 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(12) : i32
%13 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(13) : i32
%14 = arith.extui %0 : i32 to i64
%15 = arith.extui %1 : i32 to i64
%16 = arith.shli %15, %c32_i64 : i64
%17 = arith.ori %14, %16 : i64
%18 = arith.index_castui %17 : i64 to index
%19 = arith.extui %2 : i32 to i64
%20 = arith.extui %3 : i32 to i64
%21 = arith.shli %20, %c32_i64 : i64
%22 = arith.ori %19, %21 : i64
%23 = arith.index_castui %22 : i64 to index
%24 = arith.extui %4 : i32 to i64
%25 = arith.extui %5 : i32 to i64
%26 = arith.shli %25, %c32_i64 : i64
%27 = arith.ori %24, %26 : i64
%28 = arith.index_castui %27 : i64 to index
%29 = arith.extui %6 : i32 to i64
%30 = arith.extui %7 : i32 to i64
%31 = arith.shli %30, %c32_i64 : i64
%32 = arith.ori %29, %31 : i64
%33 = arith.index_castui %32 {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : i64 to index
%34 = arith.extui %8 : i32 to i64
%35 = arith.extui %9 : i32 to i64
%36 = arith.shli %35, %c32_i64 : i64
%37 = arith.ori %34, %36 : i64
%38 = arith.index_castui %37 : i64 to index
%39 = arith.bitcast %10 : i32 to f32
%40 = arith.index_castui %11 : i32 to index
%41 = arith.index_castui %12 : i32 to index
%42 = arith.index_castui %11 : i32 to index
%43 = arith.index_castui %13 : i32 to index
%44:9 = util.assume.int
%18<umin = 68027392, umax = 20995769344>,
%23<umin = 68158464, umax = 21532509184>,
%28[<umin = 67765248, umax = 19922289664>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>],
%33[<umin = 1075847616, umax = 1075847616, udiv = 1075847616>, <umin = 1293968512, umax = 1293968512, udiv = 1293968512>, <umin = 1512089408, umax = 1512089408, udiv = 1512089408>, <umin = 1730210304, umax = 1730210304, udiv = 1730210304>, <umin = 1948331200, umax = 1948331200, udiv = 1948331200>, <umin = 2166452096, umax = 2166452096, udiv = 2166452096>, <umin = 2384572992, umax = 2384572992, udiv = 2384572992>, <umin = 2602693888, umax = 2602693888, udiv = 2602693888>, <umin = 2820814784, umax = 2820814784, udiv = 2820814784>, <umin = 3038935680, umax = 3038935680, udiv = 3038935680>, <umin = 3257056576, umax = 3257056576, udiv = 3257056576>, <umin = 3475177472, umax = 3475177472, udiv = 3475177472>, <umin = 3693298368, umax = 3693298368, udiv = 3693298368>, <umin = 3911419264, umax = 3911419264, udiv = 3911419264>, <umin = 4129540160, umax = 4129540160, udiv = 4129540160>, <umin = 4347661056, umax = 4347661056, udiv = 4347661056>, <umin = 4565781952, umax = 4565781952, udiv = 4565781952>, <umin = 4783902848, umax = 4783902848, udiv = 4783902848>, <umin = 5002023744, umax = 5002023744, udiv = 5002023744>, <umin = 5220144640, umax = 5220144640, udiv = 5220144640>, <umin = 5438265536, umax = 5438265536, udiv = 5438265536>, <umin = 5656386432, umax = 5656386432, udiv = 5656386432>, <umin = 5874507328, umax = 5874507328, udiv = 5874507328>, <umin = 6092628224, umax = 6092628224, udiv = 6092628224>, <umin = 6310749120, umax = 6310749120, udiv = 6310749120>, <umin = 6528870016, umax = 6528870016, udiv = 6528870016>, <umin = 6746990912, umax = 6746990912, udiv = 6746990912>, <umin = 6965111808, umax = 6965111808, udiv = 6965111808>, <umin = 7183232704, umax = 7183232704, udiv = 7183232704>, <umin = 7401353600, umax = 7401353600, udiv = 7401353600>, <umin = 7619474496, umax = 7619474496, udiv = 7619474496>, <umin = 7837595392, umax = 7837595392, udiv = 7837595392>],
%38<umin = 67896320, umax = 20459029504>,
%40<umin = 1, umax = 4095>,
%41<umin = 32, umax = 131040, udiv = 32>,
%42<umin = 1, umax = 4095>,
%43<umin = 32, umax = 131040, udiv = 32>
: index, index, index, index, index, index, index, index, index
%45 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<i64>>
%46 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%44#3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<f32>>
%47 = flow.dispatch.workload.ordinal %44#5, 0 : index
%48 = flow.dispatch.workload.ordinal %44#6, 1 : index
%49 = flow.dispatch.workload.ordinal %44#7, 2 : index
%50 = flow.dispatch.workload.ordinal %44#8, 3 : index
%51 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67108864) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<?x32x?xi8>>{%47, %48}
%52 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%44#0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x1x?x32x128xf8E4M3FNUZ>>{%49}
%53 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%44#1) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x?x128xf8E4M3FNUZ>>{%50}
%54 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%44#2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x128x?xf8E4M3FNUZ>>{%48}
%55 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%44#4) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x?x32x8x4x128xf8E4M3FNUZ>>{%47}
%56 = flow.dispatch.tensor.load %51, offsets = [0, 0, 0], sizes = [%47, 32, %48], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x32x?xi8>>{%47, %48} -> tensor<?x32x?xi8>
%57 = arith.trunci %56 : tensor<?x32x?xi8> to tensor<?x32x?xi1>
%58 = flow.dispatch.tensor.load %45, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
%59 = flow.dispatch.tensor.load %52, offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 4, 1, %49, 32, 128], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x1x?x32x128xf8E4M3FNUZ>>{%49} -> tensor<8x4x1x?x32x128xf8E4M3FNUZ>
%60 = flow.dispatch.tensor.load %53, offsets = [0, 0, 0, 0], sizes = [8, 4, %50, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x?x128xf8E4M3FNUZ>>{%50} -> tensor<8x4x?x128xf8E4M3FNUZ>
%61 = flow.dispatch.tensor.load %54, offsets = [0, 0, 0, 0], sizes = [8, 4, 128, %48], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x128x?xf8E4M3FNUZ>>{%48} -> tensor<8x4x128x?xf8E4M3FNUZ>
%62 = flow.dispatch.tensor.load %46, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32>
%63 = tensor.empty(%47) : tensor<1x?x32x8x4x128xf8E4M3FNUZ>
%64 = tensor.empty(%47) : tensor<8x4x1x?x32x128xf32>
%65 = tensor.empty(%47, %48) : tensor<?x32x?xf8E4M3FNUZ>
%66 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%57, %58 : tensor<?x32x?xi1>, tensor<i64>) outs(%65 : tensor<?x32x?xf8E4M3FNUZ>) {
^bb0(%in: i1, %in_3: i64, %out: f8E4M3FNUZ):
%69 = linalg.index 2 : index
%70 = arith.index_cast %69 : index to i64
%71 = arith.cmpi sge, %70, %in_3 : i64
%72 = arith.ori %in, %71 : i1
%73 = arith.select %72, %cst, %cst_0 : f32
%74 = arith.truncf %73 : f32 to f8E4M3FNUZ
linalg.yield %74 : f8E4M3FNUZ
} -> tensor<?x32x?xf8E4M3FNUZ>
%expanded = tensor.expand_shape %66 [[0, 1], [2], [3]] output_shape [1, %47, 32, %48] : tensor<?x32x?xf8E4M3FNUZ> into tensor<1x?x32x?xf8E4M3FNUZ>
%67 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%59, %60, %61, %39, %expanded : tensor<8x4x1x?x32x128xf8E4M3FNUZ>, tensor<8x4x?x128xf8E4M3FNUZ>, tensor<8x4x128x?xf8E4M3FNUZ>, f32, tensor<1x?x32x?xf8E4M3FNUZ>) outs(%64 : tensor<8x4x1x?x32x128xf32>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<8x4x1x?x32x128xf32>
%68 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%67, %62 : tensor<8x4x1x?x32x128xf32>, tensor<f32>) outs(%63 : tensor<1x?x32x8x4x128xf8E4M3FNUZ>) {
^bb0(%in: f32, %in_3: f32, %out: f8E4M3FNUZ):
%69 = arith.divf %in, %in_3 : f32
%70 = arith.cmpf ult, %69, %cst_1 : f32
%71 = arith.select %70, %cst_1, %69 : f32
%72 = arith.cmpf ugt, %71, %cst_2 : f32
%73 = arith.select %72, %cst_2, %71 : f32
%74 = arith.truncf %73 : f32 to f8E4M3FNUZ
linalg.yield %74 : f8E4M3FNUZ
} -> tensor<1x?x32x8x4x128xf8E4M3FNUZ>
flow.dispatch.tensor.store %68, %55, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, %47, 32, 8, 4, 128], strides = [1, 1, 1, 1, 1, 1] : tensor<1x?x32x8x4x128xf8E4M3FNUZ> -> !flow.dispatch.tensor<writeonly:tensor<1x?x32x8x4x128xf8E4M3FNUZ>>{%47}
return
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment