Created
February 14, 2025 11:00
-
-
Save pashu123/e21bc74fafbc4ce3ae23b0adf3ac75b5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
hal.executable public @prefill_bs1$async_dispatch_19 { | |
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>) { | |
hal.executable.export public @prefill_bs1$async_dispatch_19_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic ordinal(0) layout(#hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) { | |
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index, %arg4: index): | |
%x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3, %arg4 | |
hal.return %x, %y, %z : index, index, index | |
} | |
builtin.module { | |
func.func @prefill_bs1$async_dispatch_19_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic() { | |
%c67108864 = arith.constant 67108864 : index | |
%c32_i64 = arith.constant 32 : i64 | |
%cst = arith.constant 0xFF800000 : f32 | |
%cst_0 = arith.constant 0.000000e+00 : f32 | |
%cst_1 = arith.constant -2.400000e+02 : f32 | |
%cst_2 = arith.constant 2.400000e+02 : f32 | |
%c0 = arith.constant 0 : index | |
%0 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32 | |
%1 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32 | |
%2 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32 | |
%3 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32 | |
%4 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32 | |
%5 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32 | |
%6 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32 | |
%7 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32 | |
%8 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(8) : i32 | |
%9 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(9) : i32 | |
%10 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(10) : i32 | |
%11 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(11) : i32 | |
%12 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(12) : i32 | |
%13 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(13) : i32 | |
%14 = arith.extui %0 : i32 to i64 | |
%15 = arith.extui %1 : i32 to i64 | |
%16 = arith.shli %15, %c32_i64 : i64 | |
%17 = arith.ori %14, %16 : i64 | |
%18 = arith.index_castui %17 : i64 to index | |
%19 = arith.extui %2 : i32 to i64 | |
%20 = arith.extui %3 : i32 to i64 | |
%21 = arith.shli %20, %c32_i64 : i64 | |
%22 = arith.ori %19, %21 : i64 | |
%23 = arith.index_castui %22 : i64 to index | |
%24 = arith.extui %4 : i32 to i64 | |
%25 = arith.extui %5 : i32 to i64 | |
%26 = arith.shli %25, %c32_i64 : i64 | |
%27 = arith.ori %24, %26 : i64 | |
%28 = arith.index_castui %27 : i64 to index | |
%29 = arith.extui %6 : i32 to i64 | |
%30 = arith.extui %7 : i32 to i64 | |
%31 = arith.shli %30, %c32_i64 : i64 | |
%32 = arith.ori %29, %31 : i64 | |
%33 = arith.index_castui %32 {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : i64 to index | |
%34 = arith.extui %8 : i32 to i64 | |
%35 = arith.extui %9 : i32 to i64 | |
%36 = arith.shli %35, %c32_i64 : i64 | |
%37 = arith.ori %34, %36 : i64 | |
%38 = arith.index_castui %37 : i64 to index | |
%39 = arith.bitcast %10 : i32 to f32 | |
%40 = arith.index_castui %11 : i32 to index | |
%41 = arith.index_castui %12 : i32 to index | |
%42 = arith.index_castui %11 : i32 to index | |
%43 = arith.index_castui %13 : i32 to index | |
%44:9 = util.assume.int | |
%18<umin = 68027392, umax = 20995769344>, | |
%23<umin = 68158464, umax = 21532509184>, | |
%28[<umin = 67765248, umax = 19922289664>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>], | |
%33[<umin = 1075847616, umax = 1075847616, udiv = 1075847616>, <umin = 1293968512, umax = 1293968512, udiv = 1293968512>, <umin = 1512089408, umax = 1512089408, udiv = 1512089408>, <umin = 1730210304, umax = 1730210304, udiv = 1730210304>, <umin = 1948331200, umax = 1948331200, udiv = 1948331200>, <umin = 2166452096, umax = 2166452096, udiv = 2166452096>, <umin = 2384572992, umax = 2384572992, udiv = 2384572992>, <umin = 2602693888, umax = 2602693888, udiv = 2602693888>, <umin = 2820814784, umax = 2820814784, udiv = 2820814784>, <umin = 3038935680, umax = 3038935680, udiv = 3038935680>, <umin = 3257056576, umax = 3257056576, udiv = 3257056576>, <umin = 3475177472, umax = 3475177472, udiv = 3475177472>, <umin = 3693298368, umax = 3693298368, udiv = 3693298368>, <umin = 3911419264, umax = 3911419264, udiv = 3911419264>, <umin = 4129540160, umax = 4129540160, udiv = 4129540160>, <umin = 4347661056, umax = 4347661056, udiv = 4347661056>, <umin = 4565781952, umax = 4565781952, udiv = 4565781952>, <umin = 4783902848, umax = 4783902848, udiv = 4783902848>, <umin = 5002023744, umax = 5002023744, udiv = 5002023744>, <umin = 5220144640, umax = 5220144640, udiv = 5220144640>, <umin = 5438265536, umax = 5438265536, udiv = 5438265536>, <umin = 5656386432, umax = 5656386432, udiv = 5656386432>, <umin = 5874507328, umax = 5874507328, udiv = 5874507328>, <umin = 6092628224, umax = 6092628224, udiv = 6092628224>, <umin = 6310749120, umax = 6310749120, udiv = 6310749120>, <umin = 6528870016, umax = 6528870016, udiv = 6528870016>, <umin = 6746990912, umax = 6746990912, udiv = 6746990912>, <umin = 6965111808, umax = 6965111808, udiv = 6965111808>, <umin = 7183232704, umax = 7183232704, udiv = 7183232704>, <umin = 7401353600, umax = 7401353600, udiv = 7401353600>, <umin = 7619474496, umax = 7619474496, udiv = 7619474496>, <umin = 7837595392, umax = 7837595392, udiv = 7837595392>], | |
%38<umin = 67896320, umax = 20459029504>, | |
%40<umin = 1, umax = 4095>, | |
%41<umin = 32, umax = 131040, udiv = 32>, | |
%42<umin = 1, umax = 4095>, | |
%43<umin = 32, umax = 131040, udiv = 32> | |
: index, index, index, index, index, index, index, index, index | |
%45 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<i64>> | |
%46 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%44#3) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<f32>> | |
%47 = flow.dispatch.workload.ordinal %44#5, 0 : index | |
%48 = flow.dispatch.workload.ordinal %44#6, 1 : index | |
%49 = flow.dispatch.workload.ordinal %44#7, 2 : index | |
%50 = flow.dispatch.workload.ordinal %44#8, 3 : index | |
%51 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67108864) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<?x32x?xi8>>{%47, %48} | |
%52 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%44#0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x1x?x32x128xf8E4M3FNUZ>>{%49} | |
%53 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%44#1) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x?x128xf8E4M3FNUZ>>{%50} | |
%54 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%44#2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<8x4x128x?xf8E4M3FNUZ>>{%48} | |
%55 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%44#4) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<1x?x32x8x4x128xf8E4M3FNUZ>>{%47} | |
%56 = flow.dispatch.tensor.load %51, offsets = [0, 0, 0], sizes = [%47, 32, %48], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x32x?xi8>>{%47, %48} -> tensor<?x32x?xi8> | |
%57 = arith.trunci %56 : tensor<?x32x?xi8> to tensor<?x32x?xi1> | |
%58 = flow.dispatch.tensor.load %45, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64> | |
%59 = flow.dispatch.tensor.load %52, offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 4, 1, %49, 32, 128], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x1x?x32x128xf8E4M3FNUZ>>{%49} -> tensor<8x4x1x?x32x128xf8E4M3FNUZ> | |
%60 = flow.dispatch.tensor.load %53, offsets = [0, 0, 0, 0], sizes = [8, 4, %50, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x?x128xf8E4M3FNUZ>>{%50} -> tensor<8x4x?x128xf8E4M3FNUZ> | |
%61 = flow.dispatch.tensor.load %54, offsets = [0, 0, 0, 0], sizes = [8, 4, 128, %48], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<8x4x128x?xf8E4M3FNUZ>>{%48} -> tensor<8x4x128x?xf8E4M3FNUZ> | |
%62 = flow.dispatch.tensor.load %46, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32> | |
%63 = tensor.empty(%47) : tensor<1x?x32x8x4x128xf8E4M3FNUZ> | |
%64 = tensor.empty(%47) : tensor<8x4x1x?x32x128xf32> | |
%65 = tensor.empty(%47, %48) : tensor<?x32x?xf8E4M3FNUZ> | |
%66 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%57, %58 : tensor<?x32x?xi1>, tensor<i64>) outs(%65 : tensor<?x32x?xf8E4M3FNUZ>) { | |
^bb0(%in: i1, %in_3: i64, %out: f8E4M3FNUZ): | |
%69 = linalg.index 2 : index | |
%70 = arith.index_cast %69 : index to i64 | |
%71 = arith.cmpi sge, %70, %in_3 : i64 | |
%72 = arith.ori %in, %71 : i1 | |
%73 = arith.select %72, %cst, %cst_0 : f32 | |
%74 = arith.truncf %73 : f32 to f8E4M3FNUZ | |
linalg.yield %74 : f8E4M3FNUZ | |
} -> tensor<?x32x?xf8E4M3FNUZ> | |
%expanded = tensor.expand_shape %66 [[0, 1], [2], [3]] output_shape [1, %47, 32, %48] : tensor<?x32x?xf8E4M3FNUZ> into tensor<1x?x32x?xf8E4M3FNUZ> | |
%67 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d7, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d3, d4, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5)>]} ins(%59, %60, %61, %39, %expanded : tensor<8x4x1x?x32x128xf8E4M3FNUZ>, tensor<8x4x?x128xf8E4M3FNUZ>, tensor<8x4x128x?xf8E4M3FNUZ>, f32, tensor<1x?x32x?xf8E4M3FNUZ>) outs(%64 : tensor<8x4x1x?x32x128xf32>) { | |
^bb0(%arg0: f32): | |
iree_linalg_ext.yield %arg0 : f32 | |
} -> tensor<8x4x1x?x32x128xf32> | |
%68 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> ()>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%67, %62 : tensor<8x4x1x?x32x128xf32>, tensor<f32>) outs(%63 : tensor<1x?x32x8x4x128xf8E4M3FNUZ>) { | |
^bb0(%in: f32, %in_3: f32, %out: f8E4M3FNUZ): | |
%69 = arith.divf %in, %in_3 : f32 | |
%70 = arith.cmpf ult, %69, %cst_1 : f32 | |
%71 = arith.select %70, %cst_1, %69 : f32 | |
%72 = arith.cmpf ugt, %71, %cst_2 : f32 | |
%73 = arith.select %72, %cst_2, %71 : f32 | |
%74 = arith.truncf %73 : f32 to f8E4M3FNUZ | |
linalg.yield %74 : f8E4M3FNUZ | |
} -> tensor<1x?x32x8x4x128xf8E4M3FNUZ> | |
flow.dispatch.tensor.store %68, %55, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, %47, 32, 8, 4, 128], strides = [1, 1, 1, 1, 1, 1] : tensor<1x?x32x8x4x128xf8E4M3FNUZ> -> !flow.dispatch.tensor<writeonly:tensor<1x?x32x8x4x128xf8E4M3FNUZ>>{%47} | |
return | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment