Created
February 14, 2025 11:12
-
-
Save pashu123/e22dba342a9bf78c0ee5ccb0522d3855 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @prefill_bs1$async_dispatch_19_attention_8x4x1xDx32x128xf8E4M3FNUZ_generic() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {}>} { | |
%c0 = arith.constant 0 : index | |
%cst = arith.constant 0.000000e+00 : f32 | |
%c32_i64 = arith.constant 32 : i64 | |
%c67108864 = arith.constant 67108864 : index | |
%c32 = arith.constant 32 : index | |
%c1 = arith.constant 1 : index | |
%cst_0 = arith.constant 1.44269502 : f32 | |
%cst_1 = arith.constant 0.000000e+00 : f8E4M3FNUZ | |
%cst_2 = arith.constant dense<0.000000e+00> : vector<32x128xf32> | |
%cst_3 = arith.constant dense<-3.40282347E+38> : vector<32xf32> | |
%cst_4 = arith.constant dense<0.000000e+00> : vector<32xf32> | |
%cst_5 = arith.constant dense<0.000000e+00> : vector<32x32xf32> | |
%c0_i8 = arith.constant 0 : i8 | |
%c0_i64 = arith.constant 0 : i64 | |
%cst_6 = arith.constant dense<0xFF800000> : vector<32x32xf32> | |
%cst_7 = arith.constant dense<0.00416666688> : vector<32x32xf32> | |
%cst_8 = arith.constant dense<1.44269502> : vector<32x32xf32> | |
%cst_9 = arith.constant dense<2.400000e+02> : vector<32x32xf32> | |
%cst_10 = arith.constant dense<-2.400000e+02> : vector<32x128xf32> | |
%cst_11 = arith.constant dense<2.400000e+02> : vector<32x128xf32> | |
%cst_12 = arith.constant dense<1.000000e+00> : vector<128x32xf32> | |
%alloc = memref.alloc() : memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%alloc_13 = memref.alloc() : memref<128x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%alloc_14 = memref.alloc() : memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%alloc_15 = memref.alloc() : memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%alloc_16 = memref.alloc() : memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%0 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32 | |
%1 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32 | |
%2 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32 | |
%3 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32 | |
%4 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32 | |
%5 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32 | |
%6 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(6) : i32 | |
%7 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(7) : i32 | |
%8 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(8) : i32 | |
%9 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(9) : i32 | |
%10 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(10) : i32 | |
%11 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(11) : i32 | |
%12 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(12) : i32 | |
%13 = hal.interface.constant.load layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(13) : i32 | |
%14 = arith.extui %0 : i32 to i64 | |
%15 = arith.extui %1 : i32 to i64 | |
%16 = arith.shli %15, %c32_i64 : i64 | |
%17 = arith.ori %14, %16 : i64 | |
%18 = arith.index_castui %17 : i64 to index | |
%19 = arith.extui %2 : i32 to i64 | |
%20 = arith.extui %3 : i32 to i64 | |
%21 = arith.shli %20, %c32_i64 : i64 | |
%22 = arith.ori %19, %21 : i64 | |
%23 = arith.index_castui %22 : i64 to index | |
%24 = arith.extui %4 : i32 to i64 | |
%25 = arith.extui %5 : i32 to i64 | |
%26 = arith.shli %25, %c32_i64 : i64 | |
%27 = arith.ori %24, %26 : i64 | |
%28 = arith.index_castui %27 : i64 to index | |
%29 = arith.extui %6 : i32 to i64 | |
%30 = arith.extui %7 : i32 to i64 | |
%31 = arith.shli %30, %c32_i64 : i64 | |
%32 = arith.ori %29, %31 : i64 | |
%33 = arith.index_castui %32 {stream.alignment = 64 : index, stream.values = [1075847616 : index, 1293968512 : index, 1512089408 : index, 1730210304 : index, 1948331200 : index, 2166452096 : index, 2384572992 : index, 2602693888 : index, 2820814784 : index, 3038935680 : index, 3257056576 : index, 3475177472 : index, 3693298368 : index, 3911419264 : index, 4129540160 : index, 4347661056 : index, 4565781952 : index, 4783902848 : index, 5002023744 : index, 5220144640 : index, 5438265536 : index, 5656386432 : index, 5874507328 : index, 6092628224 : index, 6310749120 : index, 6528870016 : index, 6746990912 : index, 6965111808 : index, 7183232704 : index, 7401353600 : index, 7619474496 : index, 7837595392 : index]} : i64 to index | |
%34 = arith.extui %8 : i32 to i64 | |
%35 = arith.extui %9 : i32 to i64 | |
%36 = arith.shli %35, %c32_i64 : i64 | |
%37 = arith.ori %34, %36 : i64 | |
%38 = arith.index_castui %37 : i64 to index | |
%39 = arith.bitcast %10 : i32 to f32 | |
%40 = arith.index_castui %11 : i32 to index | |
%41 = arith.index_castui %12 : i32 to index | |
%42 = arith.index_castui %13 : i32 to index | |
%43:8 = util.assume.int | |
%18<umin = 68027392, umax = 20995769344>, | |
%23<umin = 68158464, umax = 21532509184>, | |
%28[<umin = 67765248, umax = 19922289664>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>, <umin = 67634176, umax = 19385549824>], | |
%33[<umin = 1075847616, umax = 1075847616, udiv = 1075847616>, <umin = 1293968512, umax = 1293968512, udiv = 1293968512>, <umin = 1512089408, umax = 1512089408, udiv = 1512089408>, <umin = 1730210304, umax = 1730210304, udiv = 1730210304>, <umin = 1948331200, umax = 1948331200, udiv = 1948331200>, <umin = 2166452096, umax = 2166452096, udiv = 2166452096>, <umin = 2384572992, umax = 2384572992, udiv = 2384572992>, <umin = 2602693888, umax = 2602693888, udiv = 2602693888>, <umin = 2820814784, umax = 2820814784, udiv = 2820814784>, <umin = 3038935680, umax = 3038935680, udiv = 3038935680>, <umin = 3257056576, umax = 3257056576, udiv = 3257056576>, <umin = 3475177472, umax = 3475177472, udiv = 3475177472>, <umin = 3693298368, umax = 3693298368, udiv = 3693298368>, <umin = 3911419264, umax = 3911419264, udiv = 3911419264>, <umin = 4129540160, umax = 4129540160, udiv = 4129540160>, <umin = 4347661056, umax = 4347661056, udiv = 4347661056>, <umin = 4565781952, umax = 4565781952, udiv = 4565781952>, <umin = 4783902848, umax = 4783902848, udiv = 4783902848>, <umin = 5002023744, umax = 5002023744, udiv = 5002023744>, <umin = 5220144640, umax = 5220144640, udiv = 5220144640>, <umin = 5438265536, umax = 5438265536, udiv = 5438265536>, <umin = 5656386432, umax = 5656386432, udiv = 5656386432>, <umin = 5874507328, umax = 5874507328, udiv = 5874507328>, <umin = 6092628224, umax = 6092628224, udiv = 6092628224>, <umin = 6310749120, umax = 6310749120, udiv = 6310749120>, <umin = 6528870016, umax = 6528870016, udiv = 6528870016>, <umin = 6746990912, umax = 6746990912, udiv = 6746990912>, <umin = 6965111808, umax = 6965111808, udiv = 6965111808>, <umin = 7183232704, umax = 7183232704, udiv = 7183232704>, <umin = 7401353600, umax = 7401353600, udiv = 7401353600>, <umin = 7619474496, umax = 7619474496, udiv = 7619474496>, <umin = 7837595392, umax = 7837595392, udiv = 7837595392>], | |
%38<umin = 67896320, umax = 20459029504>, | |
%40<umin = 1, umax = 4095>, | |
%41<umin = 32, umax = 131040, udiv = 32>, | |
%42<umin = 32, umax = 131040, udiv = 32> | |
: index, index, index, index, index, index, index, index | |
%44 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<i64, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %44, 64 : memref<i64, #hal.descriptor_type<storage_buffer>> | |
%45 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%43#3) flags(ReadOnly) : memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %45, 64 : memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%46 = flow.dispatch.workload.ordinal %43#5, 0 : index | |
%47 = flow.dispatch.workload.ordinal %43#6, 1 : index | |
%48 = flow.dispatch.workload.ordinal %43#5, 2 : index | |
%49 = flow.dispatch.workload.ordinal %43#7, 3 : index | |
%50 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c67108864) flags("ReadOnly|Indirect") : memref<?x32x?xi8, strided<[?, ?, 1], offset: 67108864>, #hal.descriptor_type<storage_buffer>>{%46, %47} | |
memref.assume_alignment %50, 64 : memref<?x32x?xi8, strided<[?, ?, 1], offset: 67108864>, #hal.descriptor_type<storage_buffer>> | |
%51 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%43#0) flags("ReadOnly|Indirect") : memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%48} | |
memref.assume_alignment %51, 1 : memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%52 = arith.divsi %49, %c32 : index | |
%53 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%43#1) flags("ReadOnly|Indirect") : memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%52} | |
memref.assume_alignment %53, 1 : memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%54 = arith.divsi %47, %c32 : index | |
%55 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%43#2) flags("ReadOnly|Indirect") : memref<8x4x128x?x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%54} | |
memref.assume_alignment %55, 1 : memref<8x4x128x?x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%56 = hal.interface.binding.subspan layout(<constants = 14, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%43#4) flags(Indirect) : memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%46} | |
memref.assume_alignment %56, 1 : memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
scf.forall (%arg0, %arg1, %arg2) in (8, 4, %48) { | |
gpu.barrier | |
%subview = memref.subview %56[0, %arg2, 0, %arg0, %arg1, 0] [1, 1, 32, 1, 1, 128] [1, 1, 1, 1, 1, 1] : memref<1x?x32x8x4x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x32x1x1x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%subview_17 = memref.subview %subview[0, 0, 0, 0, 0, 0] [1, 1, 32, 1, 1, 128] [1, 1, 1, 1, 1, 1] : memref<1x1x32x1x1x128xf8E4M3FNUZ, strided<[?, 131072, 4096, 512, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%57 = vector.transfer_read %51[%arg0, %arg1, %c0, %arg2, %c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<8x4x1x?x32x128xf8E4M3FNUZ, strided<[?, ?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x128xf8E4M3FNUZ> | |
%58 = iree_vector_ext.to_layout %57 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [4, 1], outer_tile = [1, 1], thread_tile = [8, 8], element_tile = [1, 16], subgroup_strides = [0, 0], thread_strides = [8, 1]>) : vector<32x128xf8E4M3FNUZ> | |
%59 = arith.mulf %39, %cst_0 : f32 | |
vector.transfer_write %58, %alloc_15[%c0, %c0] {in_bounds = [true, true]} : vector<32x128xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
gpu.barrier | |
%60 = vector.transfer_read %alloc_15[%c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, vector<32x128xf8E4M3FNUZ> | |
%61 = iree_vector_ext.to_layout %60 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 4], outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 8], subgroup_strides = [0, 0], thread_strides = [1, 16]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x128xf8E4M3FNUZ> | |
%62 = iree_vector_ext.to_layout %cst_5 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 2], outer_tile = [1, 1], thread_tile = [4, 16], element_tile = [4, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x32xf32> | |
%63 = vector.transfer_read %44[], %c0_i64 : memref<i64, #hal.descriptor_type<storage_buffer>>, vector<i64> | |
%64 = vector.broadcast %63 : vector<i64> to vector<32x32xi64> | |
%65 = vector.step : vector<32xindex> | |
%66 = vector.broadcast %59 : f32 to vector<32x32xf32> | |
%67:3 = scf.for %arg3 = %c0 to %52 step %c1 iter_args(%arg4 = %cst_3, %arg5 = %cst_4, %arg6 = %cst_2) -> (vector<32xf32>, vector<32xf32>, vector<32x128xf32>) { | |
gpu.barrier | |
%subview_18 = memref.subview %55[%arg0, %arg1, 0, %arg3, 0] [1, 1, 128, 1, 32] [1, 1, 1, 1, 1] : memref<8x4x128x?x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x128x1x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%subview_19 = memref.subview %subview_18[0, 0, 0, 0, 0] [1, 1, 128, 1, 32] [1, 1, 1, 1, 1] : memref<1x1x128x1x32xf8E4M3FNUZ, strided<[?, ?, ?, 32, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<128x32xf8E4M3FNUZ, strided<[?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%80 = vector.transfer_read %53[%arg0, %arg1, %arg3, %c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<8x4x?x32x128xf8E4M3FNUZ, strided<[?, ?, 4096, 128, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<32x128xf8E4M3FNUZ> | |
%81 = iree_vector_ext.to_layout %80 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [4, 1], outer_tile = [1, 1], thread_tile = [8, 8], element_tile = [1, 16], subgroup_strides = [0, 0], thread_strides = [8, 1]>) : vector<32x128xf8E4M3FNUZ> | |
%82 = vector.transfer_read %subview_19[%c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<128x32xf8E4M3FNUZ, strided<[?, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<128x32xf8E4M3FNUZ> | |
%83 = iree_vector_ext.to_layout %82 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [4, 1], outer_tile = [1, 1], thread_tile = [32, 2], element_tile = [1, 16], subgroup_strides = [0, 0], thread_strides = [2, 1]>) : vector<128x32xf8E4M3FNUZ> | |
%84 = affine.linearize_index disjoint [%arg3, %c0] by (%54, 32) : index | |
%85 = vector.transfer_read %50[%arg2, %c0, %84], %c0_i8 {in_bounds = [true, true]} : memref<?x32x?xi8, strided<[?, ?, 1], offset: 67108864>, #hal.descriptor_type<storage_buffer>>, vector<32x32xi8> | |
%86 = arith.trunci %85 : vector<32x32xi8> to vector<32x32xi1> | |
%87 = vector.broadcast %84 : index to vector<32xindex> | |
%88 = arith.addi %87, %65 : vector<32xindex> | |
%89 = arith.index_cast %88 : vector<32xindex> to vector<32xi64> | |
%90 = vector.broadcast %89 : vector<32xi64> to vector<32x32xi64> | |
%91 = arith.cmpi sge, %90, %64 : vector<32x32xi64> | |
%92 = arith.ori %86, %91 : vector<32x32xi1> | |
%93 = arith.select %92, %cst_6, %cst_5 : vector<32x32xi1>, vector<32x32xf32> | |
%94 = arith.truncf %93 : vector<32x32xf32> to vector<32x32xf8E4M3FNUZ> | |
vector.transfer_write %94, %alloc_16[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<32x32xf8E4M3FNUZ>, memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%expand_shape = memref.expand_shape %alloc_16 [[0, 1], [2], [3, 4]] output_shape [1, 1, 32, 1, 32] : memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> into memref<1x1x32x1x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%subview_20 = memref.subview %expand_shape[0, 0, 0, 0, 0] [1, 1, 32, 1, 32] [1, 1, 1, 1, 1] : memref<1x1x32x1x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> to memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>> | |
vector.transfer_write %81, %alloc_14[%c0, %c0] {in_bounds = [true, true]} : vector<32x128xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
vector.transfer_write %83, %alloc_13[%c0, %c0] {in_bounds = [true, true]} : vector<128x32xf8E4M3FNUZ>, memref<128x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
%95 = vector.transfer_read %subview_20[%c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x32xf8E4M3FNUZ, strided<[32, 1]>, #gpu.address_space<workgroup>>, vector<32x32xf8E4M3FNUZ> | |
%96 = arith.extf %95 : vector<32x32xf8E4M3FNUZ> to vector<32x32xf32> | |
%97 = arith.mulf %96, %cst_8 : vector<32x32xf32> | |
gpu.barrier | |
%98 = vector.transfer_read %alloc_14[%c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>>, vector<32x128xf8E4M3FNUZ> | |
%99 = iree_vector_ext.to_layout %98 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 4], outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 8], subgroup_strides = [0, 0], thread_strides = [1, 16]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x128xf8E4M3FNUZ> | |
%100 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %61, %99, %62 {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x128xf8E4M3FNUZ>, vector<32x128xf8E4M3FNUZ> into vector<32x32xf32> | |
%101 = iree_vector_ext.to_layout %100 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 2], outer_tile = [1, 1], thread_tile = [4, 16], element_tile = [4, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x32xf32> | |
%102 = arith.mulf %66, %101 : vector<32x32xf32> | |
%103 = arith.addf %102, %cst_7 : vector<32x32xf32> | |
%104 = arith.addf %103, %97 : vector<32x32xf32> | |
%105 = vector.multi_reduction <maximumf>, %104, %arg4 [1] : vector<32x32xf32> to vector<32xf32> | |
%106 = arith.subf %arg4, %105 : vector<32xf32> | |
%107 = math.exp2 %106 : vector<32xf32> | |
%108 = arith.mulf %107, %arg5 : vector<32xf32> | |
%109 = vector.broadcast %105 : vector<32xf32> to vector<32x32xf32> | |
%110 = vector.transpose %109, [1, 0] : vector<32x32xf32> to vector<32x32xf32> | |
%111 = arith.subf %104, %110 : vector<32x32xf32> | |
%112 = math.exp2 %111 : vector<32x32xf32> | |
%113 = vector.multi_reduction <add>, %112, %108 [1] : vector<32x32xf32> to vector<32xf32> | |
%114 = arith.minimumf %112, %cst_9 : vector<32x32xf32> | |
%115 = arith.truncf %114 : vector<32x32xf32> to vector<32x32xf8E4M3FNUZ> | |
%116 = vector.broadcast %107 : vector<32xf32> to vector<128x32xf32> | |
%117 = vector.transpose %116, [1, 0] : vector<128x32xf32> to vector<32x128xf32> | |
%118 = arith.mulf %117, %arg6 : vector<32x128xf32> | |
vector.transfer_write %115, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<32x32xf8E4M3FNUZ>, memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
gpu.barrier | |
%119 = vector.transfer_read %alloc[%c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, vector<32x32xf8E4M3FNUZ> | |
%120 = iree_vector_ext.to_layout %119 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 1], outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 8], subgroup_strides = [0, 0], thread_strides = [1, 16]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x32xf8E4M3FNUZ> | |
%121 = vector.transfer_read %alloc_13[%c0, %c0], %cst_1 {in_bounds = [true, true]} : memref<128x32xf8E4M3FNUZ, #gpu.address_space<workgroup>>, vector<128x32xf8E4M3FNUZ> | |
%122 = iree_vector_ext.to_layout %121 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [8, 1], outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 8], subgroup_strides = [0, 0], thread_strides = [1, 16]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<128x32xf8E4M3FNUZ> | |
%123 = iree_vector_ext.to_layout %118 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 8], outer_tile = [1, 1], thread_tile = [4, 16], element_tile = [4, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x128xf32> | |
%124 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %120, %122, %123 {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x32xf8E4M3FNUZ>, vector<128x32xf8E4M3FNUZ> into vector<32x128xf32> | |
%125 = iree_vector_ext.to_layout %124 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [2, 8], outer_tile = [1, 1], thread_tile = [4, 16], element_tile = [4, 1], subgroup_strides = [0, 0], thread_strides = [16, 1]>) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F8E4M3FNUZ>} : vector<32x128xf32> | |
scf.yield %105, %113, %125 : vector<32xf32>, vector<32xf32>, vector<32x128xf32> | |
} | |
%68 = vector.broadcast %67#1 : vector<32xf32> to vector<128x32xf32> | |
%69 = arith.divf %cst_12, %68 : vector<128x32xf32> | |
%70 = vector.transpose %69, [1, 0] : vector<128x32xf32> to vector<32x128xf32> | |
%71 = arith.mulf %70, %67#2 : vector<32x128xf32> | |
%72 = vector.transfer_read %45[], %cst : memref<f32, strided<[], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<f32> | |
%73 = vector.broadcast %72 : vector<f32> to vector<32x128xf32> | |
%74 = arith.divf %71, %73 : vector<32x128xf32> | |
%75 = arith.cmpf ult, %74, %cst_10 : vector<32x128xf32> | |
%76 = arith.select %75, %cst_10, %74 : vector<32x128xi1>, vector<32x128xf32> | |
%77 = arith.cmpf ugt, %76, %cst_11 : vector<32x128xf32> | |
%78 = arith.select %77, %cst_11, %76 : vector<32x128xi1>, vector<32x128xf32> | |
%79 = arith.truncf %78 : vector<32x128xf32> to vector<32x128xf8E4M3FNUZ> | |
vector.transfer_write %79, %subview_17[%c0, %c0] {in_bounds = [true, true]} : vector<32x128xf8E4M3FNUZ>, memref<32x128xf8E4M3FNUZ, strided<[4096, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
} {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]} | |
memref.dealloc %alloc_16 : memref<1x32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
memref.dealloc %alloc_15 : memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
memref.dealloc %alloc_14 : memref<32x128xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
memref.dealloc %alloc_13 : memref<128x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
memref.dealloc %alloc : memref<32x32xf8E4M3FNUZ, #gpu.address_space<workgroup>> | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment