Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save pashu123/74778fcd0526039861913d503e5b8e84 to your computer and use it in GitHub Desktop.
Save pashu123/74778fcd0526039861913d503e5b8e84 to your computer and use it in GitHub Desktop.
hal.executable public @prefill_bs4$async_dispatch_3 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, ukernels = "none"}>) {
hal.executable.export public @prefill_bs4$async_dispatch_3_softmax_4x128256xf32_generic ordinal(0) layout(#hal.pipeline.layout<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) count(%arg0: !hal.device, %arg1: index) -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice %arg1
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @prefill_bs4$async_dispatch_3_softmax_4x128256xf32_generic() {
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c4_i64 = arith.constant 4 : i64
%cst = arith.constant dense_resource<torch_tensor_4_torch.int64> : tensor<4xi64>
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4:2 = util.assume.int
%2<umin = 5224448, umax = 537901056>,
%3<umin = 32, umax = 8160, udiv = 32>
: index, index
%5 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xi64>>
%6 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4x128256xf16>>
%7 = iree_tensor_ext.dispatch.workload.ordinal %4#1, 0 : index
%8 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%4#0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4x?x128256xf16>>{%7}
%9 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0, 0], sizes = [4, %7, 128256], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4x?x128256xf16>>{%7} -> tensor<4x?x128256xf16>
%10 = iree_tensor_ext.dispatch.tensor.load %5, offsets = [0], sizes = [4], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xi64>> -> tensor<4xi64>
%11 = tensor.empty() : tensor<4x128256xf16>
%12 = tensor.empty() : tensor<4x128256xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cst, %10 : tensor<4xi64>, tensor<4xi64>) outs(%12 : tensor<4x128256xf32>) {
^bb0(%in: i64, %in_0: i64, %out: f32):
%16 = arith.subi %in_0, %c1_i64 : i64
%17 = linalg.index 1 : index
%18 = arith.cmpi slt, %in, %c0_i64 : i64
%19 = arith.addi %in, %c4_i64 : i64
%20 = arith.select %18, %19, %in : i64
%21 = arith.index_cast %20 : i64 to index
%22 = arith.cmpi slt, %16, %c0_i64 : i64
%23 = arith.index_castui %7 : index to i64
%24 = arith.addi %16, %23 : i64
%25 = arith.select %22, %24, %16 : i64
%26 = arith.index_cast %25 : i64 to index
%extracted = tensor.extract %9[%21, %26, %17] : tensor<4x?x128256xf16>
%27 = arith.extf %extracted : f16 to f32
linalg.yield %27 : f32
} -> tensor<4x128256xf32>
%14 = linalg.softmax dimension(1) ins(%13 : tensor<4x128256xf32>) outs(%12 : tensor<4x128256xf32>) -> tensor<4x128256xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<4x128256xf32>) outs(%11 : tensor<4x128256xf16>) {
^bb0(%in: f32, %out: f16):
%16 = arith.truncf %in : f32 to f16
%17 = math.log %16 : f16
linalg.yield %17 : f16
} -> tensor<4x128256xf16>
iree_tensor_ext.dispatch.tensor.store %15, %6, offsets = [0, 0], sizes = [4, 128256], strides = [1, 1] : tensor<4x128256xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4x128256xf16>>
return
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment