Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 17, 2025 14:22
Show Gist options
  • Save pashu123/d1ac0b27dd90466410c68eed5a8687f6 to your computer and use it in GitHub Desktop.
Save pashu123/d1ac0b27dd90466410c68eed5a8687f6 to your computer and use it in GitHub Desktop.
func.func @softy_dispatch_0_softmax_4x128256xf32_generic() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0xFFC00000 : f32
%c32_i64 = arith.constant 32 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i64 = arith.constant 0 : i64
%c4_i64 = arith.constant 4 : i64
%cst_1 = arith.constant dense_resource<torch_tensor_4_torch.int64> : tensor<4xi64>
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
%2 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
%3 = hal.interface.constant.load layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
%4 = arith.extui %0 : i32 to i64
%5 = arith.extui %1 : i32 to i64
%6 = arith.shli %5, %c32_i64 : i64
%7 = arith.ori %4, %6 : i64
%8 = arith.index_castui %7 : i64 to index
%9 = arith.extui %2 : i32 to i64
%10 = arith.extui %3 : i32 to i64
%11 = arith.shli %10, %c32_i64 : i64
%12 = arith.ori %9, %11 : i64
%13 = arith.index_castui %12 : i64 to index
%14 = util.assume.int %13<umin = 0, umax = 9007199254740991> : index
%15 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xi64>>
%16 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4x128256xf16>>
%17 = iree_tensor_ext.dispatch.workload.ordinal %14, 1 : index
%18 = hal.interface.binding.subspan layout(<constants = 4, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4x?x128256xf16>>{%17}
%19 = iree_tensor_ext.dispatch.workload.ordinal %8, 0 : index
%20 = iree_tensor_ext.dispatch.tensor.load %18, offsets = [0, 0, 0], sizes = [4, %17, 128256], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4x?x128256xf16>>{%17} -> tensor<4x?x128256xf16>
%21 = iree_tensor_ext.dispatch.tensor.load %15, offsets = [0], sizes = [4], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xi64>> -> tensor<4xi64>
%22 = tensor.empty() : tensor<4x128256xf16>
%23 = scf.forall (%arg0) in (4) shared_outs(%arg1 = %22) -> (tensor<4x128256xf16>) {
%extracted_slice = tensor.extract_slice %cst_1[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%extracted_slice_2 = tensor.extract_slice %21[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%24 = tensor.empty() : tensor<1x128256xf32>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_2 : tensor<1xi64>, tensor<1xi64>) outs(%24 : tensor<1x128256xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: i64, %in_10: i64, %out: f32):
%42 = arith.subi %in_10, %c1_i64 : i64
%43 = linalg.index 1 : index
%44 = arith.cmpi slt, %in, %c0_i64 : i64
%45 = arith.addi %in, %c4_i64 : i64
%46 = arith.select %44, %45, %in : i64
%47 = arith.index_cast %46 : i64 to index
%48 = arith.cmpi slt, %42, %c0_i64 : i64
%49 = arith.index_castui %19 : index to i64
%50 = arith.addi %42, %49 : i64
%51 = arith.select %48, %50, %42 : i64
%52 = arith.index_cast %51 : i64 to index
%extracted = tensor.extract %20[%47, %52, %43] : tensor<4x?x128256xf16>
%53 = arith.extf %extracted : f16 to f32
linalg.yield %53 : f32
} -> tensor<1x128256xf32>
%extracted_slice_3 = tensor.extract_slice %cst_1[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%extracted_slice_4 = tensor.extract_slice %21[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%26 = tensor.empty() : tensor<1x128256xf32>
%27 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_3, %extracted_slice_4 : tensor<1xi64>, tensor<1xi64>) outs(%26 : tensor<1x128256xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: i64, %in_10: i64, %out: f32):
%42 = arith.subi %in_10, %c1_i64 : i64
%43 = linalg.index 1 : index
%44 = arith.cmpi slt, %in, %c0_i64 : i64
%45 = arith.addi %in, %c4_i64 : i64
%46 = arith.select %44, %45, %in : i64
%47 = arith.index_cast %46 : i64 to index
%48 = arith.cmpi slt, %42, %c0_i64 : i64
%49 = arith.index_castui %19 : index to i64
%50 = arith.addi %42, %49 : i64
%51 = arith.select %48, %50, %42 : i64
%52 = arith.index_cast %51 : i64 to index
%extracted = tensor.extract %20[%47, %52, %43] : tensor<4x?x128256xf16>
%53 = arith.extf %extracted : f16 to f32
linalg.yield %53 : f32
} -> tensor<1x128256xf32>
%28 = tensor.empty() : tensor<1xf32>
%29 = linalg.fill ins(%cst_0 : f32) outs(%28 : tensor<1xf32>) -> tensor<1xf32>
%30 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%27 : tensor<1x128256xf32>) outs(%29 : tensor<1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %out: f32):
%42 = arith.maxnumf %in, %out : f32
linalg.yield %42 : f32
} -> tensor<1xf32>
%extracted_slice_5 = tensor.extract_slice %cst_1[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%extracted_slice_6 = tensor.extract_slice %21[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%31 = tensor.empty() : tensor<1x128256xf32>
%32 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_5, %extracted_slice_6 : tensor<1xi64>, tensor<1xi64>) outs(%31 : tensor<1x128256xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: i64, %in_10: i64, %out: f32):
%42 = arith.subi %in_10, %c1_i64 : i64
%43 = linalg.index 1 : index
%44 = arith.cmpi slt, %in, %c0_i64 : i64
%45 = arith.addi %in, %c4_i64 : i64
%46 = arith.select %44, %45, %in : i64
%47 = arith.index_cast %46 : i64 to index
%48 = arith.cmpi slt, %42, %c0_i64 : i64
%49 = arith.index_castui %19 : index to i64
%50 = arith.addi %42, %49 : i64
%51 = arith.select %48, %50, %42 : i64
%52 = arith.index_cast %51 : i64 to index
%extracted = tensor.extract %20[%47, %52, %43] : tensor<4x?x128256xf16>
%53 = arith.extf %extracted : f16 to f32
linalg.yield %53 : f32
} -> tensor<1x128256xf32>
%extracted_slice_7 = tensor.extract_slice %cst_1[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%extracted_slice_8 = tensor.extract_slice %21[%arg0] [1] [1] : tensor<4xi64> to tensor<1xi64>
%33 = tensor.empty() : tensor<1x128256xf32>
%34 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_7, %extracted_slice_8 : tensor<1xi64>, tensor<1xi64>) outs(%33 : tensor<1x128256xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: i64, %in_10: i64, %out: f32):
%42 = arith.subi %in_10, %c1_i64 : i64
%43 = linalg.index 1 : index
%44 = arith.cmpi slt, %in, %c0_i64 : i64
%45 = arith.addi %in, %c4_i64 : i64
%46 = arith.select %44, %45, %in : i64
%47 = arith.index_cast %46 : i64 to index
%48 = arith.cmpi slt, %42, %c0_i64 : i64
%49 = arith.index_castui %19 : index to i64
%50 = arith.addi %42, %49 : i64
%51 = arith.select %48, %50, %42 : i64
%52 = arith.index_cast %51 : i64 to index
%extracted = tensor.extract %20[%47, %52, %43] : tensor<4x?x128256xf16>
%53 = arith.extf %extracted : f16 to f32
linalg.yield %53 : f32
} -> tensor<1x128256xf32>
%35 = tensor.empty() : tensor<1xf32>
%36 = linalg.fill ins(%cst_0 : f32) outs(%35 : tensor<1xf32>) -> tensor<1xf32>
%37 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%34 : tensor<1x128256xf32>) outs(%36 : tensor<1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %out: f32):
%42 = arith.maxnumf %in, %out : f32
linalg.yield %42 : f32
} -> tensor<1xf32>
%38 = tensor.empty() : tensor<1xf32>
%39 = linalg.fill ins(%cst : f32) outs(%38 : tensor<1xf32>) -> tensor<1xf32>
%40 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%32, %37 : tensor<1x128256xf32>, tensor<1xf32>) outs(%39 : tensor<1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %in_10: f32, %out: f32):
%42 = arith.subf %in, %in_10 : f32
%43 = math.exp %42 : f32
%44 = arith.addf %43, %out : f32
linalg.yield %44 : f32
} -> tensor<1xf32>
%extracted_slice_9 = tensor.extract_slice %arg1[%arg0, 0] [1, 128256] [1, 1] : tensor<4x128256xf16> to tensor<1x128256xf16>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%25, %30, %40 : tensor<1x128256xf32>, tensor<1xf32>, tensor<1xf32>) outs(%extracted_slice_9 : tensor<1x128256xf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 256], subgroup_basis = [[1, 1], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %in_10: f32, %in_11: f32, %out: f16):
%42 = arith.subf %in, %in_10 : f32
%43 = math.exp %42 : f32
%44 = arith.divf %43, %in_11 : f32
%45 = arith.truncf %44 : f32 to f16
%46 = math.log %45 : f16
linalg.yield %46 : f16
} -> tensor<1x128256xf16>
scf.forall.in_parallel {
tensor.parallel_insert_slice %41 into %arg1[%arg0, 0] [1, 128256] [1, 1] : tensor<1x128256xf16> into tensor<4x128256xf16>
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
iree_tensor_ext.dispatch.tensor.store %23, %16, offsets = [0, 0], sizes = [4, 128256], strides = [1, 1] : tensor<4x128256xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4x128256xf16>>
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment