Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created February 14, 2025 02:27
Show Gist options
  • Save AmosLewis/e9a147afd209cf847f51027afa10a9bb to your computer and use it in GitHub Desktop.
Save AmosLewis/e9a147afd209cf847f51027afa10a9bb to your computer and use it in GitHub Desktop.
/home/chi/src/iree-build/tools/iree-compile f8_attn_chi_castf32_roctorch.mlir \
--iree-hip-target=gfx942 \
-o=f8_attn_chi_castf32_roctorch_0213.vmfb \
--iree-hal-target-device=hip \
--iree-dispatch-creation-enable-aggressive-fusion=true \
--iree-global-opt-propagate-transposes=true \
--iree-opt-aggressively-propagate-transposes=true \
--iree-opt-data-tiling=false \
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
--iree-hal-indirect-command-buffers=true \
--iree-stream-resource-memory-model=discrete \
--iree-hal-memoization=true \
--iree-opt-strip-assertions
failed to translate executables
f8_attn_chi_castf32_roctorch.mlir:45778:10: error: 'func.func' op failed to distribute
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed, %collapsed_1, %collapsed_2, %extracted, %arg4 : tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, f32, tensor<?x?xf8E4M3FNUZ>) outs(%cast : tensor<32x?x128xf32>) {
^
f8_attn_chi_castf32_roctorch.mlir:2706:12: note: called from
%914 = util.call @sharktank_masked_flash_attention_1_32_128_128_f8E4M3FNUZ_f32_f32(%909, %910, %911, %913, %912) : (tensor<1x32x?x128xf8E4M3FNUZ>, tensor<1x32x?x128xf8E4M3FNUZ>, tensor<1x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?xf8E4M3FNUZ>) -> tensor<1x32x?x128xf32>
^
f8_attn_chi_castf32_roctorch.mlir:45778:10: note: see current operation:
"func.func"() <{function_type = () -> (), sym_name = "prefill_bs1$async_dispatch_18_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic"}> ({
%0 = "arith.constant"() <{value = 7 : index}> : () -> index
%1 = "arith.constant"() <{value = 6 : index}> : () -> index
%2 = "arith.constant"() <{value = 5 : index}> : () -> index
%3 = "arith.constant"() <{value = 4 : index}> : () -> index
%4 = "arith.constant"() <{value = dense<0.000000e+00> : vector<1x8x1x1x8x1xf8E4M3FNUZ>}> : () -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%5 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x1x1x1x1x8xf8E4M3FNUZ>}> : () -> vector<2x1x1x1x1x8xf8E4M3FNUZ>
%6 = "arith.constant"() <{value = dense<0.000000e+00> : vector<8x2x1x1x1x4xf32>}> : () -> vector<8x2x1x1x1x4xf32>
%7 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x2x1x1x1x4xf32>}> : () -> vector<2x2x1x1x1x4xf32>
%8 = "arith.constant"() <{value = dense<0.000000e+00> : vector<8xf32>}> : () -> vector<8xf32>
%9 = "arith.constant"() <{value = dense<0xFF800000> : vector<2x1x4xf32>}> : () -> vector<2x1x4xf32>
%10 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x2x1x1x4x1xf8E4M3FNUZ>}> : () -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%11 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x4x1x1x1x8xf8E4M3FNUZ>}> : () -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%12 = "arith.constant"() <{value = 3 : index}> : () -> index
%13 = "arith.constant"() <{value = 2 : index}> : () -> index
%14 = "arith.constant"() <{value = dense<1.000000e+00> : vector<8x2x1x1x1x4xf32>}> : () -> vector<8x2x1x1x1x4xf32>
%15 = "arith.constant"() <{value = dense<2.400000e+02> : vector<2x8x1x1x4x1xf32>}> : () -> vector<2x8x1x1x4x1xf32>
%16 = "arith.constant"() <{value = dense<-2.400000e+02> : vector<2x8x1x1x4x1xf32>}> : () -> vector<2x8x1x1x4x1xf32>
%17 = "arith.constant"() <{value = dense<2.400000e+02> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%18 = "arith.constant"() <{value = dense<1.44269502> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%19 = "arith.constant"() <{value = dense<0.00416666688> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%20 = "arith.constant"() <{value = dense<0xFF800000> : vector<32x32xf32>}> : () -> vector<32x32xf32>
%21 = "arith.constant"() <{value = 0 : i64}> : () -> i64
%22 = "arith.constant"() <{value = 0 : i8}> : () -> i8
%23 = "arith.constant"() <{value = dense<0.000000e+00> : vector<32x32xf32>}> : () -> vector<32x32xf32>
%24 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%25 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x1x4xf32>}> : () -> vector<2x1x4xf32>
%26 = "arith.constant"() <{value = dense<-3.40282347E+38> : vector<2x1x4xf32>}> : () -> vector<2x1x4xf32>
%27 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x8x1x1x4x1xf32>}> : () -> vector<2x8x1x1x4x1xf32>
%28 = "arith.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ
%29 = "arith.constant"() <{value = 1.44269502 : f32}> : () -> f32
%30 = "arith.constant"() <{value = 1 : index}> : () -> index
%31 = "arith.constant"() <{value = 32 : index}> : () -> index
%32 = "arith.constant"() <{value = 67108864 : index}> : () -> index
%33 = "arith.constant"() <{value = 32 : i64}> : () -> i64
%34 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
%35 = "arith.constant"() <{value = 0 : index}> : () -> index
%36 = "gpu.thread_id"() <{dimension = #gpu<dim z>}> : () -> index
%37 = "gpu.thread_id"() <{dimension = #gpu<dim y>}> : () -> index
%38 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
%39 = "affine.linearize_index"(%36, %37, %38) <{disjoint, operandSegmentSizes = array<i32: 3, 0>, static_basis = array<i64: 1, 1, 64>}> : (index, index, index) -> index
%40 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%41 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%42 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%43 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%44 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%45 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
%46 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
%47 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
%48 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
%49 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 4 : index} : () -> i32
%50 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 5 : index} : () -> i32
%51 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 6 : index} : () -> i32
%52 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 7 : index} : () -> i32
%53 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 8 : index} : () -> i32
%54 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 9 : index} : () -> i32
%55 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 10 : index} : () -> i32
%56 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 11 : index} : () -> i32
%57 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 12 : index} : () -> i32
%58 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 13 : index} : () -> i32
%59 = "arith.extui"(%45) : (i32) -> i64
%60 = "arith.extui"(%46) : (i32) -> i64
%61 = "arith.shli"(%60, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%62 = "arith.ori"(%59, %61) : (i64, i64) -> i64
%63 = "arith.index_castui"(%62) : (i64) -> index
%64 = "arith.extui"(%47) : (i32) -> i64
%65 = "arith.extui"(%48) : (i32) -> i64
%66 = "arith.shli"(%65, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%67 = "arith.ori"(%64, %66) : (i64, i64) -> i64
%68 = "arith.index_castui"(%67) : (i64) -> index
%69 = "arith.extui"(%49) : (i32) -> i64
%70 = "arith.extui"(%50) : (i32) -> i64
%71 = "arith.shli"(%70, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%72 = "arith.ori"(%69, %71) : (i64, i64) -> i64
%73 = "arith.index_castui"(%72) : (i64) -> index
%74 = "arith.extui"(%51) : (i32) -> i64
%75 = "arith.extui"(%52) : (i32) -> i64
%76 = "arith.shli"(%75, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%77 = "arith.ori"(%74, %76) : (i64, i64) -> i64
%78 = "arith.index_castui"(%77) {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : (i64) -> index
%79 = "arith.extui"(%53) : (i32) -> i64
%80 = "arith.extui"(%54) : (i32) -> i64
%81 = "arith.shli"(%80, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%82 = "arith.ori"(%79, %81) : (i64, i64) -> i64
%83 = "arith.index_castui"(%82) : (i64) -> index
%84 = "arith.index_castui"(%55) : (i32) -> index
%85 = "arith.bitcast"(%56) : (i32) -> f32
%86 = "arith.index_castui"(%57) : (i32) -> index
%87 = "arith.index_castui"(%58) : (i32) -> index
%88:8 = "util.assume.int"(%63, %68, %73, %78, %83, %84, %86, %87) <{assumptions = [[#util.int.assumption<umin = 68027392, umax = 20995769344>], [#util.int.assumption<umin = 68158464, umax = 21532509184>], [#util.int.assumption<umin = 68355072, umax = 22337618944>], [#util.int.assumption<umin = 1075847616, umax = 1075847616, udiv = 1075847616>, #util.int.assumption<umin = 1293968512, umax = 1293968512, udiv = 1293968512>, #util.int.assumption<umin = 1512089408, umax = 1512089408, udiv = 1512089408>, #util.int.assumption<umin = 1730210304, umax = 1730210304, udiv = 1730210304>, #util.int.assumption<umin = 1948331200, umax = 1948331200, udiv = 1948331200>, #util.int.assumption<umin = 2166452096, umax = 2166452096, udiv = 2166452096>, #util.int.assumption<umin = 2384572992, umax = 2384572992, udiv = 2384572992>, #util.int.assumption<umin = 2602693888, umax = 2602693888, udiv = 2602693888>, #util.int.assumption<umin = 2820814784, umax = 2820814784, udiv = 2820814784>, #util.int.assumption<umin = 3038935680, umax = 3038935680, udiv = 3038935680>, #util.int.assumption<umin = 3257056576, umax = 3257056576, udiv = 3257056576>, #util.int.assumption<umin = 3475177472, umax = 3475177472, udiv = 3475177472>, #util.int.assumption<umin = 3693298368, umax = 3693298368, udiv = 3693298368>, #util.int.assumption<umin = 3911419264, umax = 3911419264, udiv = 3911419264>, #util.int.assumption<umin = 4129540160, umax = 4129540160, udiv = 4129540160>, #util.int.assumption<umin = 4347661056, umax = 4347661056, udiv = 4347661056>, #util.int.assumption<umin = 4565781952, umax = 4565781952, udiv = 4565781952>, #util.int.assumption<umin = 4783902848, umax = 4783902848, udiv = 4783902848>, #util.int.assumption<umin = 5002023744, umax = 5002023744, udiv = 5002023744>, #util.int.assumption<umin = 5220144640, umax = 5220144640, udiv = 5220144640>, #util.int.assumption<umin = 5438265536, umax = 5438265536, udiv = 5438265536>, #util.int.assumption<umin = 5656386432, umax = 5656386432, udiv = 5656386432>, #util.int.assumption<umin = 5874507328, umax = 5874507328, udiv = 5874507328>, #util.int.assumption<umin = 6092628224, umax = 6092628224, udiv = 6092628224>, #util.int.assumption<umin = 6310749120, umax = 6310749120, udiv = 6310749120>, #util.int.assumption<umin = 6528870016, umax = 6528870016, udiv = 6528870016>, #util.int.assumption<umin = 6746990912, umax = 6746990912, udiv = 6746990912>, #util.int.assumption<umin = 6965111808, umax = 6965111808, udiv = 6965111808>, #util.int.assumption<umin = 7183232704, umax = 7183232704, udiv = 7183232704>, #util.int.assumption<umin = 7401353600, umax = 7401353600, udiv = 7401353600>, #util.int.assumption<umin = 7619474496, umax = 7619474496, udiv = 7619474496>, #util.int.assumption<umin = 7837595392, umax = 7837595392, udiv = 7837595392>], [#util.int.assumption<umin = 67896320, umax = 20459029504>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>], [#util.int.assumption<umin = 1, umax = 4095>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>]]}> : (index, index, index, index, index, index, index, index) -> (index, index, index, index, index, index, index, index)
%89 = "hal.interface.binding.subspan"(%35) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<i64, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%89) <{alignment = 64 : i32}> : (memref<i64, #hal.descriptor_type<storage_buffer>>) -> ()
%90 = "hal.interface.binding.subspan"(%88#3) {alignment = 64 : index, binding = 2 : index, descriptor_flags = 1 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%90) <{alignment = 64 : i32}> : (memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%91 = "flow.dispatch.workload.ordinal"(%88#5) <{ordinal = 0 : index}> : (index) -> index
%92 = "flow.dispatch.workload.ordinal"(%88#6) <{ordinal = 1 : index}> : (index) -> index
%93 = "flow.dispatch.workload.ordinal"(%88#6) <{ordinal = 2 : index}> : (index) -> index
%94 = "flow.dispatch.workload.ordinal"(%88#7) <{ordinal = 3 : index}> : (index) -> index
%95 = "hal.interface.binding.subspan"(%32, %92, %91) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 2>} : (index, index, index) -> memref<?x32x?xi8, strided<[?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%95) <{alignment = 64 : i32}> : (memref<?x32x?xi8, strided<[?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%96 = "hal.interface.binding.subspan"(%88#0, %93) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%96) <{alignment = 1 : i32}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%97 = "arith.divsi"(%94, %31) : (index, index) -> index
%98 = "hal.interface.binding.subspan"(%88#1, %97) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%98) <{alignment = 1 : i32}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%99 = "arith.divsi"(%91, %31) : (index, index) -> index
%100 = "hal.interface.binding.subspan"(%88#2, %99) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x32x8x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%100) <{alignment = 1 : i32}> : (memref<?x32x8x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%101 = "hal.interface.binding.subspan"(%88#4, %92) {alignment = 64 : index, binding = 3 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%101) <{alignment = 1 : i32}> : (memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
"scf.forall"(%93) <{mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>], operandSegmentSizes = array<i32: 0, 1, 0, 0>, staticLowerBound = array<i64: 0, 0, 0>, staticStep = array<i64: 1, 1, 1>, staticUpperBound = array<i64: 8, 4, -9223372036854775808>}> ({
^bb0(%arg0: index, %arg1: index, %arg2: index):
"gpu.barrier"() : () -> ()
%102 = "memref.subview"(%101, %arg2, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 3, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808, 0, -9223372036854775808, -9223372036854775808, 0>, static_sizes = array<i64: 1, 1, 32, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index) -> memref<1x1x32x1x1x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%103 = "memref.subview"(%102) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 32, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<1x1x32x1x1x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%104:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%105:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%106 = "affine.linearize_index"(%104#2, %35, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%107 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%108 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %106, %107, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%109 = "affine.linearize_index"(%104#2, %30, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%110 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%111 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %109, %110, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%112 = "affine.linearize_index"(%104#2, %13, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%113 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%114 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %112, %113, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%115 = "affine.linearize_index"(%104#2, %12, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%116 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%117 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %115, %116, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%118 = "arith.mulf"(%85, %29) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%119:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%120:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%121 = "affine.linearize_index"(%119#2, %35, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%122 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%108, %43, %121, %122) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%123 = "affine.linearize_index"(%119#2, %30, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%124 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%111, %43, %123, %124) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%125 = "affine.linearize_index"(%119#2, %13, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%126 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%114, %43, %125, %126) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%127 = "affine.linearize_index"(%119#2, %12, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%128 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%117, %43, %127, %128) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
"gpu.barrier"() : () -> ()
%129:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%130:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%131 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%132 = "affine.linearize_index"(%129#1, %35, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%133 = "vector.transfer_read"(%43, %131, %132, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%134 = "vector.insert_strided_slice"(%133, %11) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%135 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%136 = "affine.linearize_index"(%129#1, %30, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%137 = "vector.transfer_read"(%43, %135, %136, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%138 = "vector.insert_strided_slice"(%137, %134) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%139 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%140 = "affine.linearize_index"(%129#1, %13, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%141 = "vector.transfer_read"(%43, %139, %140, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%142 = "vector.insert_strided_slice"(%141, %138) <{offsets = [0, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%143 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%144 = "affine.linearize_index"(%129#1, %12, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%145 = "vector.transfer_read"(%43, %143, %144, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%146 = "vector.insert_strided_slice"(%145, %142) <{offsets = [0, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%147 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%148 = "affine.linearize_index"(%129#1, %35, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%149 = "vector.transfer_read"(%43, %147, %148, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%150 = "vector.insert_strided_slice"(%149, %146) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%151 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%152 = "affine.linearize_index"(%129#1, %30, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%153 = "vector.transfer_read"(%43, %151, %152, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%154 = "vector.insert_strided_slice"(%153, %150) <{offsets = [1, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%155 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%156 = "affine.linearize_index"(%129#1, %13, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%157 = "vector.transfer_read"(%43, %155, %156, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%158 = "vector.insert_strided_slice"(%157, %154) <{offsets = [1, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%159 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%160 = "affine.linearize_index"(%129#1, %12, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%161 = "vector.transfer_read"(%43, %159, %160, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%162 = "vector.insert_strided_slice"(%161, %158) <{offsets = [1, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%163 = "vector.transfer_read"(%89, %21) <{in_bounds = [], operandSegmentSizes = array<i32: 1, 0, 1, 0>, permutation_map = affine_map<() -> ()>}> : (memref<i64, #hal.descriptor_type<storage_buffer>>, i64) -> vector<i64>
%164 = "iree_vector_ext.to_simd"(%163) : (vector<i64>) -> vector<i64>
%165 = "vector.broadcast"(%164) : (vector<i64>) -> vector<32x32xi64>
%166 = "vector.step"() : () -> vector<32xindex>
%167 = "vector.broadcast"(%118) : (f32) -> vector<2x2x1x1x4x1xf32>
%168:3 = "scf.for"(%35, %97, %30, %26, %25, %27) ({
^bb0(%arg3: index, %arg4: vector<2x1x4xf32>, %arg5: vector<2x1x4xf32>, %arg6: vector<2x8x1x1x4x1xf32>):
"gpu.barrier"() : () -> ()
%325 = "memref.subview"(%100, %arg3, %arg0) <{operandSegmentSizes = array<i32: 1, 2, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, -9223372036854775808, 0>, static_sizes = array<i64: 1, 32, 1, 128>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<?x32x8x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> memref<1x32x1x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%326 = "memref.subview"(%325) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: 1, 32, 1, 128>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<1x32x1x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%327:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%328:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%329 = "affine.linearize_index"(%327#2, %35, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%330 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%331 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %329, %330, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%332 = "affine.linearize_index"(%327#2, %30, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%333 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%334 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %332, %333, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%335 = "affine.linearize_index"(%327#2, %13, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%336 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%337 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %335, %336, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%338 = "affine.linearize_index"(%327#2, %12, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%339 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%340 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %338, %339, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%341:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%342:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%343 = "affine.linearize_index"(%341#2, %35, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%344 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%345 = "vector.transfer_read"(%326, %343, %344, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%346 = "affine.linearize_index"(%341#2, %30, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%347 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%348 = "vector.transfer_read"(%326, %346, %347, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%349 = "affine.linearize_index"(%341#2, %13, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%350 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%351 = "vector.transfer_read"(%326, %349, %350, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%352 = "affine.linearize_index"(%341#2, %12, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%353 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%354 = "vector.transfer_read"(%326, %352, %353, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%355 = "affine.linearize_index"(%arg3, %35, %99) <{disjoint, operandSegmentSizes = array<i32: 2, 1>, static_basis = array<i64: -9223372036854775808, 32>}> : (index, index, index) -> index
%356 = "vector.transfer_read"(%95, %arg2, %35, %355, %22) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 3, 1, 0>, permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}> : (memref<?x32x?xi8, strided<[?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, i8) -> vector<32x32xi8>
%357 = "arith.trunci"(%356) : (vector<32x32xi8>) -> vector<32x32xi1>
%358 = "vector.broadcast"(%355) : (index) -> vector<32xindex>
%359 = "arith.addi"(%358, %166) <{overflowFlags = #arith.overflow<none>}> : (vector<32xindex>, vector<32xindex>) -> vector<32xindex>
%360 = "arith.index_cast"(%359) : (vector<32xindex>) -> vector<32xi64>
%361 = "vector.broadcast"(%360) : (vector<32xi64>) -> vector<32x32xi64>
%362 = "arith.cmpi"(%361, %165) <{predicate = 5 : i64}> : (vector<32x32xi64>, vector<32x32xi64>) -> vector<32x32xi1>
%363 = "arith.ori"(%357, %362) : (vector<32x32xi1>, vector<32x32xi1>) -> vector<32x32xi1>
%364 = "arith.select"(%363, %20, %23) : (vector<32x32xi1>, vector<32x32xf32>, vector<32x32xf32>) -> vector<32x32xf32>
%365 = "arith.truncf"(%364) : (vector<32x32xf32>) -> vector<32x32xf8E4M3FNUZ>
"vector.transfer_write"(%365, %44, %35, %35, %35) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 3, 0>, permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}> : (vector<32x32xf8E4M3FNUZ>, memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, index) -> ()
%366 = "memref.expand_shape"(%44) <{reassociation = [[0, 1], [2], [3, 4]], static_output_shape = array<i64: 1, 1, 32, 1, 32>}> : (memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> memref<1x1x32x1x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%367 = "memref.subview"(%366) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 32, 1, 32>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (memref<1x1x32x1x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>
%368:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%369:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%370 = "affine.linearize_index"(%368#2, %35, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%371 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%331, %42, %370, %371) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%372 = "affine.linearize_index"(%368#2, %30, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%373 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%334, %42, %372, %373) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%374 = "affine.linearize_index"(%368#2, %13, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%375 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%337, %42, %374, %375) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%376 = "affine.linearize_index"(%368#2, %12, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%377 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%340, %42, %376, %377) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%378:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%379:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%380 = "affine.linearize_index"(%378#2, %35, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%381 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%345, %41, %380, %381) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%382 = "affine.linearize_index"(%378#2, %30, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%383 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%348, %41, %382, %383) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%384 = "affine.linearize_index"(%378#2, %13, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%385 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%351, %41, %384, %385) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%386 = "affine.linearize_index"(%378#2, %12, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%387 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%354, %41, %386, %387) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%388:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%389:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%390 = "affine.linearize_index"(%388#2, %35, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%391 = "affine.linearize_index"(%388#1, %35, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%392 = "vector.transfer_read"(%367, %390, %391, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%393 = "vector.insert_strided_slice"(%392, %10) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%394 = "affine.linearize_index"(%388#2, %35, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%395 = "affine.linearize_index"(%388#1, %30, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%396 = "vector.transfer_read"(%367, %394, %395, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%397 = "vector.insert_strided_slice"(%396, %393) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%398 = "affine.linearize_index"(%388#2, %30, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%399 = "affine.linearize_index"(%388#1, %35, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%400 = "vector.transfer_read"(%367, %398, %399, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%401 = "vector.insert_strided_slice"(%400, %397) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%402 = "affine.linearize_index"(%388#2, %30, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%403 = "affine.linearize_index"(%388#1, %30, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%404 = "vector.transfer_read"(%367, %402, %403, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%405 = "vector.insert_strided_slice"(%404, %401) <{offsets = [1, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%406 = "arith.extf"(%405) : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf32>
%407 = "arith.mulf"(%406, %18) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
"gpu.barrier"() : () -> ()
%408:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%409:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%410 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%411 = "affine.linearize_index"(%408#1, %35, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%412 = "vector.transfer_read"(%42, %410, %411, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%413 = "vector.insert_strided_slice"(%412, %11) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%414 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%415 = "affine.linearize_index"(%408#1, %30, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%416 = "vector.transfer_read"(%42, %414, %415, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%417 = "vector.insert_strided_slice"(%416, %413) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%418 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%419 = "affine.linearize_index"(%408#1, %13, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%420 = "vector.transfer_read"(%42, %418, %419, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%421 = "vector.insert_strided_slice"(%420, %417) <{offsets = [0, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%422 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%423 = "affine.linearize_index"(%408#1, %12, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%424 = "vector.transfer_read"(%42, %422, %423, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%425 = "vector.insert_strided_slice"(%424, %421) <{offsets = [0, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%426 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%427 = "affine.linearize_index"(%408#1, %35, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%428 = "vector.transfer_read"(%42, %426, %427, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%429 = "vector.insert_strided_slice"(%428, %425) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%430 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%431 = "affine.linearize_index"(%408#1, %30, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%432 = "vector.transfer_read"(%42, %430, %431, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%433 = "vector.insert_strided_slice"(%432, %429) <{offsets = [1, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%434 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%435 = "affine.linearize_index"(%408#1, %13, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%436 = "vector.transfer_read"(%42, %434, %435, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%437 = "vector.insert_strided_slice"(%436, %433) <{offsets = [1, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%438 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%439 = "affine.linearize_index"(%408#1, %12, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%440 = "vector.transfer_read"(%42, %438, %439, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%441 = "vector.insert_strided_slice"(%440, %437) <{offsets = [1, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%442 = "vector.extract"(%24) <{static_position = array<i64: 0, 0>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%443 = "vector.extract"(%162) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%444 = "vector.extract"(%441) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%445 = "vector.shape_cast"(%443) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%446 = "vector.shape_cast"(%444) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%447 = "vector.shape_cast"(%442) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%448 = "amdgpu.mfma"(%445, %446, %447) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%449 = "vector.extract"(%162) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%450 = "vector.extract"(%441) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%451 = "vector.shape_cast"(%449) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%452 = "vector.shape_cast"(%450) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%453 = "amdgpu.mfma"(%451, %452, %448) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%454 = "vector.extract"(%162) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%455 = "vector.extract"(%441) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%456 = "vector.shape_cast"(%454) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%457 = "vector.shape_cast"(%455) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%458 = "amdgpu.mfma"(%456, %457, %453) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%459 = "vector.extract"(%162) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%460 = "vector.extract"(%441) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%461 = "vector.shape_cast"(%459) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%462 = "vector.shape_cast"(%460) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%463 = "amdgpu.mfma"(%461, %462, %458) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%464 = "vector.shape_cast"(%463) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%465 = "vector.insert"(%464, %24) <{static_position = array<i64: 0, 0>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%466 = "vector.extract"(%24) <{static_position = array<i64: 0, 1>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%467 = "vector.extract"(%162) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%468 = "vector.extract"(%441) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%469 = "vector.shape_cast"(%467) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%470 = "vector.shape_cast"(%468) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%471 = "vector.shape_cast"(%466) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%472 = "amdgpu.mfma"(%469, %470, %471) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%473 = "vector.extract"(%162) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%474 = "vector.extract"(%441) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%475 = "vector.shape_cast"(%473) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%476 = "vector.shape_cast"(%474) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%477 = "amdgpu.mfma"(%475, %476, %472) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%478 = "vector.extract"(%162) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%479 = "vector.extract"(%441) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%480 = "vector.shape_cast"(%478) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%481 = "vector.shape_cast"(%479) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%482 = "amdgpu.mfma"(%480, %481, %477) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%483 = "vector.extract"(%162) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%484 = "vector.extract"(%441) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%485 = "vector.shape_cast"(%483) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%486 = "vector.shape_cast"(%484) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%487 = "amdgpu.mfma"(%485, %486, %482) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%488 = "vector.shape_cast"(%487) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%489 = "vector.insert"(%488, %465) <{static_position = array<i64: 0, 1>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%490 = "vector.extract"(%24) <{static_position = array<i64: 1, 0>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%491 = "vector.extract"(%162) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%492 = "vector.extract"(%441) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%493 = "vector.shape_cast"(%491) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%494 = "vector.shape_cast"(%492) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%495 = "vector.shape_cast"(%490) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%496 = "amdgpu.mfma"(%493, %494, %495) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%497 = "vector.extract"(%162) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%498 = "vector.extract"(%441) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%499 = "vector.shape_cast"(%497) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%500 = "vector.shape_cast"(%498) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%501 = "amdgpu.mfma"(%499, %500, %496) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%502 = "vector.extract"(%162) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%503 = "vector.extract"(%441) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%504 = "vector.shape_cast"(%502) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%505 = "vector.shape_cast"(%503) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%506 = "amdgpu.mfma"(%504, %505, %501) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%507 = "vector.extract"(%162) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%508 = "vector.extract"(%441) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%509 = "vector.shape_cast"(%507) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%510 = "vector.shape_cast"(%508) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%511 = "amdgpu.mfma"(%509, %510, %506) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%512 = "vector.shape_cast"(%511) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%513 = "vector.insert"(%512, %489) <{static_position = array<i64: 1, 0>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%514 = "vector.extract"(%24) <{static_position = array<i64: 1, 1>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%515 = "vector.extract"(%162) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%516 = "vector.extract"(%441) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%517 = "vector.shape_cast"(%515) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%518 = "vector.shape_cast"(%516) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%519 = "vector.shape_cast"(%514) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%520 = "amdgpu.mfma"(%517, %518, %519) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%521 = "vector.extract"(%162) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%522 = "vector.extract"(%441) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%523 = "vector.shape_cast"(%521) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%524 = "vector.shape_cast"(%522) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%525 = "amdgpu.mfma"(%523, %524, %520) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%526 = "vector.extract"(%162) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%527 = "vector.extract"(%441) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%528 = "vector.shape_cast"(%526) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%529 = "vector.shape_cast"(%527) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%530 = "amdgpu.mfma"(%528, %529, %525) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%531 = "vector.extract"(%162) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%532 = "vector.extract"(%441) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%533 = "vector.shape_cast"(%531) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%534 = "vector.shape_cast"(%532) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%535 = "amdgpu.mfma"(%533, %534, %530) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%536 = "vector.shape_cast"(%535) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%537 = "vector.insert"(%536, %513) <{static_position = array<i64: 1, 1>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%538 = "arith.mulf"(%167, %537) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%539 = "arith.addf"(%538, %19) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%540 = "arith.addf"(%539, %407) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%541 = "vector.multi_reduction"(%540, %9) <{kind = #vector.kind<maximumf>, reduction_dims = array<i64: 1, 3, 5>}> : (vector<2x2x1x1x4x1xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%542 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%543 = "gpu.subgroup_reduce"(%542) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%544 = "vector.insert"(%543, %8) <{static_position = array<i64: 0>}> : (f32, vector<8xf32>) -> vector<8xf32>
%545 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%546 = "gpu.subgroup_reduce"(%545) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%547 = "vector.insert"(%546, %544) <{static_position = array<i64: 1>}> : (f32, vector<8xf32>) -> vector<8xf32>
%548 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%549 = "gpu.subgroup_reduce"(%548) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%550 = "vector.insert"(%549, %547) <{static_position = array<i64: 2>}> : (f32, vector<8xf32>) -> vector<8xf32>
%551 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%552 = "gpu.subgroup_reduce"(%551) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%553 = "vector.insert"(%552, %550) <{static_position = array<i64: 3>}> : (f32, vector<8xf32>) -> vector<8xf32>
%554 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%555 = "gpu.subgroup_reduce"(%554) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%556 = "vector.insert"(%555, %553) <{static_position = array<i64: 4>}> : (f32, vector<8xf32>) -> vector<8xf32>
%557 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%558 = "gpu.subgroup_reduce"(%557) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%559 = "vector.insert"(%558, %556) <{static_position = array<i64: 5>}> : (f32, vector<8xf32>) -> vector<8xf32>
%560 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%561 = "gpu.subgroup_reduce"(%560) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%562 = "vector.insert"(%561, %559) <{static_position = array<i64: 6>}> : (f32, vector<8xf32>) -> vector<8xf32>
%563 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%564 = "gpu.subgroup_reduce"(%563) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%565 = "vector.insert"(%564, %562) <{static_position = array<i64: 7>}> : (f32, vector<8xf32>) -> vector<8xf32>
%566 = "vector.shape_cast"(%565) : (vector<8xf32>) -> vector<2x1x4xf32>
%567 = "arith.maximumf"(%566, %arg4) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%568 = "arith.subf"(%arg4, %567) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%569 = "math.exp2"(%568) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>) -> vector<2x1x4xf32>
%570 = "arith.mulf"(%569, %arg5) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%571 = "vector.extract"(%567) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%572 = "vector.broadcast"(%571) : (vector<4xf32>) -> vector<1x4xf32>
%573 = "vector.insert"(%572, %7) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%574 = "vector.extract"(%567) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%575 = "vector.broadcast"(%574) : (vector<4xf32>) -> vector<1x4xf32>
%576 = "vector.insert"(%575, %573) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%577 = "vector.extract"(%567) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%578 = "vector.broadcast"(%577) : (vector<4xf32>) -> vector<1x4xf32>
%579 = "vector.insert"(%578, %576) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%580 = "vector.extract"(%567) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%581 = "vector.broadcast"(%580) : (vector<4xf32>) -> vector<1x4xf32>
%582 = "vector.insert"(%581, %579) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%583 = "vector.transpose"(%582) <{permutation = array<i64: 1, 0, 3, 2, 5, 4>}> : (vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x4x1xf32>
%584 = "arith.subf"(%540, %583) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%585 = "math.exp2"(%584) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%586 = "vector.multi_reduction"(%585, %25) <{kind = #vector.kind<add>, reduction_dims = array<i64: 1, 3, 5>}> : (vector<2x2x1x1x4x1xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%587 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%588 = "gpu.subgroup_reduce"(%587) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%589 = "vector.insert"(%588, %8) <{static_position = array<i64: 0>}> : (f32, vector<8xf32>) -> vector<8xf32>
%590 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%591 = "gpu.subgroup_reduce"(%590) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%592 = "vector.insert"(%591, %589) <{static_position = array<i64: 1>}> : (f32, vector<8xf32>) -> vector<8xf32>
%593 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%594 = "gpu.subgroup_reduce"(%593) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%595 = "vector.insert"(%594, %592) <{static_position = array<i64: 2>}> : (f32, vector<8xf32>) -> vector<8xf32>
%596 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%597 = "gpu.subgroup_reduce"(%596) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%598 = "vector.insert"(%597, %595) <{static_position = array<i64: 3>}> : (f32, vector<8xf32>) -> vector<8xf32>
%599 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%600 = "gpu.subgroup_reduce"(%599) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%601 = "vector.insert"(%600, %598) <{static_position = array<i64: 4>}> : (f32, vector<8xf32>) -> vector<8xf32>
%602 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%603 = "gpu.subgroup_reduce"(%602) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%604 = "vector.insert"(%603, %601) <{static_position = array<i64: 5>}> : (f32, vector<8xf32>) -> vector<8xf32>
%605 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%606 = "gpu.subgroup_reduce"(%605) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%607 = "vector.insert"(%606, %604) <{static_position = array<i64: 6>}> : (f32, vector<8xf32>) -> vector<8xf32>
%608 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%609 = "gpu.subgroup_reduce"(%608) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%610 = "vector.insert"(%609, %607) <{static_position = array<i64: 7>}> : (f32, vector<8xf32>) -> vector<8xf32>
%611 = "vector.shape_cast"(%610) : (vector<8xf32>) -> vector<2x1x4xf32>
%612 = "arith.addf"(%611, %570) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%613 = "arith.minimumf"(%585, %17) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%614 = "arith.truncf"(%613) : (vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%615 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%616 = "vector.broadcast"(%615) : (vector<4xf32>) -> vector<1x4xf32>
%617 = "vector.insert"(%616, %6) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%618 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%619 = "vector.broadcast"(%618) : (vector<4xf32>) -> vector<1x4xf32>
%620 = "vector.insert"(%619, %617) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%621 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%622 = "vector.broadcast"(%621) : (vector<4xf32>) -> vector<1x4xf32>
%623 = "vector.insert"(%622, %620) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%624 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%625 = "vector.broadcast"(%624) : (vector<4xf32>) -> vector<1x4xf32>
%626 = "vector.insert"(%625, %623) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%627 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%628 = "vector.broadcast"(%627) : (vector<4xf32>) -> vector<1x4xf32>
%629 = "vector.insert"(%628, %626) <{static_position = array<i64: 2, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%630 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%631 = "vector.broadcast"(%630) : (vector<4xf32>) -> vector<1x4xf32>
%632 = "vector.insert"(%631, %629) <{static_position = array<i64: 2, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%633 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%634 = "vector.broadcast"(%633) : (vector<4xf32>) -> vector<1x4xf32>
%635 = "vector.insert"(%634, %632) <{static_position = array<i64: 3, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%636 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%637 = "vector.broadcast"(%636) : (vector<4xf32>) -> vector<1x4xf32>
%638 = "vector.insert"(%637, %635) <{static_position = array<i64: 3, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%639 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%640 = "vector.broadcast"(%639) : (vector<4xf32>) -> vector<1x4xf32>
%641 = "vector.insert"(%640, %638) <{static_position = array<i64: 4, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%642 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%643 = "vector.broadcast"(%642) : (vector<4xf32>) -> vector<1x4xf32>
%644 = "vector.insert"(%643, %641) <{static_position = array<i64: 4, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%645 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%646 = "vector.broadcast"(%645) : (vector<4xf32>) -> vector<1x4xf32>
%647 = "vector.insert"(%646, %644) <{static_position = array<i64: 5, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%648 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%649 = "vector.broadcast"(%648) : (vector<4xf32>) -> vector<1x4xf32>
%650 = "vector.insert"(%649, %647) <{static_position = array<i64: 5, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%651 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%652 = "vector.broadcast"(%651) : (vector<4xf32>) -> vector<1x4xf32>
%653 = "vector.insert"(%652, %650) <{static_position = array<i64: 6, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%654 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%655 = "vector.broadcast"(%654) : (vector<4xf32>) -> vector<1x4xf32>
%656 = "vector.insert"(%655, %653) <{static_position = array<i64: 6, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%657 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%658 = "vector.broadcast"(%657) : (vector<4xf32>) -> vector<1x4xf32>
%659 = "vector.insert"(%658, %656) <{static_position = array<i64: 7, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%660 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%661 = "vector.broadcast"(%660) : (vector<4xf32>) -> vector<1x4xf32>
%662 = "vector.insert"(%661, %659) <{static_position = array<i64: 7, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%663 = "vector.transpose"(%662) <{permutation = array<i64: 1, 0, 3, 2, 5, 4>}> : (vector<8x2x1x1x1x4xf32>) -> vector<2x8x1x1x4x1xf32>
%664 = "arith.mulf"(%663, %arg6) <{fastmath = #arith.fastmath<none>}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%665:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%666:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%667 = "affine.linearize_index"(%665#2, %35, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%668 = "affine.linearize_index"(%665#1, %35, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%669 = "vector.extract"(%614) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%669, %40, %667, %668) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%670 = "affine.linearize_index"(%665#2, %35, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%671 = "affine.linearize_index"(%665#1, %30, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%672 = "vector.extract"(%614) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%672, %40, %670, %671) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%673 = "affine.linearize_index"(%665#2, %30, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%674 = "affine.linearize_index"(%665#1, %35, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%675 = "vector.extract"(%614) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%675, %40, %673, %674) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%676 = "affine.linearize_index"(%665#2, %30, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%677 = "affine.linearize_index"(%665#1, %30, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%678 = "vector.extract"(%614) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%678, %40, %676, %677) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
"gpu.barrier"() : () -> ()
%679:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%680:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%681 = "affine.linearize_index"(%679#2, %35, %35, %680#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%682 = "affine.linearize_index"(%679#1, %35, %35, %680#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%683 = "vector.transfer_read"(%40, %681, %682, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%684 = "vector.insert_strided_slice"(%683, %5) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<2x1x1x1x1x8xf8E4M3FNUZ>
%685 = "affine.linearize_index"(%679#2, %30, %35, %680#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%686 = "affine.linearize_index"(%679#1, %35, %35, %680#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%687 = "vector.transfer_read"(%40, %685, %686, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%688 = "vector.insert_strided_slice"(%687, %684) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<2x1x1x1x1x8xf8E4M3FNUZ>
%689:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%690:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%691 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%692 = "affine.linearize_index"(%689#1, %35, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%693 = "vector.transfer_read"(%41, %691, %692, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%694 = "vector.insert_strided_slice"(%693, %4) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%695 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%696 = "affine.linearize_index"(%689#1, %30, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%697 = "vector.transfer_read"(%41, %695, %696, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%698 = "vector.insert_strided_slice"(%697, %694) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%699 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%700 = "affine.linearize_index"(%689#1, %13, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%701 = "vector.transfer_read"(%41, %699, %700, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%702 = "vector.insert_strided_slice"(%701, %698) <{offsets = [0, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%703 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%704 = "affine.linearize_index"(%689#1, %12, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%705 = "vector.transfer_read"(%41, %703, %704, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%706 = "vector.insert_strided_slice"(%705, %702) <{offsets = [0, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%707 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%708 = "affine.linearize_index"(%689#1, %3, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%709 = "vector.transfer_read"(%41, %707, %708, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%710 = "vector.insert_strided_slice"(%709, %706) <{offsets = [0, 4, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%711 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%712 = "affine.linearize_index"(%689#1, %2, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%713 = "vector.transfer_read"(%41, %711, %712, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%714 = "vector.insert_strided_slice"(%713, %710) <{offsets = [0, 5, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%715 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%716 = "affine.linearize_index"(%689#1, %1, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%717 = "vector.transfer_read"(%41, %715, %716, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%718 = "vector.insert_strided_slice"(%717, %714) <{offsets = [0, 6, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%719 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%720 = "affine.linearize_index"(%689#1, %0, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%721 = "vector.transfer_read"(%41, %719, %720, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%722 = "vector.insert_strided_slice"(%721, %718) <{offsets = [0, 7, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%723 = "vector.extract"(%664) <{static_position = array<i64: 0, 0>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%724 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%725 = "vector.extract"(%722) <{static_position = array<i64: 0, 0>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%726 = "vector.shape_cast"(%724) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%727 = "vector.shape_cast"(%725) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%728 = "vector.shape_cast"(%723) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%729 = "amdgpu.mfma"(%726, %727, %728) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%730 = "vector.shape_cast"(%729) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%731 = "vector.insert"(%730, %27) <{static_position = array<i64: 0, 0>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%732 = "vector.extract"(%664) <{static_position = array<i64: 0, 1>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%733 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%734 = "vector.extract"(%722) <{static_position = array<i64: 0, 1>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%735 = "vector.shape_cast"(%733) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%736 = "vector.shape_cast"(%734) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%737 = "vector.shape_cast"(%732) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%738 = "amdgpu.mfma"(%735, %736, %737) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%739 = "vector.shape_cast"(%738) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%740 = "vector.insert"(%739, %731) <{static_position = array<i64: 0, 1>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%741 = "vector.extract"(%664) <{static_position = array<i64: 0, 2>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%742 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%743 = "vector.extract"(%722) <{static_position = array<i64: 0, 2>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%744 = "vector.shape_cast"(%742) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%745 = "vector.shape_cast"(%743) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%746 = "vector.shape_cast"(%741) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%747 = "amdgpu.mfma"(%744, %745, %746) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%748 = "vector.shape_cast"(%747) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%749 = "vector.insert"(%748, %740) <{static_position = array<i64: 0, 2>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%750 = "vector.extract"(%664) <{static_position = array<i64: 0, 3>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%751 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%752 = "vector.extract"(%722) <{static_position = array<i64: 0, 3>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%753 = "vector.shape_cast"(%751) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%754 = "vector.shape_cast"(%752) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%755 = "vector.shape_cast"(%750) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%756 = "amdgpu.mfma"(%753, %754, %755) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%757 = "vector.shape_cast"(%756) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%758 = "vector.insert"(%757, %749) <{static_position = array<i64: 0, 3>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%759 = "vector.extract"(%664) <{static_position = array<i64: 0, 4>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%760 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%761 = "vector.extract"(%722) <{static_position = array<i64: 0, 4>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%762 = "vector.shape_cast"(%760) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%763 = "vector.shape_cast"(%761) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%764 = "vector.shape_cast"(%759) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%765 = "amdgpu.mfma"(%762, %763, %764) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%766 = "vector.shape_cast"(%765) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%767 = "vector.insert"(%766, %758) <{static_position = array<i64: 0, 4>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%768 = "vector.extract"(%664) <{static_position = array<i64: 0, 5>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%769 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%770 = "vector.extract"(%722) <{static_position = array<i64: 0, 5>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%771 = "vector.shape_cast"(%769) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%772 = "vector.shape_cast"(%770) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%773 = "vector.shape_cast"(%768) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%774 = "amdgpu.mfma"(%771, %772, %773) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%775 = "vector.shape_cast"(%774) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%776 = "vector.insert"(%775, %767) <{static_position = array<i64: 0, 5>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%777 = "vector.extract"(%664) <{static_position = array<i64: 0, 6>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%778 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%779 = "vector.extract"(%722) <{static_position = array<i64: 0, 6>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%780 = "vector.shape_cast"(%778) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%781 = "vector.shape_cast"(%779) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%782 = "vector.shape_cast"(%777) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%783 = "amdgpu.mfma"(%780, %781, %782) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%784 = "vector.shape_cast"(%783) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%785 = "vector.insert"(%784, %776) <{static_position = array<i64: 0, 6>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%786 = "vector.extract"(%664) <{static_position = array<i64: 0, 7>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%787 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%788 = "vector.extract"(%722) <{static_position = array<i64: 0, 7>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%789 = "vector.shape_cast"(%787) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%790 = "vector.shape_cast"(%788) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%791 = "vector.shape_cast"(%786) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%792 = "amdgpu.mfma"(%789, %790, %791) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%793 = "vector.shape_cast"(%792) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%794 = "vector.insert"(%793, %785) <{static_position = array<i64: 0, 7>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%795 = "vector.extract"(%664) <{static_position = array<i64: 1, 0>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%796 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%797 = "vector.extract"(%722) <{static_position = array<i64: 0, 0>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%798 = "vector.shape_cast"(%796) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%799 = "vector.shape_cast"(%797) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%800 = "vector.shape_cast"(%795) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%801 = "amdgpu.mfma"(%798, %799, %800) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%802 = "vector.shape_cast"(%801) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%803 = "vector.insert"(%802, %794) <{static_position = array<i64: 1, 0>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%804 = "vector.extract"(%664) <{static_position = array<i64: 1, 1>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%805 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%806 = "vector.extract"(%722) <{static_position = array<i64: 0, 1>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%807 = "vector.shape_cast"(%805) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%808 = "vector.shape_cast"(%806) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%809 = "vector.shape_cast"(%804) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%810 = "amdgpu.mfma"(%807, %808, %809) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%811 = "vector.shape_cast"(%810) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%812 = "vector.insert"(%811, %803) <{static_position = array<i64: 1, 1>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%813 = "vector.extract"(%664) <{static_position = array<i64: 1, 2>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%814 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%815 = "vector.extract"(%722) <{static_position = array<i64: 0, 2>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%816 = "vector.shape_cast"(%814) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%817 = "vector.shape_cast"(%815) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%818 = "vector.shape_cast"(%813) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%819 = "amdgpu.mfma"(%816, %817, %818) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%820 = "vector.shape_cast"(%819) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%821 = "vector.insert"(%820, %812) <{static_position = array<i64: 1, 2>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%822 = "vector.extract"(%664) <{static_position = array<i64: 1, 3>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%823 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%824 = "vector.extract"(%722) <{static_position = array<i64: 0, 3>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%825 = "vector.shape_cast"(%823) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%826 = "vector.shape_cast"(%824) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%827 = "vector.shape_cast"(%822) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%828 = "amdgpu.mfma"(%825, %826, %827) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%829 = "vector.shape_cast"(%828) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%830 = "vector.insert"(%829, %821) <{static_position = array<i64: 1, 3>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%831 = "vector.extract"(%664) <{static_position = array<i64: 1, 4>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%832 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%833 = "vector.extract"(%722) <{static_position = array<i64: 0, 4>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%834 = "vector.shape_cast"(%832) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%835 = "vector.shape_cast"(%833) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%836 = "vector.shape_cast"(%831) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%837 = "amdgpu.mfma"(%834, %835, %836) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%838 = "vector.shape_cast"(%837) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%839 = "vector.insert"(%838, %830) <{static_position = array<i64: 1, 4>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%840 = "vector.extract"(%664) <{static_position = array<i64: 1, 5>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%841 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%842 = "vector.extract"(%722) <{static_position = array<i64: 0, 5>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%843 = "vector.shape_cast"(%841) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%844 = "vector.shape_cast"(%842) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%845 = "vector.shape_cast"(%840) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%846 = "amdgpu.mfma"(%843, %844, %845) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%847 = "vector.shape_cast"(%846) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%848 = "vector.insert"(%847, %839) <{static_position = array<i64: 1, 5>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%849 = "vector.extract"(%664) <{static_position = array<i64: 1, 6>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%850 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%851 = "vector.extract"(%722) <{static_position = array<i64: 0, 6>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%852 = "vector.shape_cast"(%850) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%853 = "vector.shape_cast"(%851) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%854 = "vector.shape_cast"(%849) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%855 = "amdgpu.mfma"(%852, %853, %854) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%856 = "vector.shape_cast"(%855) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%857 = "vector.insert"(%856, %848) <{static_position = array<i64: 1, 6>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%858 = "vector.extract"(%664) <{static_position = array<i64: 1, 7>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%859 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%860 = "vector.extract"(%722) <{static_position = array<i64: 0, 7>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%861 = "vector.shape_cast"(%859) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%862 = "vector.shape_cast"(%860) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%863 = "vector.shape_cast"(%858) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%864 = "amdgpu.mfma"(%861, %862, %863) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%865 = "vector.shape_cast"(%864) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%866 = "vector.insert"(%865, %857) <{static_position = array<i64: 1, 7>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
"scf.yield"(%567, %612, %866) : (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>) -> ()
}) : (index, index, index, vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>) -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>)
%169 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%170 = "vector.broadcast"(%169) : (vector<4xf32>) -> vector<1x4xf32>
%171 = "vector.insert"(%170, %6) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%172 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%173 = "vector.broadcast"(%172) : (vector<4xf32>) -> vector<1x4xf32>
%174 = "vector.insert"(%173, %171) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%175 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%176 = "vector.broadcast"(%175) : (vector<4xf32>) -> vector<1x4xf32>
%177 = "vector.insert"(%176, %174) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%178 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%179 = "vector.broadcast"(%178) : (vector<4xf32>) -> vector<1x4xf32>
%180 = "vector.insert"(%179, %177) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%181 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%182 = "vector.broadcast"(%181) : (vector<4xf32>) -> vector<1x4xf32>
%183 = "vector.insert"(%182, %180) <{static_position = array<i64: 2, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%184 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%185 = "vector.broadcast"(%184) : (vector<4xf32>) -> vector<1x4xf32>
%186 = "vector.insert"(%185, %183) <{static_position = array<i64: 2, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%187 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%188 = "vector.broadcast"(%187) : (vector<4xf32>) -> vector<1x4xf32>
%189 = "vector.insert"(%188, %186) <{static_position = array<i64: 3, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%190 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%191 = "vector.broadcast"(%190) : (vector<4xf32>) -> vector<1x4xf32>
%192 = "vector.insert"(%191, %189) <{static_position = array<i64: 3, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%193 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%194 = "vector.broadcast"(%193) : (vector<4xf32>) -> vector<1x4xf32>
%195 = "vector.insert"(%194, %192) <{static_position = array<i64: 4, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%196 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%197 = "vector.broadcast"(%196) : (vector<4xf32>) -> vector<1x4xf32>
%198 = "vector.insert"(%197, %195) <{static_position = array<i64: 4, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%199 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%200 = "vector.broadcast"(%199) : (vector<4xf32>) -> vector<1x4xf32>
%201 = "vector.insert"(%200, %198) <{static_position = array<i64: 5, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%202 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%203 = "vector.broadcast"(%202) : (vector<4xf32>) -> vector<1x4xf32>
%204 = "vector.insert"(%203, %201) <{static_position = array<i64: 5, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%205 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%206 = "vector.broadcast"(%205) : (vector<4xf32>) -> vector<1x4xf32>
%207 = "vector.insert"(%206, %204) <{static_position = array<i64: 6, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%208 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%209 = "vector.broadcast"(%208) : (vector<4xf32>) -> vector<1x4xf32>
%210 = "vector.insert"(%209, %207) <{static_position = array<i64: 6, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%211 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%212 = "vector.broadcast"(%211) : (vector<4xf32>) -> vector<1x4xf32>
%213 = "vector.insert"(%212, %210) <{static_position = array<i64: 7, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%214 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%215 = "vector.broadcast"(%214) : (vector<4xf32>) -> vector<1x4xf32>
%216 = "vector.insert"(%215, %213) <{static_position = array<i64: 7, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%217 = "arith.divf"(%14, %216) <{fastmath = #arith.fastmath<none>}> : (vector<8x2x1x1x1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%218 = "vector.transpose"(%217) <{permutation = array<i64: 1, 0, 3, 2, 5, 4>}> : (vector<8x2x1x1x1x4xf32>) -> vector<2x8x1x1x4x1xf32>
%219 = "arith.mulf"(%218, %168#2) <{fastmath = #arith.fastmath<none>}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%220 = "vector.transfer_read"(%90, %34) <{in_bounds = [], operandSegmentSizes = array<i32: 1, 0, 1, 0>, permutation_map = affine_map<() -> ()>}> : (memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>, f32) -> vector<f32>
%221 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%222 = "vector.broadcast"(%221) : (f32) -> vector<4x1xf32>
%223 = "vector.insert"(%222, %27) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%224 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%225 = "vector.broadcast"(%224) : (f32) -> vector<4x1xf32>
%226 = "vector.insert"(%225, %223) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%227 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%228 = "vector.broadcast"(%227) : (f32) -> vector<4x1xf32>
%229 = "vector.insert"(%228, %226) <{static_position = array<i64: 0, 2, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%230 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%231 = "vector.broadcast"(%230) : (f32) -> vector<4x1xf32>
%232 = "vector.insert"(%231, %229) <{static_position = array<i64: 0, 3, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%233 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%234 = "vector.broadcast"(%233) : (f32) -> vector<4x1xf32>
%235 = "vector.insert"(%234, %232) <{static_position = array<i64: 0, 4, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%236 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%237 = "vector.broadcast"(%236) : (f32) -> vector<4x1xf32>
%238 = "vector.insert"(%237, %235) <{static_position = array<i64: 0, 5, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%239 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%240 = "vector.broadcast"(%239) : (f32) -> vector<4x1xf32>
%241 = "vector.insert"(%240, %238) <{static_position = array<i64: 0, 6, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%242 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%243 = "vector.broadcast"(%242) : (f32) -> vector<4x1xf32>
%244 = "vector.insert"(%243, %241) <{static_position = array<i64: 0, 7, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%245 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%246 = "vector.broadcast"(%245) : (f32) -> vector<4x1xf32>
%247 = "vector.insert"(%246, %244) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%248 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%249 = "vector.broadcast"(%248) : (f32) -> vector<4x1xf32>
%250 = "vector.insert"(%249, %247) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%251 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%252 = "vector.broadcast"(%251) : (f32) -> vector<4x1xf32>
%253 = "vector.insert"(%252, %250) <{static_position = array<i64: 1, 2, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%254 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%255 = "vector.broadcast"(%254) : (f32) -> vector<4x1xf32>
%256 = "vector.insert"(%255, %253) <{static_position = array<i64: 1, 3, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%257 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%258 = "vector.broadcast"(%257) : (f32) -> vector<4x1xf32>
%259 = "vector.insert"(%258, %256) <{static_position = array<i64: 1, 4, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%260 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%261 = "vector.broadcast"(%260) : (f32) -> vector<4x1xf32>
%262 = "vector.insert"(%261, %259) <{static_position = array<i64: 1, 5, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%263 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%264 = "vector.broadcast"(%263) : (f32) -> vector<4x1xf32>
%265 = "vector.insert"(%264, %262) <{static_position = array<i64: 1, 6, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%266 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%267 = "vector.broadcast"(%266) : (f32) -> vector<4x1xf32>
%268 = "vector.insert"(%267, %265) <{static_position = array<i64: 1, 7, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%269 = "arith.divf"(%219, %268) <{fastmath = #arith.fastmath<none>}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%270 = "arith.cmpf"(%269, %16) <{fastmath = #arith.fastmath<none>, predicate = 11 : i64}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xi1>
%271 = "arith.select"(%270, %16, %269) : (vector<2x8x1x1x4x1xi1>, vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%272 = "arith.cmpf"(%271, %15) <{fastmath = #arith.fastmath<none>, predicate = 9 : i64}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xi1>
%273 = "arith.select"(%272, %15, %271) : (vector<2x8x1x1x4x1xi1>, vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%274 = "arith.truncf"(%273) : (vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf8E4M3FNUZ>
%275:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%276:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%277 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%278 = "affine.linearize_index"(%275#1, %35, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%279 = "vector.extract"(%274) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%279, %103, %277, %278) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%280 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%281 = "affine.linearize_index"(%275#1, %30, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%282 = "vector.extract"(%274) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%282, %103, %280, %281) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%283 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%284 = "affine.linearize_index"(%275#1, %13, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%285 = "vector.extract"(%274) <{static_position = array<i64: 0, 2, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%285, %103, %283, %284) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%286 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%287 = "affine.linearize_index"(%275#1, %12, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%288 = "vector.extract"(%274) <{static_position = array<i64: 0, 3, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%288, %103, %286, %287) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%289 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%290 = "affine.linearize_index"(%275#1, %3, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%291 = "vector.extract"(%274) <{static_position = array<i64: 0, 4, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%291, %103, %289, %290) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%292 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%293 = "affine.linearize_index"(%275#1, %2, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%294 = "vector.extract"(%274) <{static_position = array<i64: 0, 5, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%294, %103, %292, %293) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%295 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%296 = "affine.linearize_index"(%275#1, %1, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%297 = "vector.extract"(%274) <{static_position = array<i64: 0, 6, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%297, %103, %295, %296) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%298 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%299 = "affine.linearize_index"(%275#1, %0, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%300 = "vector.extract"(%274) <{static_position = array<i64: 0, 7, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%300, %103, %298, %299) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%301 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%302 = "affine.linearize_index"(%275#1, %35, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%303 = "vector.extract"(%274) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%303, %103, %301, %302) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%304 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%305 = "affine.linearize_index"(%275#1, %30, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%306 = "vector.extract"(%274) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%306, %103, %304, %305) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%307 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%308 = "affine.linearize_index"(%275#1, %13, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%309 = "vector.extract"(%274) <{static_position = array<i64: 1, 2, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%309, %103, %307, %308) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%310 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%311 = "affine.linearize_index"(%275#1, %12, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%312 = "vector.extract"(%274) <{static_position = array<i64: 1, 3, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%312, %103, %310, %311) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%313 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%314 = "affine.linearize_index"(%275#1, %3, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%315 = "vector.extract"(%274) <{static_position = array<i64: 1, 4, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%315, %103, %313, %314) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%316 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%317 = "affine.linearize_index"(%275#1, %2, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%318 = "vector.extract"(%274) <{static_position = array<i64: 1, 5, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%318, %103, %316, %317) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%319 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%320 = "affine.linearize_index"(%275#1, %1, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%321 = "vector.extract"(%274) <{static_position = array<i64: 1, 6, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%321, %103, %319, %320) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%322 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%323 = "affine.linearize_index"(%275#1, %0, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%324 = "vector.extract"(%274) <{static_position = array<i64: 1, 7, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%324, %103, %322, %323) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
"scf.forall.in_parallel"() ({
^bb0:
}) : () -> ()
}) : (index) -> ()
"memref.dealloc"(%44) : (memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%43) : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%42) : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%41) : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%40) : (memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"func.return"() : () -> ()
}) {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {}>} : () -> ()
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed, %collapsed_1, %collapsed_2, %extracted, %arg4 : tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, f32, tensor<?x?xf8E4M3FNUZ>) outs(%cast : tensor<32x?x128xf32>) {
^
f8_attn_chi_castf32_roctorch.mlir:45778:10: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed, %collapsed_1, %collapsed_2, %extracted, %arg4 : tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, f32, tensor<?x?xf8E4M3FNUZ>) outs(%cast : tensor<32x?x128xf32>) {
^
f8_attn_chi_castf32_roctorch.mlir:2706:12: note: called from
%914 = util.call @sharktank_masked_flash_attention_1_32_128_128_f8E4M3FNUZ_f32_f32(%909, %910, %911, %913, %912) : (tensor<1x32x?x128xf8E4M3FNUZ>, tensor<1x32x?x128xf8E4M3FNUZ>, tensor<1x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?xf8E4M3FNUZ>) -> tensor<1x32x?x128xf32>
^
f8_attn_chi_castf32_roctorch.mlir:45778:10: note: see current operation:
"hal.executable.variant"() ({
"hal.executable.export"() ({
^bb0(%arg7: !hal.device, %arg8: index, %arg9: index, %arg10: index, %arg11: index):
%867:3 = "flow.dispatch.workgroup_count_from_slice"(%arg8, %arg9, %arg10, %arg11) : (index, index, index, index) -> (index, index, index)
"hal.return"(%867#0, %867#1, %867#2) : (index, index, index) -> ()
}) {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index, sym_name = "prefill_bs1$async_dispatch_18_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic"} : () -> ()
"builtin.module"() ({
"func.func"() <{function_type = () -> (), sym_name = "prefill_bs1$async_dispatch_18_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic"}> ({
%0 = "arith.constant"() <{value = 7 : index}> : () -> index
%1 = "arith.constant"() <{value = 6 : index}> : () -> index
%2 = "arith.constant"() <{value = 5 : index}> : () -> index
%3 = "arith.constant"() <{value = 4 : index}> : () -> index
%4 = "arith.constant"() <{value = dense<0.000000e+00> : vector<1x8x1x1x8x1xf8E4M3FNUZ>}> : () -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%5 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x1x1x1x1x8xf8E4M3FNUZ>}> : () -> vector<2x1x1x1x1x8xf8E4M3FNUZ>
%6 = "arith.constant"() <{value = dense<0.000000e+00> : vector<8x2x1x1x1x4xf32>}> : () -> vector<8x2x1x1x1x4xf32>
%7 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x2x1x1x1x4xf32>}> : () -> vector<2x2x1x1x1x4xf32>
%8 = "arith.constant"() <{value = dense<0.000000e+00> : vector<8xf32>}> : () -> vector<8xf32>
%9 = "arith.constant"() <{value = dense<0xFF800000> : vector<2x1x4xf32>}> : () -> vector<2x1x4xf32>
%10 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x2x1x1x4x1xf8E4M3FNUZ>}> : () -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%11 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x4x1x1x1x8xf8E4M3FNUZ>}> : () -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%12 = "arith.constant"() <{value = 3 : index}> : () -> index
%13 = "arith.constant"() <{value = 2 : index}> : () -> index
%14 = "arith.constant"() <{value = dense<1.000000e+00> : vector<8x2x1x1x1x4xf32>}> : () -> vector<8x2x1x1x1x4xf32>
%15 = "arith.constant"() <{value = dense<2.400000e+02> : vector<2x8x1x1x4x1xf32>}> : () -> vector<2x8x1x1x4x1xf32>
%16 = "arith.constant"() <{value = dense<-2.400000e+02> : vector<2x8x1x1x4x1xf32>}> : () -> vector<2x8x1x1x4x1xf32>
%17 = "arith.constant"() <{value = dense<2.400000e+02> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%18 = "arith.constant"() <{value = dense<1.44269502> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%19 = "arith.constant"() <{value = dense<0.00416666688> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%20 = "arith.constant"() <{value = dense<0xFF800000> : vector<32x32xf32>}> : () -> vector<32x32xf32>
%21 = "arith.constant"() <{value = 0 : i64}> : () -> i64
%22 = "arith.constant"() <{value = 0 : i8}> : () -> i8
%23 = "arith.constant"() <{value = dense<0.000000e+00> : vector<32x32xf32>}> : () -> vector<32x32xf32>
%24 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x2x1x1x4x1xf32>}> : () -> vector<2x2x1x1x4x1xf32>
%25 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x1x4xf32>}> : () -> vector<2x1x4xf32>
%26 = "arith.constant"() <{value = dense<-3.40282347E+38> : vector<2x1x4xf32>}> : () -> vector<2x1x4xf32>
%27 = "arith.constant"() <{value = dense<0.000000e+00> : vector<2x8x1x1x4x1xf32>}> : () -> vector<2x8x1x1x4x1xf32>
%28 = "arith.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ
%29 = "arith.constant"() <{value = 1.44269502 : f32}> : () -> f32
%30 = "arith.constant"() <{value = 1 : index}> : () -> index
%31 = "arith.constant"() <{value = 32 : index}> : () -> index
%32 = "arith.constant"() <{value = 67108864 : index}> : () -> index
%33 = "arith.constant"() <{value = 32 : i64}> : () -> i64
%34 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
%35 = "arith.constant"() <{value = 0 : index}> : () -> index
%36 = "gpu.thread_id"() <{dimension = #gpu<dim z>}> : () -> index
%37 = "gpu.thread_id"() <{dimension = #gpu<dim y>}> : () -> index
%38 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
%39 = "affine.linearize_index"(%36, %37, %38) <{disjoint, operandSegmentSizes = array<i32: 3, 0>, static_basis = array<i64: 1, 1, 64>}> : (index, index, index) -> index
%40 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%41 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%42 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%43 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%44 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%45 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 0 : index} : () -> i32
%46 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 1 : index} : () -> i32
%47 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 2 : index} : () -> i32
%48 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 3 : index} : () -> i32
%49 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 4 : index} : () -> i32
%50 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 5 : index} : () -> i32
%51 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 6 : index} : () -> i32
%52 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 7 : index} : () -> i32
%53 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 8 : index} : () -> i32
%54 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 9 : index} : () -> i32
%55 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 10 : index} : () -> i32
%56 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 11 : index} : () -> i32
%57 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 12 : index} : () -> i32
%58 = "hal.interface.constant.load"() {layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, ordinal = 13 : index} : () -> i32
%59 = "arith.extui"(%45) : (i32) -> i64
%60 = "arith.extui"(%46) : (i32) -> i64
%61 = "arith.shli"(%60, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%62 = "arith.ori"(%59, %61) : (i64, i64) -> i64
%63 = "arith.index_castui"(%62) : (i64) -> index
%64 = "arith.extui"(%47) : (i32) -> i64
%65 = "arith.extui"(%48) : (i32) -> i64
%66 = "arith.shli"(%65, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%67 = "arith.ori"(%64, %66) : (i64, i64) -> i64
%68 = "arith.index_castui"(%67) : (i64) -> index
%69 = "arith.extui"(%49) : (i32) -> i64
%70 = "arith.extui"(%50) : (i32) -> i64
%71 = "arith.shli"(%70, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%72 = "arith.ori"(%69, %71) : (i64, i64) -> i64
%73 = "arith.index_castui"(%72) : (i64) -> index
%74 = "arith.extui"(%51) : (i32) -> i64
%75 = "arith.extui"(%52) : (i32) -> i64
%76 = "arith.shli"(%75, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%77 = "arith.ori"(%74, %76) : (i64, i64) -> i64
%78 = "arith.index_castui"(%77) {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : (i64) -> index
%79 = "arith.extui"(%53) : (i32) -> i64
%80 = "arith.extui"(%54) : (i32) -> i64
%81 = "arith.shli"(%80, %33) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
%82 = "arith.ori"(%79, %81) : (i64, i64) -> i64
%83 = "arith.index_castui"(%82) : (i64) -> index
%84 = "arith.index_castui"(%55) : (i32) -> index
%85 = "arith.bitcast"(%56) : (i32) -> f32
%86 = "arith.index_castui"(%57) : (i32) -> index
%87 = "arith.index_castui"(%58) : (i32) -> index
%88:8 = "util.assume.int"(%63, %68, %73, %78, %83, %84, %86, %87) <{assumptions = [[#util.int.assumption<umin = 68027392, umax = 20995769344>], [#util.int.assumption<umin = 68158464, umax = 21532509184>], [#util.int.assumption<umin = 68355072, umax = 22337618944>], [#util.int.assumption<umin = 1075847616, umax = 1075847616, udiv = 1075847616>, #util.int.assumption<umin = 1293968512, umax = 1293968512, udiv = 1293968512>, #util.int.assumption<umin = 1512089408, umax = 1512089408, udiv = 1512089408>, #util.int.assumption<umin = 1730210304, umax = 1730210304, udiv = 1730210304>, #util.int.assumption<umin = 1948331200, umax = 1948331200, udiv = 1948331200>, #util.int.assumption<umin = 2166452096, umax = 2166452096, udiv = 2166452096>, #util.int.assumption<umin = 2384572992, umax = 2384572992, udiv = 2384572992>, #util.int.assumption<umin = 2602693888, umax = 2602693888, udiv = 2602693888>, #util.int.assumption<umin = 2820814784, umax = 2820814784, udiv = 2820814784>, #util.int.assumption<umin = 3038935680, umax = 3038935680, udiv = 3038935680>, #util.int.assumption<umin = 3257056576, umax = 3257056576, udiv = 3257056576>, #util.int.assumption<umin = 3475177472, umax = 3475177472, udiv = 3475177472>, #util.int.assumption<umin = 3693298368, umax = 3693298368, udiv = 3693298368>, #util.int.assumption<umin = 3911419264, umax = 3911419264, udiv = 3911419264>, #util.int.assumption<umin = 4129540160, umax = 4129540160, udiv = 4129540160>, #util.int.assumption<umin = 4347661056, umax = 4347661056, udiv = 4347661056>, #util.int.assumption<umin = 4565781952, umax = 4565781952, udiv = 4565781952>, #util.int.assumption<umin = 4783902848, umax = 4783902848, udiv = 4783902848>, #util.int.assumption<umin = 5002023744, umax = 5002023744, udiv = 5002023744>, #util.int.assumption<umin = 5220144640, umax = 5220144640, udiv = 5220144640>, #util.int.assumption<umin = 5438265536, umax = 5438265536, udiv = 5438265536>, #util.int.assumption<umin = 5656386432, umax = 5656386432, udiv = 5656386432>, #util.int.assumption<umin = 5874507328, umax = 5874507328, udiv = 5874507328>, #util.int.assumption<umin = 6092628224, umax = 6092628224, udiv = 6092628224>, #util.int.assumption<umin = 6310749120, umax = 6310749120, udiv = 6310749120>, #util.int.assumption<umin = 6528870016, umax = 6528870016, udiv = 6528870016>, #util.int.assumption<umin = 6746990912, umax = 6746990912, udiv = 6746990912>, #util.int.assumption<umin = 6965111808, umax = 6965111808, udiv = 6965111808>, #util.int.assumption<umin = 7183232704, umax = 7183232704, udiv = 7183232704>, #util.int.assumption<umin = 7401353600, umax = 7401353600, udiv = 7401353600>, #util.int.assumption<umin = 7619474496, umax = 7619474496, udiv = 7619474496>, #util.int.assumption<umin = 7837595392, umax = 7837595392, udiv = 7837595392>], [#util.int.assumption<umin = 67896320, umax = 20459029504>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>], [#util.int.assumption<umin = 1, umax = 4095>], [#util.int.assumption<umin = 32, umax = 131040, udiv = 32>]]}> : (index, index, index, index, index, index, index, index) -> (index, index, index, index, index, index, index, index)
%89 = "hal.interface.binding.subspan"(%35) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<i64, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%89) <{alignment = 64 : i32}> : (memref<i64, #hal.descriptor_type<storage_buffer>>) -> ()
%90 = "hal.interface.binding.subspan"(%88#3) {alignment = 64 : index, binding = 2 : index, descriptor_flags = 1 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 0>} : (index) -> memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%90) <{alignment = 64 : i32}> : (memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%91 = "flow.dispatch.workload.ordinal"(%88#5) <{ordinal = 0 : index}> : (index) -> index
%92 = "flow.dispatch.workload.ordinal"(%88#6) <{ordinal = 1 : index}> : (index) -> index
%93 = "flow.dispatch.workload.ordinal"(%88#6) <{ordinal = 2 : index}> : (index) -> index
%94 = "flow.dispatch.workload.ordinal"(%88#7) <{ordinal = 3 : index}> : (index) -> index
%95 = "hal.interface.binding.subspan"(%32, %92, %91) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 2>} : (index, index, index) -> memref<?x32x?xi8, strided<[?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%95) <{alignment = 64 : i32}> : (memref<?x32x?xi8, strided<[?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%96 = "hal.interface.binding.subspan"(%88#0, %93) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%96) <{alignment = 1 : i32}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%97 = "arith.divsi"(%94, %31) : (index, index) -> index
%98 = "hal.interface.binding.subspan"(%88#1, %97) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%98) <{alignment = 1 : i32}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%99 = "arith.divsi"(%91, %31) : (index, index) -> index
%100 = "hal.interface.binding.subspan"(%88#2, %99) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 3 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<?x32x8x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%100) <{alignment = 1 : i32}> : (memref<?x32x8x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
%101 = "hal.interface.binding.subspan"(%88#4, %92) {alignment = 64 : index, binding = 3 : index, descriptor_flags = 2 : i32, layout = #hal.pipeline.layout<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>, operandSegmentSizes = array<i32: 1, 1>} : (index, index) -> memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
"memref.assume_alignment"(%101) <{alignment = 1 : i32}> : (memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> ()
"scf.forall"(%93) <{mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>], operandSegmentSizes = array<i32: 0, 1, 0, 0>, staticLowerBound = array<i64: 0, 0, 0>, staticStep = array<i64: 1, 1, 1>, staticUpperBound = array<i64: 8, 4, -9223372036854775808>}> ({
^bb0(%arg0: index, %arg1: index, %arg2: index):
"gpu.barrier"() : () -> ()
%102 = "memref.subview"(%101, %arg2, %arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 3, 0, 0>, static_offsets = array<i64: 0, -9223372036854775808, 0, -9223372036854775808, -9223372036854775808, 0>, static_sizes = array<i64: 1, 1, 32, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index) -> memref<1x1x32x1x1x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%103 = "memref.subview"(%102) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 32, 1, 1, 128>, static_strides = array<i64: 1, 1, 1, 1, 1, 1>}> : (memref<1x1x32x1x1x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%104:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%105:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%106 = "affine.linearize_index"(%104#2, %35, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%107 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%108 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %106, %107, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%109 = "affine.linearize_index"(%104#2, %30, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%110 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%111 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %109, %110, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%112 = "affine.linearize_index"(%104#2, %13, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%113 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%114 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %112, %113, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%115 = "affine.linearize_index"(%104#2, %12, %35, %105#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%116 = "affine.linearize_index"(%104#1, %35, %35, %105#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%117 = "vector.transfer_read"(%96, %arg0, %arg1, %35, %arg2, %115, %116, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 6, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>}> : (memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%118 = "arith.mulf"(%85, %29) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%119:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%120:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%121 = "affine.linearize_index"(%119#2, %35, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%122 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%108, %43, %121, %122) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%123 = "affine.linearize_index"(%119#2, %30, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%124 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%111, %43, %123, %124) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%125 = "affine.linearize_index"(%119#2, %13, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%126 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%114, %43, %125, %126) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%127 = "affine.linearize_index"(%119#2, %12, %35, %120#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%128 = "affine.linearize_index"(%119#1, %35, %35, %120#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%117, %43, %127, %128) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
"gpu.barrier"() : () -> ()
%129:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%130:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%131 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%132 = "affine.linearize_index"(%129#1, %35, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%133 = "vector.transfer_read"(%43, %131, %132, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%134 = "vector.insert_strided_slice"(%133, %11) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%135 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%136 = "affine.linearize_index"(%129#1, %30, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%137 = "vector.transfer_read"(%43, %135, %136, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%138 = "vector.insert_strided_slice"(%137, %134) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%139 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%140 = "affine.linearize_index"(%129#1, %13, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%141 = "vector.transfer_read"(%43, %139, %140, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%142 = "vector.insert_strided_slice"(%141, %138) <{offsets = [0, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%143 = "affine.linearize_index"(%129#2, %35, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%144 = "affine.linearize_index"(%129#1, %12, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%145 = "vector.transfer_read"(%43, %143, %144, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%146 = "vector.insert_strided_slice"(%145, %142) <{offsets = [0, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%147 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%148 = "affine.linearize_index"(%129#1, %35, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%149 = "vector.transfer_read"(%43, %147, %148, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%150 = "vector.insert_strided_slice"(%149, %146) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%151 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%152 = "affine.linearize_index"(%129#1, %30, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%153 = "vector.transfer_read"(%43, %151, %152, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%154 = "vector.insert_strided_slice"(%153, %150) <{offsets = [1, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%155 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%156 = "affine.linearize_index"(%129#1, %13, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%157 = "vector.transfer_read"(%43, %155, %156, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%158 = "vector.insert_strided_slice"(%157, %154) <{offsets = [1, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%159 = "affine.linearize_index"(%129#2, %30, %35, %130#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%160 = "affine.linearize_index"(%129#1, %12, %35, %130#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%161 = "vector.transfer_read"(%43, %159, %160, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%162 = "vector.insert_strided_slice"(%161, %158) <{offsets = [1, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%163 = "vector.transfer_read"(%89, %21) <{in_bounds = [], operandSegmentSizes = array<i32: 1, 0, 1, 0>, permutation_map = affine_map<() -> ()>}> : (memref<i64, #hal.descriptor_type<storage_buffer>>, i64) -> vector<i64>
%164 = "iree_vector_ext.to_simd"(%163) : (vector<i64>) -> vector<i64>
%165 = "vector.broadcast"(%164) : (vector<i64>) -> vector<32x32xi64>
%166 = "vector.step"() : () -> vector<32xindex>
%167 = "vector.broadcast"(%118) : (f32) -> vector<2x2x1x1x4x1xf32>
%168:3 = "scf.for"(%35, %97, %30, %26, %25, %27) ({
^bb0(%arg3: index, %arg4: vector<2x1x4xf32>, %arg5: vector<2x1x4xf32>, %arg6: vector<2x8x1x1x4x1xf32>):
"gpu.barrier"() : () -> ()
%325 = "memref.subview"(%100, %arg3, %arg0) <{operandSegmentSizes = array<i32: 1, 2, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, -9223372036854775808, 0>, static_sizes = array<i64: 1, 32, 1, 128>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<?x32x8x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> memref<1x32x1x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%326 = "memref.subview"(%325) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: 1, 32, 1, 128>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<1x32x1x128xf8E4M3FNUZ, strided<[32768, 1024, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>) -> memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%327:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%328:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%329 = "affine.linearize_index"(%327#2, %35, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%330 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%331 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %329, %330, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%332 = "affine.linearize_index"(%327#2, %30, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%333 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%334 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %332, %333, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%335 = "affine.linearize_index"(%327#2, %13, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%336 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%337 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %335, %336, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%338 = "affine.linearize_index"(%327#2, %12, %35, %328#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%339 = "affine.linearize_index"(%327#1, %35, %35, %328#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%340 = "vector.transfer_read"(%98, %arg0, %arg1, %arg3, %338, %339, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 5, 1, 0>, permutation_map = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>}> : (memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%341:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%342:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%343 = "affine.linearize_index"(%341#2, %35, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%344 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%345 = "vector.transfer_read"(%326, %343, %344, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%346 = "affine.linearize_index"(%341#2, %30, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%347 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%348 = "vector.transfer_read"(%326, %346, %347, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%349 = "affine.linearize_index"(%341#2, %13, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%350 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%351 = "vector.transfer_read"(%326, %349, %350, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%352 = "affine.linearize_index"(%341#2, %12, %35, %342#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%353 = "affine.linearize_index"(%341#1, %35, %35, %342#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
%354 = "vector.transfer_read"(%326, %352, %353, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, strided<[1024, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, f8E4M3FNUZ) -> vector<1x16xf8E4M3FNUZ>
%355 = "affine.linearize_index"(%arg3, %35, %99) <{disjoint, operandSegmentSizes = array<i32: 2, 1>, static_basis = array<i64: -9223372036854775808, 32>}> : (index, index, index) -> index
%356 = "vector.transfer_read"(%95, %arg2, %35, %355, %22) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 3, 1, 0>, permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}> : (memref<?x32x?xi8, strided<[?, ?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index, index, i8) -> vector<32x32xi8>
%357 = "arith.trunci"(%356) : (vector<32x32xi8>) -> vector<32x32xi1>
%358 = "vector.broadcast"(%355) : (index) -> vector<32xindex>
%359 = "arith.addi"(%358, %166) <{overflowFlags = #arith.overflow<none>}> : (vector<32xindex>, vector<32xindex>) -> vector<32xindex>
%360 = "arith.index_cast"(%359) : (vector<32xindex>) -> vector<32xi64>
%361 = "vector.broadcast"(%360) : (vector<32xi64>) -> vector<32x32xi64>
%362 = "arith.cmpi"(%361, %165) <{predicate = 5 : i64}> : (vector<32x32xi64>, vector<32x32xi64>) -> vector<32x32xi1>
%363 = "arith.ori"(%357, %362) : (vector<32x32xi1>, vector<32x32xi1>) -> vector<32x32xi1>
%364 = "arith.select"(%363, %20, %23) : (vector<32x32xi1>, vector<32x32xf32>, vector<32x32xf32>) -> vector<32x32xf32>
%365 = "arith.truncf"(%364) : (vector<32x32xf32>) -> vector<32x32xf8E4M3FNUZ>
"vector.transfer_write"(%365, %44, %35, %35, %35) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 3, 0>, permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}> : (vector<32x32xf8E4M3FNUZ>, memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, index) -> ()
%366 = "memref.expand_shape"(%44) <{reassociation = [[0, 1], [2], [3, 4]], static_output_shape = array<i64: 1, 1, 32, 1, 32>}> : (memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> memref<1x1x32x1x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>
%367 = "memref.subview"(%366) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 32, 1, 32>, static_strides = array<i64: 1, 1, 1, 1, 1>}> : (memref<1x1x32x1x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>
%368:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%369:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%370 = "affine.linearize_index"(%368#2, %35, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%371 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%331, %42, %370, %371) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%372 = "affine.linearize_index"(%368#2, %30, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%373 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%334, %42, %372, %373) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%374 = "affine.linearize_index"(%368#2, %13, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%375 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%337, %42, %374, %375) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%376 = "affine.linearize_index"(%368#2, %12, %35, %369#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%377 = "affine.linearize_index"(%368#1, %35, %35, %369#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%340, %42, %376, %377) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%378:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%379:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 8, 8>}> : (index) -> (index, index, index)
%380 = "affine.linearize_index"(%378#2, %35, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%381 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%345, %41, %380, %381) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%382 = "affine.linearize_index"(%378#2, %30, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%383 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%348, %41, %382, %383) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%384 = "affine.linearize_index"(%378#2, %13, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%385 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%351, %41, %384, %385) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%386 = "affine.linearize_index"(%378#2, %12, %35, %379#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 8, 1>}> : (index, index, index, index, index) -> index
%387 = "affine.linearize_index"(%378#1, %35, %35, %379#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 8, 16>}> : (index, index, index, index, index) -> index
"vector.transfer_write"(%354, %41, %386, %387) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<1x16xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%388:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%389:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%390 = "affine.linearize_index"(%388#2, %35, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%391 = "affine.linearize_index"(%388#1, %35, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%392 = "vector.transfer_read"(%367, %390, %391, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%393 = "vector.insert_strided_slice"(%392, %10) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%394 = "affine.linearize_index"(%388#2, %35, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%395 = "affine.linearize_index"(%388#1, %30, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%396 = "vector.transfer_read"(%367, %394, %395, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%397 = "vector.insert_strided_slice"(%396, %393) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%398 = "affine.linearize_index"(%388#2, %30, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%399 = "affine.linearize_index"(%388#1, %35, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%400 = "vector.transfer_read"(%367, %398, %399, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%401 = "vector.insert_strided_slice"(%400, %397) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%402 = "affine.linearize_index"(%388#2, %30, %35, %389#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%403 = "affine.linearize_index"(%388#1, %30, %35, %389#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%404 = "vector.transfer_read"(%367, %402, %403, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<4x1xf8E4M3FNUZ>
%405 = "vector.insert_strided_slice"(%404, %401) <{offsets = [1, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<4x1xf8E4M3FNUZ>, vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%406 = "arith.extf"(%405) : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<2x2x1x1x4x1xf32>
%407 = "arith.mulf"(%406, %18) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
"gpu.barrier"() : () -> ()
%408:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%409:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%410 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%411 = "affine.linearize_index"(%408#1, %35, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%412 = "vector.transfer_read"(%42, %410, %411, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%413 = "vector.insert_strided_slice"(%412, %11) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%414 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%415 = "affine.linearize_index"(%408#1, %30, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%416 = "vector.transfer_read"(%42, %414, %415, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%417 = "vector.insert_strided_slice"(%416, %413) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%418 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%419 = "affine.linearize_index"(%408#1, %13, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%420 = "vector.transfer_read"(%42, %418, %419, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%421 = "vector.insert_strided_slice"(%420, %417) <{offsets = [0, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%422 = "affine.linearize_index"(%408#2, %35, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%423 = "affine.linearize_index"(%408#1, %12, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%424 = "vector.transfer_read"(%42, %422, %423, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%425 = "vector.insert_strided_slice"(%424, %421) <{offsets = [0, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%426 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%427 = "affine.linearize_index"(%408#1, %35, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%428 = "vector.transfer_read"(%42, %426, %427, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%429 = "vector.insert_strided_slice"(%428, %425) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%430 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%431 = "affine.linearize_index"(%408#1, %30, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%432 = "vector.transfer_read"(%42, %430, %431, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%433 = "vector.insert_strided_slice"(%432, %429) <{offsets = [1, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%434 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%435 = "affine.linearize_index"(%408#1, %13, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%436 = "vector.transfer_read"(%42, %434, %435, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%437 = "vector.insert_strided_slice"(%436, %433) <{offsets = [1, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%438 = "affine.linearize_index"(%408#2, %30, %35, %409#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%439 = "affine.linearize_index"(%408#1, %12, %35, %409#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 4, 1, 4, 8>}> : (index, index, index, index, index) -> index
%440 = "vector.transfer_read"(%42, %438, %439, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%441 = "vector.insert_strided_slice"(%440, %437) <{offsets = [1, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<2x4x1x1x1x8xf8E4M3FNUZ>
%442 = "vector.extract"(%24) <{static_position = array<i64: 0, 0>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%443 = "vector.extract"(%162) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%444 = "vector.extract"(%441) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%445 = "vector.shape_cast"(%443) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%446 = "vector.shape_cast"(%444) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%447 = "vector.shape_cast"(%442) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%448 = "amdgpu.mfma"(%445, %446, %447) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%449 = "vector.extract"(%162) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%450 = "vector.extract"(%441) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%451 = "vector.shape_cast"(%449) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%452 = "vector.shape_cast"(%450) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%453 = "amdgpu.mfma"(%451, %452, %448) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%454 = "vector.extract"(%162) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%455 = "vector.extract"(%441) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%456 = "vector.shape_cast"(%454) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%457 = "vector.shape_cast"(%455) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%458 = "amdgpu.mfma"(%456, %457, %453) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%459 = "vector.extract"(%162) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%460 = "vector.extract"(%441) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%461 = "vector.shape_cast"(%459) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%462 = "vector.shape_cast"(%460) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%463 = "amdgpu.mfma"(%461, %462, %458) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%464 = "vector.shape_cast"(%463) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%465 = "vector.insert"(%464, %24) <{static_position = array<i64: 0, 0>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%466 = "vector.extract"(%24) <{static_position = array<i64: 0, 1>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%467 = "vector.extract"(%162) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%468 = "vector.extract"(%441) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%469 = "vector.shape_cast"(%467) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%470 = "vector.shape_cast"(%468) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%471 = "vector.shape_cast"(%466) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%472 = "amdgpu.mfma"(%469, %470, %471) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%473 = "vector.extract"(%162) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%474 = "vector.extract"(%441) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%475 = "vector.shape_cast"(%473) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%476 = "vector.shape_cast"(%474) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%477 = "amdgpu.mfma"(%475, %476, %472) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%478 = "vector.extract"(%162) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%479 = "vector.extract"(%441) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%480 = "vector.shape_cast"(%478) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%481 = "vector.shape_cast"(%479) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%482 = "amdgpu.mfma"(%480, %481, %477) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%483 = "vector.extract"(%162) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%484 = "vector.extract"(%441) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%485 = "vector.shape_cast"(%483) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%486 = "vector.shape_cast"(%484) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%487 = "amdgpu.mfma"(%485, %486, %482) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%488 = "vector.shape_cast"(%487) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%489 = "vector.insert"(%488, %465) <{static_position = array<i64: 0, 1>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%490 = "vector.extract"(%24) <{static_position = array<i64: 1, 0>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%491 = "vector.extract"(%162) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%492 = "vector.extract"(%441) <{static_position = array<i64: 0, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%493 = "vector.shape_cast"(%491) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%494 = "vector.shape_cast"(%492) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%495 = "vector.shape_cast"(%490) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%496 = "amdgpu.mfma"(%493, %494, %495) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%497 = "vector.extract"(%162) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%498 = "vector.extract"(%441) <{static_position = array<i64: 0, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%499 = "vector.shape_cast"(%497) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%500 = "vector.shape_cast"(%498) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%501 = "amdgpu.mfma"(%499, %500, %496) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%502 = "vector.extract"(%162) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%503 = "vector.extract"(%441) <{static_position = array<i64: 0, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%504 = "vector.shape_cast"(%502) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%505 = "vector.shape_cast"(%503) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%506 = "amdgpu.mfma"(%504, %505, %501) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%507 = "vector.extract"(%162) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%508 = "vector.extract"(%441) <{static_position = array<i64: 0, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%509 = "vector.shape_cast"(%507) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%510 = "vector.shape_cast"(%508) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%511 = "amdgpu.mfma"(%509, %510, %506) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%512 = "vector.shape_cast"(%511) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%513 = "vector.insert"(%512, %489) <{static_position = array<i64: 1, 0>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%514 = "vector.extract"(%24) <{static_position = array<i64: 1, 1>}> : (vector<2x2x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%515 = "vector.extract"(%162) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%516 = "vector.extract"(%441) <{static_position = array<i64: 1, 0>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%517 = "vector.shape_cast"(%515) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%518 = "vector.shape_cast"(%516) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%519 = "vector.shape_cast"(%514) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%520 = "amdgpu.mfma"(%517, %518, %519) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%521 = "vector.extract"(%162) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%522 = "vector.extract"(%441) <{static_position = array<i64: 1, 1>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%523 = "vector.shape_cast"(%521) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%524 = "vector.shape_cast"(%522) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%525 = "amdgpu.mfma"(%523, %524, %520) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%526 = "vector.extract"(%162) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%527 = "vector.extract"(%441) <{static_position = array<i64: 1, 2>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%528 = "vector.shape_cast"(%526) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%529 = "vector.shape_cast"(%527) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%530 = "amdgpu.mfma"(%528, %529, %525) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%531 = "vector.extract"(%162) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%532 = "vector.extract"(%441) <{static_position = array<i64: 1, 3>}> : (vector<2x4x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%533 = "vector.shape_cast"(%531) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%534 = "vector.shape_cast"(%532) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%535 = "amdgpu.mfma"(%533, %534, %530) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%536 = "vector.shape_cast"(%535) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%537 = "vector.insert"(%536, %513) <{static_position = array<i64: 1, 1>}> : (vector<1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%538 = "arith.mulf"(%167, %537) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%539 = "arith.addf"(%538, %19) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%540 = "arith.addf"(%539, %407) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%541 = "vector.multi_reduction"(%540, %9) <{kind = #vector.kind<maximumf>, reduction_dims = array<i64: 1, 3, 5>}> : (vector<2x2x1x1x4x1xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%542 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%543 = "gpu.subgroup_reduce"(%542) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%544 = "vector.insert"(%543, %8) <{static_position = array<i64: 0>}> : (f32, vector<8xf32>) -> vector<8xf32>
%545 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%546 = "gpu.subgroup_reduce"(%545) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%547 = "vector.insert"(%546, %544) <{static_position = array<i64: 1>}> : (f32, vector<8xf32>) -> vector<8xf32>
%548 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%549 = "gpu.subgroup_reduce"(%548) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%550 = "vector.insert"(%549, %547) <{static_position = array<i64: 2>}> : (f32, vector<8xf32>) -> vector<8xf32>
%551 = "vector.extract"(%541) <{static_position = array<i64: 0, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%552 = "gpu.subgroup_reduce"(%551) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%553 = "vector.insert"(%552, %550) <{static_position = array<i64: 3>}> : (f32, vector<8xf32>) -> vector<8xf32>
%554 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%555 = "gpu.subgroup_reduce"(%554) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%556 = "vector.insert"(%555, %553) <{static_position = array<i64: 4>}> : (f32, vector<8xf32>) -> vector<8xf32>
%557 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%558 = "gpu.subgroup_reduce"(%557) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%559 = "vector.insert"(%558, %556) <{static_position = array<i64: 5>}> : (f32, vector<8xf32>) -> vector<8xf32>
%560 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%561 = "gpu.subgroup_reduce"(%560) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%562 = "vector.insert"(%561, %559) <{static_position = array<i64: 6>}> : (f32, vector<8xf32>) -> vector<8xf32>
%563 = "vector.extract"(%541) <{static_position = array<i64: 1, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%564 = "gpu.subgroup_reduce"(%563) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op maximumf>}> : (f32) -> f32
%565 = "vector.insert"(%564, %562) <{static_position = array<i64: 7>}> : (f32, vector<8xf32>) -> vector<8xf32>
%566 = "vector.shape_cast"(%565) : (vector<8xf32>) -> vector<2x1x4xf32>
%567 = "arith.maximumf"(%566, %arg4) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%568 = "arith.subf"(%arg4, %567) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%569 = "math.exp2"(%568) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>) -> vector<2x1x4xf32>
%570 = "arith.mulf"(%569, %arg5) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%571 = "vector.extract"(%567) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%572 = "vector.broadcast"(%571) : (vector<4xf32>) -> vector<1x4xf32>
%573 = "vector.insert"(%572, %7) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%574 = "vector.extract"(%567) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%575 = "vector.broadcast"(%574) : (vector<4xf32>) -> vector<1x4xf32>
%576 = "vector.insert"(%575, %573) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%577 = "vector.extract"(%567) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%578 = "vector.broadcast"(%577) : (vector<4xf32>) -> vector<1x4xf32>
%579 = "vector.insert"(%578, %576) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%580 = "vector.extract"(%567) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%581 = "vector.broadcast"(%580) : (vector<4xf32>) -> vector<1x4xf32>
%582 = "vector.insert"(%581, %579) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<1x4xf32>, vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x1x4xf32>
%583 = "vector.transpose"(%582) <{permutation = array<i64: 1, 0, 3, 2, 5, 4>}> : (vector<2x2x1x1x1x4xf32>) -> vector<2x2x1x1x4x1xf32>
%584 = "arith.subf"(%540, %583) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%585 = "math.exp2"(%584) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%586 = "vector.multi_reduction"(%585, %25) <{kind = #vector.kind<add>, reduction_dims = array<i64: 1, 3, 5>}> : (vector<2x2x1x1x4x1xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%587 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%588 = "gpu.subgroup_reduce"(%587) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%589 = "vector.insert"(%588, %8) <{static_position = array<i64: 0>}> : (f32, vector<8xf32>) -> vector<8xf32>
%590 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%591 = "gpu.subgroup_reduce"(%590) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%592 = "vector.insert"(%591, %589) <{static_position = array<i64: 1>}> : (f32, vector<8xf32>) -> vector<8xf32>
%593 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%594 = "gpu.subgroup_reduce"(%593) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%595 = "vector.insert"(%594, %592) <{static_position = array<i64: 2>}> : (f32, vector<8xf32>) -> vector<8xf32>
%596 = "vector.extract"(%586) <{static_position = array<i64: 0, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%597 = "gpu.subgroup_reduce"(%596) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%598 = "vector.insert"(%597, %595) <{static_position = array<i64: 3>}> : (f32, vector<8xf32>) -> vector<8xf32>
%599 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 0>}> : (vector<2x1x4xf32>) -> f32
%600 = "gpu.subgroup_reduce"(%599) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%601 = "vector.insert"(%600, %598) <{static_position = array<i64: 4>}> : (f32, vector<8xf32>) -> vector<8xf32>
%602 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 1>}> : (vector<2x1x4xf32>) -> f32
%603 = "gpu.subgroup_reduce"(%602) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%604 = "vector.insert"(%603, %601) <{static_position = array<i64: 5>}> : (f32, vector<8xf32>) -> vector<8xf32>
%605 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 2>}> : (vector<2x1x4xf32>) -> f32
%606 = "gpu.subgroup_reduce"(%605) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%607 = "vector.insert"(%606, %604) <{static_position = array<i64: 6>}> : (f32, vector<8xf32>) -> vector<8xf32>
%608 = "vector.extract"(%586) <{static_position = array<i64: 1, 0, 3>}> : (vector<2x1x4xf32>) -> f32
%609 = "gpu.subgroup_reduce"(%608) <{cluster_size = 16 : i32, cluster_stride = 1 : i32, op = #gpu<all_reduce_op add>}> : (f32) -> f32
%610 = "vector.insert"(%609, %607) <{static_position = array<i64: 7>}> : (f32, vector<8xf32>) -> vector<8xf32>
%611 = "vector.shape_cast"(%610) : (vector<8xf32>) -> vector<2x1x4xf32>
%612 = "arith.addf"(%611, %570) <{fastmath = #arith.fastmath<none>}> : (vector<2x1x4xf32>, vector<2x1x4xf32>) -> vector<2x1x4xf32>
%613 = "arith.minimumf"(%585, %17) <{fastmath = #arith.fastmath<none>}> : (vector<2x2x1x1x4x1xf32>, vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf32>
%614 = "arith.truncf"(%613) : (vector<2x2x1x1x4x1xf32>) -> vector<2x2x1x1x4x1xf8E4M3FNUZ>
%615 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%616 = "vector.broadcast"(%615) : (vector<4xf32>) -> vector<1x4xf32>
%617 = "vector.insert"(%616, %6) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%618 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%619 = "vector.broadcast"(%618) : (vector<4xf32>) -> vector<1x4xf32>
%620 = "vector.insert"(%619, %617) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%621 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%622 = "vector.broadcast"(%621) : (vector<4xf32>) -> vector<1x4xf32>
%623 = "vector.insert"(%622, %620) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%624 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%625 = "vector.broadcast"(%624) : (vector<4xf32>) -> vector<1x4xf32>
%626 = "vector.insert"(%625, %623) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%627 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%628 = "vector.broadcast"(%627) : (vector<4xf32>) -> vector<1x4xf32>
%629 = "vector.insert"(%628, %626) <{static_position = array<i64: 2, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%630 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%631 = "vector.broadcast"(%630) : (vector<4xf32>) -> vector<1x4xf32>
%632 = "vector.insert"(%631, %629) <{static_position = array<i64: 2, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%633 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%634 = "vector.broadcast"(%633) : (vector<4xf32>) -> vector<1x4xf32>
%635 = "vector.insert"(%634, %632) <{static_position = array<i64: 3, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%636 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%637 = "vector.broadcast"(%636) : (vector<4xf32>) -> vector<1x4xf32>
%638 = "vector.insert"(%637, %635) <{static_position = array<i64: 3, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%639 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%640 = "vector.broadcast"(%639) : (vector<4xf32>) -> vector<1x4xf32>
%641 = "vector.insert"(%640, %638) <{static_position = array<i64: 4, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%642 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%643 = "vector.broadcast"(%642) : (vector<4xf32>) -> vector<1x4xf32>
%644 = "vector.insert"(%643, %641) <{static_position = array<i64: 4, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%645 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%646 = "vector.broadcast"(%645) : (vector<4xf32>) -> vector<1x4xf32>
%647 = "vector.insert"(%646, %644) <{static_position = array<i64: 5, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%648 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%649 = "vector.broadcast"(%648) : (vector<4xf32>) -> vector<1x4xf32>
%650 = "vector.insert"(%649, %647) <{static_position = array<i64: 5, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%651 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%652 = "vector.broadcast"(%651) : (vector<4xf32>) -> vector<1x4xf32>
%653 = "vector.insert"(%652, %650) <{static_position = array<i64: 6, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%654 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%655 = "vector.broadcast"(%654) : (vector<4xf32>) -> vector<1x4xf32>
%656 = "vector.insert"(%655, %653) <{static_position = array<i64: 6, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%657 = "vector.extract"(%569) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%658 = "vector.broadcast"(%657) : (vector<4xf32>) -> vector<1x4xf32>
%659 = "vector.insert"(%658, %656) <{static_position = array<i64: 7, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%660 = "vector.extract"(%569) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%661 = "vector.broadcast"(%660) : (vector<4xf32>) -> vector<1x4xf32>
%662 = "vector.insert"(%661, %659) <{static_position = array<i64: 7, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%663 = "vector.transpose"(%662) <{permutation = array<i64: 1, 0, 3, 2, 5, 4>}> : (vector<8x2x1x1x1x4xf32>) -> vector<2x8x1x1x4x1xf32>
%664 = "arith.mulf"(%663, %arg6) <{fastmath = #arith.fastmath<none>}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%665:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%666:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%667 = "affine.linearize_index"(%665#2, %35, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%668 = "affine.linearize_index"(%665#1, %35, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%669 = "vector.extract"(%614) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%669, %40, %667, %668) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%670 = "affine.linearize_index"(%665#2, %35, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%671 = "affine.linearize_index"(%665#1, %30, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%672 = "vector.extract"(%614) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%672, %40, %670, %671) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%673 = "affine.linearize_index"(%665#2, %30, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%674 = "affine.linearize_index"(%665#1, %35, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%675 = "vector.extract"(%614) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%675, %40, %673, %674) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
%676 = "affine.linearize_index"(%665#2, %30, %35, %666#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%677 = "affine.linearize_index"(%665#1, %30, %35, %666#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%678 = "vector.extract"(%614) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<2x2x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%678, %40, %676, %677) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index) -> ()
"gpu.barrier"() : () -> ()
%679:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%680:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%681 = "affine.linearize_index"(%679#2, %35, %35, %680#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%682 = "affine.linearize_index"(%679#1, %35, %35, %680#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%683 = "vector.transfer_read"(%40, %681, %682, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%684 = "vector.insert_strided_slice"(%683, %5) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<2x1x1x1x1x8xf8E4M3FNUZ>
%685 = "affine.linearize_index"(%679#2, %30, %35, %680#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 16, 1>}> : (index, index, index, index, index) -> index
%686 = "affine.linearize_index"(%679#1, %35, %35, %680#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%687 = "vector.transfer_read"(%40, %685, %686, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<1x8xf8E4M3FNUZ>
%688 = "vector.insert_strided_slice"(%687, %684) <{offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<1x8xf8E4M3FNUZ>, vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<2x1x1x1x1x8xf8E4M3FNUZ>
%689:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%690:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%691 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%692 = "affine.linearize_index"(%689#1, %35, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%693 = "vector.transfer_read"(%41, %691, %692, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%694 = "vector.insert_strided_slice"(%693, %4) <{offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%695 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%696 = "affine.linearize_index"(%689#1, %30, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%697 = "vector.transfer_read"(%41, %695, %696, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%698 = "vector.insert_strided_slice"(%697, %694) <{offsets = [0, 1, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%699 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%700 = "affine.linearize_index"(%689#1, %13, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%701 = "vector.transfer_read"(%41, %699, %700, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%702 = "vector.insert_strided_slice"(%701, %698) <{offsets = [0, 2, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%703 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%704 = "affine.linearize_index"(%689#1, %12, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%705 = "vector.transfer_read"(%41, %703, %704, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%706 = "vector.insert_strided_slice"(%705, %702) <{offsets = [0, 3, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%707 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%708 = "affine.linearize_index"(%689#1, %3, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%709 = "vector.transfer_read"(%41, %707, %708, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%710 = "vector.insert_strided_slice"(%709, %706) <{offsets = [0, 4, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%711 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%712 = "affine.linearize_index"(%689#1, %2, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%713 = "vector.transfer_read"(%41, %711, %712, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%714 = "vector.insert_strided_slice"(%713, %710) <{offsets = [0, 5, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%715 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%716 = "affine.linearize_index"(%689#1, %1, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%717 = "vector.transfer_read"(%41, %715, %716, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%718 = "vector.insert_strided_slice"(%717, %714) <{offsets = [0, 6, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%719 = "affine.linearize_index"(%689#2, %35, %35, %690#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 1, 1, 4, 8>}> : (index, index, index, index, index) -> index
%720 = "affine.linearize_index"(%689#1, %0, %35, %690#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%721 = "vector.transfer_read"(%41, %719, %720, %28) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, index, index, f8E4M3FNUZ) -> vector<8x1xf8E4M3FNUZ>
%722 = "vector.insert_strided_slice"(%721, %718) <{offsets = [0, 7, 0, 0, 0, 0], strides = [1, 1]}> : (vector<8x1xf8E4M3FNUZ>, vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x8x1x1x8x1xf8E4M3FNUZ>
%723 = "vector.extract"(%664) <{static_position = array<i64: 0, 0>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%724 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%725 = "vector.extract"(%722) <{static_position = array<i64: 0, 0>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%726 = "vector.shape_cast"(%724) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%727 = "vector.shape_cast"(%725) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%728 = "vector.shape_cast"(%723) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%729 = "amdgpu.mfma"(%726, %727, %728) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%730 = "vector.shape_cast"(%729) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%731 = "vector.insert"(%730, %27) <{static_position = array<i64: 0, 0>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%732 = "vector.extract"(%664) <{static_position = array<i64: 0, 1>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%733 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%734 = "vector.extract"(%722) <{static_position = array<i64: 0, 1>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%735 = "vector.shape_cast"(%733) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%736 = "vector.shape_cast"(%734) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%737 = "vector.shape_cast"(%732) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%738 = "amdgpu.mfma"(%735, %736, %737) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%739 = "vector.shape_cast"(%738) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%740 = "vector.insert"(%739, %731) <{static_position = array<i64: 0, 1>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%741 = "vector.extract"(%664) <{static_position = array<i64: 0, 2>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%742 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%743 = "vector.extract"(%722) <{static_position = array<i64: 0, 2>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%744 = "vector.shape_cast"(%742) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%745 = "vector.shape_cast"(%743) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%746 = "vector.shape_cast"(%741) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%747 = "amdgpu.mfma"(%744, %745, %746) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%748 = "vector.shape_cast"(%747) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%749 = "vector.insert"(%748, %740) <{static_position = array<i64: 0, 2>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%750 = "vector.extract"(%664) <{static_position = array<i64: 0, 3>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%751 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%752 = "vector.extract"(%722) <{static_position = array<i64: 0, 3>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%753 = "vector.shape_cast"(%751) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%754 = "vector.shape_cast"(%752) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%755 = "vector.shape_cast"(%750) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%756 = "amdgpu.mfma"(%753, %754, %755) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%757 = "vector.shape_cast"(%756) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%758 = "vector.insert"(%757, %749) <{static_position = array<i64: 0, 3>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%759 = "vector.extract"(%664) <{static_position = array<i64: 0, 4>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%760 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%761 = "vector.extract"(%722) <{static_position = array<i64: 0, 4>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%762 = "vector.shape_cast"(%760) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%763 = "vector.shape_cast"(%761) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%764 = "vector.shape_cast"(%759) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%765 = "amdgpu.mfma"(%762, %763, %764) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%766 = "vector.shape_cast"(%765) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%767 = "vector.insert"(%766, %758) <{static_position = array<i64: 0, 4>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%768 = "vector.extract"(%664) <{static_position = array<i64: 0, 5>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%769 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%770 = "vector.extract"(%722) <{static_position = array<i64: 0, 5>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%771 = "vector.shape_cast"(%769) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%772 = "vector.shape_cast"(%770) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%773 = "vector.shape_cast"(%768) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%774 = "amdgpu.mfma"(%771, %772, %773) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%775 = "vector.shape_cast"(%774) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%776 = "vector.insert"(%775, %767) <{static_position = array<i64: 0, 5>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%777 = "vector.extract"(%664) <{static_position = array<i64: 0, 6>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%778 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%779 = "vector.extract"(%722) <{static_position = array<i64: 0, 6>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%780 = "vector.shape_cast"(%778) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%781 = "vector.shape_cast"(%779) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%782 = "vector.shape_cast"(%777) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%783 = "amdgpu.mfma"(%780, %781, %782) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%784 = "vector.shape_cast"(%783) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%785 = "vector.insert"(%784, %776) <{static_position = array<i64: 0, 6>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%786 = "vector.extract"(%664) <{static_position = array<i64: 0, 7>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%787 = "vector.extract"(%688) <{static_position = array<i64: 0, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%788 = "vector.extract"(%722) <{static_position = array<i64: 0, 7>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%789 = "vector.shape_cast"(%787) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%790 = "vector.shape_cast"(%788) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%791 = "vector.shape_cast"(%786) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%792 = "amdgpu.mfma"(%789, %790, %791) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%793 = "vector.shape_cast"(%792) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%794 = "vector.insert"(%793, %785) <{static_position = array<i64: 0, 7>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%795 = "vector.extract"(%664) <{static_position = array<i64: 1, 0>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%796 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%797 = "vector.extract"(%722) <{static_position = array<i64: 0, 0>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%798 = "vector.shape_cast"(%796) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%799 = "vector.shape_cast"(%797) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%800 = "vector.shape_cast"(%795) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%801 = "amdgpu.mfma"(%798, %799, %800) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%802 = "vector.shape_cast"(%801) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%803 = "vector.insert"(%802, %794) <{static_position = array<i64: 1, 0>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%804 = "vector.extract"(%664) <{static_position = array<i64: 1, 1>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%805 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%806 = "vector.extract"(%722) <{static_position = array<i64: 0, 1>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%807 = "vector.shape_cast"(%805) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%808 = "vector.shape_cast"(%806) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%809 = "vector.shape_cast"(%804) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%810 = "amdgpu.mfma"(%807, %808, %809) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%811 = "vector.shape_cast"(%810) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%812 = "vector.insert"(%811, %803) <{static_position = array<i64: 1, 1>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%813 = "vector.extract"(%664) <{static_position = array<i64: 1, 2>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%814 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%815 = "vector.extract"(%722) <{static_position = array<i64: 0, 2>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%816 = "vector.shape_cast"(%814) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%817 = "vector.shape_cast"(%815) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%818 = "vector.shape_cast"(%813) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%819 = "amdgpu.mfma"(%816, %817, %818) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%820 = "vector.shape_cast"(%819) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%821 = "vector.insert"(%820, %812) <{static_position = array<i64: 1, 2>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%822 = "vector.extract"(%664) <{static_position = array<i64: 1, 3>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%823 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%824 = "vector.extract"(%722) <{static_position = array<i64: 0, 3>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%825 = "vector.shape_cast"(%823) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%826 = "vector.shape_cast"(%824) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%827 = "vector.shape_cast"(%822) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%828 = "amdgpu.mfma"(%825, %826, %827) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%829 = "vector.shape_cast"(%828) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%830 = "vector.insert"(%829, %821) <{static_position = array<i64: 1, 3>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%831 = "vector.extract"(%664) <{static_position = array<i64: 1, 4>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%832 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%833 = "vector.extract"(%722) <{static_position = array<i64: 0, 4>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%834 = "vector.shape_cast"(%832) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%835 = "vector.shape_cast"(%833) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%836 = "vector.shape_cast"(%831) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%837 = "amdgpu.mfma"(%834, %835, %836) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%838 = "vector.shape_cast"(%837) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%839 = "vector.insert"(%838, %830) <{static_position = array<i64: 1, 4>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%840 = "vector.extract"(%664) <{static_position = array<i64: 1, 5>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%841 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%842 = "vector.extract"(%722) <{static_position = array<i64: 0, 5>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%843 = "vector.shape_cast"(%841) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%844 = "vector.shape_cast"(%842) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%845 = "vector.shape_cast"(%840) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%846 = "amdgpu.mfma"(%843, %844, %845) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%847 = "vector.shape_cast"(%846) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%848 = "vector.insert"(%847, %839) <{static_position = array<i64: 1, 5>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%849 = "vector.extract"(%664) <{static_position = array<i64: 1, 6>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%850 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%851 = "vector.extract"(%722) <{static_position = array<i64: 0, 6>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%852 = "vector.shape_cast"(%850) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%853 = "vector.shape_cast"(%851) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%854 = "vector.shape_cast"(%849) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%855 = "amdgpu.mfma"(%852, %853, %854) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%856 = "vector.shape_cast"(%855) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%857 = "vector.insert"(%856, %848) <{static_position = array<i64: 1, 6>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%858 = "vector.extract"(%664) <{static_position = array<i64: 1, 7>}> : (vector<2x8x1x1x4x1xf32>) -> vector<1x1x4x1xf32>
%859 = "vector.extract"(%688) <{static_position = array<i64: 1, 0>}> : (vector<2x1x1x1x1x8xf8E4M3FNUZ>) -> vector<1x1x1x8xf8E4M3FNUZ>
%860 = "vector.extract"(%722) <{static_position = array<i64: 0, 7>}> : (vector<1x8x1x1x8x1xf8E4M3FNUZ>) -> vector<1x1x8x1xf8E4M3FNUZ>
%861 = "vector.shape_cast"(%859) : (vector<1x1x1x8xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%862 = "vector.shape_cast"(%860) : (vector<1x1x8x1xf8E4M3FNUZ>) -> vector<8xf8E4M3FNUZ>
%863 = "vector.shape_cast"(%858) : (vector<1x1x4x1xf32>) -> vector<4xf32>
%864 = "amdgpu.mfma"(%861, %862, %863) <{abid = 0 : i32, blgp = #amdgpu<mfma_perm_b none>, blocks = 1 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32}> : (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
%865 = "vector.shape_cast"(%864) : (vector<4xf32>) -> vector<1x1x4x1xf32>
%866 = "vector.insert"(%865, %857) <{static_position = array<i64: 1, 7>}> : (vector<1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
"scf.yield"(%567, %612, %866) : (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>) -> ()
}) : (index, index, index, vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>) -> (vector<2x1x4xf32>, vector<2x1x4xf32>, vector<2x8x1x1x4x1xf32>)
%169 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%170 = "vector.broadcast"(%169) : (vector<4xf32>) -> vector<1x4xf32>
%171 = "vector.insert"(%170, %6) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%172 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%173 = "vector.broadcast"(%172) : (vector<4xf32>) -> vector<1x4xf32>
%174 = "vector.insert"(%173, %171) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%175 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%176 = "vector.broadcast"(%175) : (vector<4xf32>) -> vector<1x4xf32>
%177 = "vector.insert"(%176, %174) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%178 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%179 = "vector.broadcast"(%178) : (vector<4xf32>) -> vector<1x4xf32>
%180 = "vector.insert"(%179, %177) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%181 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%182 = "vector.broadcast"(%181) : (vector<4xf32>) -> vector<1x4xf32>
%183 = "vector.insert"(%182, %180) <{static_position = array<i64: 2, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%184 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%185 = "vector.broadcast"(%184) : (vector<4xf32>) -> vector<1x4xf32>
%186 = "vector.insert"(%185, %183) <{static_position = array<i64: 2, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%187 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%188 = "vector.broadcast"(%187) : (vector<4xf32>) -> vector<1x4xf32>
%189 = "vector.insert"(%188, %186) <{static_position = array<i64: 3, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%190 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%191 = "vector.broadcast"(%190) : (vector<4xf32>) -> vector<1x4xf32>
%192 = "vector.insert"(%191, %189) <{static_position = array<i64: 3, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%193 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%194 = "vector.broadcast"(%193) : (vector<4xf32>) -> vector<1x4xf32>
%195 = "vector.insert"(%194, %192) <{static_position = array<i64: 4, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%196 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%197 = "vector.broadcast"(%196) : (vector<4xf32>) -> vector<1x4xf32>
%198 = "vector.insert"(%197, %195) <{static_position = array<i64: 4, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%199 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%200 = "vector.broadcast"(%199) : (vector<4xf32>) -> vector<1x4xf32>
%201 = "vector.insert"(%200, %198) <{static_position = array<i64: 5, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%202 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%203 = "vector.broadcast"(%202) : (vector<4xf32>) -> vector<1x4xf32>
%204 = "vector.insert"(%203, %201) <{static_position = array<i64: 5, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%205 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%206 = "vector.broadcast"(%205) : (vector<4xf32>) -> vector<1x4xf32>
%207 = "vector.insert"(%206, %204) <{static_position = array<i64: 6, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%208 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%209 = "vector.broadcast"(%208) : (vector<4xf32>) -> vector<1x4xf32>
%210 = "vector.insert"(%209, %207) <{static_position = array<i64: 6, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%211 = "vector.extract"(%168#1) <{static_position = array<i64: 0, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%212 = "vector.broadcast"(%211) : (vector<4xf32>) -> vector<1x4xf32>
%213 = "vector.insert"(%212, %210) <{static_position = array<i64: 7, 0, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%214 = "vector.extract"(%168#1) <{static_position = array<i64: 1, 0>}> : (vector<2x1x4xf32>) -> vector<4xf32>
%215 = "vector.broadcast"(%214) : (vector<4xf32>) -> vector<1x4xf32>
%216 = "vector.insert"(%215, %213) <{static_position = array<i64: 7, 1, 0, 0>}> : (vector<1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%217 = "arith.divf"(%14, %216) <{fastmath = #arith.fastmath<none>}> : (vector<8x2x1x1x1x4xf32>, vector<8x2x1x1x1x4xf32>) -> vector<8x2x1x1x1x4xf32>
%218 = "vector.transpose"(%217) <{permutation = array<i64: 1, 0, 3, 2, 5, 4>}> : (vector<8x2x1x1x1x4xf32>) -> vector<2x8x1x1x4x1xf32>
%219 = "arith.mulf"(%218, %168#2) <{fastmath = #arith.fastmath<none>}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%220 = "vector.transfer_read"(%90, %34) <{in_bounds = [], operandSegmentSizes = array<i32: 1, 0, 1, 0>, permutation_map = affine_map<() -> ()>}> : (memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>, f32) -> vector<f32>
%221 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%222 = "vector.broadcast"(%221) : (f32) -> vector<4x1xf32>
%223 = "vector.insert"(%222, %27) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%224 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%225 = "vector.broadcast"(%224) : (f32) -> vector<4x1xf32>
%226 = "vector.insert"(%225, %223) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%227 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%228 = "vector.broadcast"(%227) : (f32) -> vector<4x1xf32>
%229 = "vector.insert"(%228, %226) <{static_position = array<i64: 0, 2, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%230 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%231 = "vector.broadcast"(%230) : (f32) -> vector<4x1xf32>
%232 = "vector.insert"(%231, %229) <{static_position = array<i64: 0, 3, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%233 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%234 = "vector.broadcast"(%233) : (f32) -> vector<4x1xf32>
%235 = "vector.insert"(%234, %232) <{static_position = array<i64: 0, 4, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%236 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%237 = "vector.broadcast"(%236) : (f32) -> vector<4x1xf32>
%238 = "vector.insert"(%237, %235) <{static_position = array<i64: 0, 5, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%239 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%240 = "vector.broadcast"(%239) : (f32) -> vector<4x1xf32>
%241 = "vector.insert"(%240, %238) <{static_position = array<i64: 0, 6, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%242 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%243 = "vector.broadcast"(%242) : (f32) -> vector<4x1xf32>
%244 = "vector.insert"(%243, %241) <{static_position = array<i64: 0, 7, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%245 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%246 = "vector.broadcast"(%245) : (f32) -> vector<4x1xf32>
%247 = "vector.insert"(%246, %244) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%248 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%249 = "vector.broadcast"(%248) : (f32) -> vector<4x1xf32>
%250 = "vector.insert"(%249, %247) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%251 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%252 = "vector.broadcast"(%251) : (f32) -> vector<4x1xf32>
%253 = "vector.insert"(%252, %250) <{static_position = array<i64: 1, 2, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%254 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%255 = "vector.broadcast"(%254) : (f32) -> vector<4x1xf32>
%256 = "vector.insert"(%255, %253) <{static_position = array<i64: 1, 3, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%257 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%258 = "vector.broadcast"(%257) : (f32) -> vector<4x1xf32>
%259 = "vector.insert"(%258, %256) <{static_position = array<i64: 1, 4, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%260 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%261 = "vector.broadcast"(%260) : (f32) -> vector<4x1xf32>
%262 = "vector.insert"(%261, %259) <{static_position = array<i64: 1, 5, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%263 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%264 = "vector.broadcast"(%263) : (f32) -> vector<4x1xf32>
%265 = "vector.insert"(%264, %262) <{static_position = array<i64: 1, 6, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%266 = "vector.extract"(%220) <{static_position = array<i64>}> : (vector<f32>) -> f32
%267 = "vector.broadcast"(%266) : (f32) -> vector<4x1xf32>
%268 = "vector.insert"(%267, %265) <{static_position = array<i64: 1, 7, 0, 0>}> : (vector<4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%269 = "arith.divf"(%219, %268) <{fastmath = #arith.fastmath<none>}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%270 = "arith.cmpf"(%269, %16) <{fastmath = #arith.fastmath<none>, predicate = 11 : i64}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xi1>
%271 = "arith.select"(%270, %16, %269) : (vector<2x8x1x1x4x1xi1>, vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%272 = "arith.cmpf"(%271, %15) <{fastmath = #arith.fastmath<none>, predicate = 9 : i64}> : (vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xi1>
%273 = "arith.select"(%272, %15, %271) : (vector<2x8x1x1x4x1xi1>, vector<2x8x1x1x4x1xf32>, vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf32>
%274 = "arith.truncf"(%273) : (vector<2x8x1x1x4x1xf32>) -> vector<2x8x1x1x4x1xf8E4M3FNUZ>
%275:4 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 1, 1, 64>}> : (index) -> (index, index, index, index)
%276:3 = "affine.delinearize_index"(%39) <{static_basis = array<i64: 4, 16>}> : (index) -> (index, index, index)
%277 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%278 = "affine.linearize_index"(%275#1, %35, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%279 = "vector.extract"(%274) <{static_position = array<i64: 0, 0, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%279, %103, %277, %278) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%280 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%281 = "affine.linearize_index"(%275#1, %30, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%282 = "vector.extract"(%274) <{static_position = array<i64: 0, 1, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%282, %103, %280, %281) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%283 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%284 = "affine.linearize_index"(%275#1, %13, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%285 = "vector.extract"(%274) <{static_position = array<i64: 0, 2, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%285, %103, %283, %284) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%286 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%287 = "affine.linearize_index"(%275#1, %12, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%288 = "vector.extract"(%274) <{static_position = array<i64: 0, 3, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%288, %103, %286, %287) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%289 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%290 = "affine.linearize_index"(%275#1, %3, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%291 = "vector.extract"(%274) <{static_position = array<i64: 0, 4, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%291, %103, %289, %290) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%292 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%293 = "affine.linearize_index"(%275#1, %2, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%294 = "vector.extract"(%274) <{static_position = array<i64: 0, 5, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%294, %103, %292, %293) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%295 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%296 = "affine.linearize_index"(%275#1, %1, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%297 = "vector.extract"(%274) <{static_position = array<i64: 0, 6, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%297, %103, %295, %296) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%298 = "affine.linearize_index"(%275#2, %35, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%299 = "affine.linearize_index"(%275#1, %0, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%300 = "vector.extract"(%274) <{static_position = array<i64: 0, 7, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%300, %103, %298, %299) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%301 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%302 = "affine.linearize_index"(%275#1, %35, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%303 = "vector.extract"(%274) <{static_position = array<i64: 1, 0, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%303, %103, %301, %302) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%304 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%305 = "affine.linearize_index"(%275#1, %30, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%306 = "vector.extract"(%274) <{static_position = array<i64: 1, 1, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%306, %103, %304, %305) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%307 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%308 = "affine.linearize_index"(%275#1, %13, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%309 = "vector.extract"(%274) <{static_position = array<i64: 1, 2, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%309, %103, %307, %308) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%310 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%311 = "affine.linearize_index"(%275#1, %12, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%312 = "vector.extract"(%274) <{static_position = array<i64: 1, 3, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%312, %103, %310, %311) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%313 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%314 = "affine.linearize_index"(%275#1, %3, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%315 = "vector.extract"(%274) <{static_position = array<i64: 1, 4, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%315, %103, %313, %314) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%316 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%317 = "affine.linearize_index"(%275#1, %2, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%318 = "vector.extract"(%274) <{static_position = array<i64: 1, 5, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%318, %103, %316, %317) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%319 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%320 = "affine.linearize_index"(%275#1, %1, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%321 = "vector.extract"(%274) <{static_position = array<i64: 1, 6, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%321, %103, %319, %320) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
%322 = "affine.linearize_index"(%275#2, %30, %35, %276#1, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 2, 1, 4, 4>}> : (index, index, index, index, index) -> index
%323 = "affine.linearize_index"(%275#1, %0, %35, %276#2, %35) <{disjoint, operandSegmentSizes = array<i32: 5, 0>, static_basis = array<i64: 1, 8, 1, 16, 1>}> : (index, index, index, index, index) -> index
%324 = "vector.extract"(%274) <{static_position = array<i64: 1, 7, 0, 0>}> : (vector<2x8x1x1x4x1xf8E4M3FNUZ>) -> vector<4x1xf8E4M3FNUZ>
"vector.transfer_write"(%324, %103, %322, %323) <{in_bounds = [true, true], operandSegmentSizes = array<i32: 1, 1, 2, 0>, permutation_map = affine_map<(d0, d1) -> (d0, d1)>}> : (vector<4x1xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, index, index) -> ()
"scf.forall.in_parallel"() ({
^bb0:
}) : () -> ()
}) : (index) -> ()
"memref.dealloc"(%44) : (memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%43) : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%42) : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%41) : (memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"memref.dealloc"(%40) : (memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>) -> ()
"func.return"() : () -> ()
}) {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {}>} : () -> ()
}) : () -> ()
"hal.executable.variant_end"() : () -> ()
}) {sym_name = "rocm_hsaco_fb", target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>} : () -> ()
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5]} ins(%collapsed, %collapsed_1, %collapsed_2, %extracted, %arg4 : tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, tensor<32x?x128xf8E4M3FNUZ>, f32, tensor<?x?xf8E4M3FNUZ>) outs(%cast : tensor<32x?x128xf32>) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment