Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 4, 2025 18:08
Show Gist options
  • Save pashu123/36e462b99fa04749e2c653cab18c24ed to your computer and use it in GitHub Desktop.
Save pashu123/36e462b99fa04749e2c653cab18c24ed to your computer and use it in GitHub Desktop.
#map = affine_map<(d0)[s0] -> (-d0 + s0, 32)>
#map1 = affine_map<()[s0, s1] -> (s0 + s1)>
#map2 = affine_map<(d0) -> (d0)>
#map3 = affine_map<(d0) -> ()>
#pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [4, 1, 1] subgroup_size = 32, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>
module {
func.func @dynamic_softmax() attributes {translation_info = #translation} {
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
%cst = arith.constant dense<0.000000e+00> : vector<32xf16>
%cst_0 = arith.constant dense<0xFE00> : vector<32xf16>
%c32 = arith.constant 32 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%cst_2 = arith.constant 0xFE00 : f16
%c32_i64 = arith.constant 32 : i64
%c0 = arith.constant 0 : index
%thread_id_x = gpu.thread_id x
%alloc = memref.alloc() : memref<f16, #gpu.address_space<workgroup>>
%alloc_3 = memref.alloc() : memref<f16, #gpu.address_space<workgroup>>
%alloc_4 = memref.alloc() : memref<1x32xf16, #gpu.address_space<workgroup>>
%alloc_5 = memref.alloc() : memref<1x32xf16, #gpu.address_space<workgroup>>
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
%2 = arith.extui %0 : i32 to i64
%3 = arith.extui %1 : i32 to i64
%4 = arith.shli %3, %c32_i64 : i64
%5 = arith.ori %2, %4 : i64
%6 = arith.index_castui %5 : i64 to index
%7 = iree_tensor_ext.dispatch.workload.ordinal %6, 0 : index
%8 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>{%7}
%assume_align = memref.assume_alignment %8, 64 : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>
%9 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>{%7}
%assume_align_6 = memref.assume_alignment %9, 64 : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0) in (32) {
vector.transfer_write %cst_0, %alloc_4[%c0, %c0] {in_bounds = [true]} : vector<32xf16>, memref<1x32xf16, #gpu.address_space<workgroup>>
scf.for %arg1 = %c0 to %7 step %c32 {
%16 = affine.min #map(%arg1)[%7]
%subview = memref.subview %alloc_4[0, 0] [1, %16] [1, 1] : memref<1x32xf16, #gpu.address_space<workgroup>> to memref<?xf16, strided<[1]>, #gpu.address_space<workgroup>>
%17:2 = affine.delinearize_index %thread_id_x into (4) : index, index
%18 = arith.subi %16, %c1 : index
%19:2 = affine.delinearize_index %18 into (4, 8) : index, index
%20 = arith.addi %19#1, %c1 : index
%21 = arith.cmpi eq, %17#1, %19#0 : index
%22 = arith.cmpi slt, %17#1, %19#0 : index
%23 = arith.select %22, %c8, %c0 : index
%24 = arith.select %21, %20, %23 : index
%25 = vector.create_mask %24 : vector<8xi1>
%26 = affine.linearize_index disjoint [%17#1, %c0] by (4, 8) : index
%27 = affine.apply #map1()[%arg1, %26]
%28 = vector.transfer_read %assume_align[%arg0, %27], %cst_1, %25 {in_bounds = [true]} : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%alloc_7 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
vector.transfer_write %28, %alloc_7[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
%29 = vector.transfer_read %alloc_4[%c0, %26], %cst_1, %25 {in_bounds = [true]} : memref<1x32xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%alloc_8 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
vector.transfer_write %29, %alloc_8[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
%alloc_9 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
%alloc_10 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} ins(%alloc_7, %alloc_8 : memref<?xf16, #gpu.address_space<workgroup>>, memref<?xf16, #gpu.address_space<workgroup>>) outs(%alloc_9 : memref<?xf16, #gpu.address_space<workgroup>>) {
^bb0(%in: f16, %in_11: f16, %out: f16):
%31 = arith.maxnumf %in, %in_11 : f16
linalg.yield %31 : f16
}
%30 = vector.transfer_read %alloc_9[%26], %cst_1, %25 {in_bounds = [true]} : memref<?xf16, #gpu.address_space<workgroup>>, vector<8xf16>
vector.transfer_write %30, %alloc_10[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %alloc_10, %subview {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<?xf16, #gpu.address_space<workgroup>> to memref<?xf16, strided<[1]>, #gpu.address_space<workgroup>>
gpu.barrier
}
%10 = vector.transfer_read %alloc_4[%c0, %c0], %cst_1 {in_bounds = [true]} : memref<1x32xf16, #gpu.address_space<workgroup>>, vector<32xf16>
%11 = vector.multi_reduction <maxnumf>, %10, %cst_2 [0] : vector<32xf16> to f16
vector.transfer_write %cst, %alloc_5[%c0, %c0] {in_bounds = [true]} : vector<32xf16>, memref<1x32xf16, #gpu.address_space<workgroup>>
%12 = vector.broadcast %11 : f16 to vector<f16>
vector.transfer_write %12, %alloc_3[] : vector<f16>, memref<f16, #gpu.address_space<workgroup>>
scf.for %arg1 = %c0 to %7 step %c32 {
%16 = affine.min #map(%arg1)[%7]
%subview = memref.subview %alloc_5[0, 0] [1, %16] [1, 1] : memref<1x32xf16, #gpu.address_space<workgroup>> to memref<?xf16, strided<[1]>, #gpu.address_space<workgroup>>
%17:2 = affine.delinearize_index %thread_id_x into (4) : index, index
%18 = arith.subi %16, %c1 : index
%19:2 = affine.delinearize_index %18 into (4, 8) : index, index
%20 = arith.addi %19#1, %c1 : index
%21 = arith.cmpi eq, %17#1, %19#0 : index
%22 = arith.cmpi slt, %17#1, %19#0 : index
%23 = arith.select %22, %c8, %c0 : index
%24 = arith.select %21, %20, %23 : index
%25 = vector.create_mask %24 : vector<8xi1>
%26 = affine.linearize_index disjoint [%17#1, %c0] by (4, 8) : index
%27 = affine.apply #map1()[%arg1, %26]
%28 = vector.transfer_read %assume_align[%arg0, %27], %cst_1, %25 {in_bounds = [true]} : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%alloc_7 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
vector.transfer_write %28, %alloc_7[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
%29 = vector.transfer_read %alloc_5[%c0, %26], %cst_1, %25 {in_bounds = [true]} : memref<1x32xf16, #gpu.address_space<workgroup>>, vector<8xf16>
%alloc_8 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
vector.transfer_write %29, %alloc_8[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
%alloc_9 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
%alloc_10 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
linalg.generic {indexing_maps = [#map2, #map3, #map2, #map2], iterator_types = ["parallel"]} ins(%alloc_7, %alloc_3, %alloc_8 : memref<?xf16, #gpu.address_space<workgroup>>, memref<f16, #gpu.address_space<workgroup>>, memref<?xf16, #gpu.address_space<workgroup>>) outs(%alloc_9 : memref<?xf16, #gpu.address_space<workgroup>>) {
^bb0(%in: f16, %in_11: f16, %in_12: f16, %out: f16):
%31 = arith.subf %in, %in_11 : f16
%32 = math.exp %31 : f16
%33 = arith.addf %32, %in_12 : f16
linalg.yield %33 : f16
}
%30 = vector.transfer_read %alloc_9[%26], %cst_1, %25 {in_bounds = [true]} : memref<?xf16, #gpu.address_space<workgroup>>, vector<8xf16>
vector.transfer_write %30, %alloc_10[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %alloc_10, %subview {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<?xf16, #gpu.address_space<workgroup>> to memref<?xf16, strided<[1]>, #gpu.address_space<workgroup>>
gpu.barrier
}
%13 = vector.transfer_read %alloc_5[%c0, %c0], %cst_1 {in_bounds = [true]} : memref<1x32xf16, #gpu.address_space<workgroup>>, vector<32xf16>
%14 = vector.multi_reduction <add>, %13, %cst_1 [0] : vector<32xf16> to f16
%15 = vector.broadcast %14 : f16 to vector<f16>
vector.transfer_write %15, %alloc[] : vector<f16>, memref<f16, #gpu.address_space<workgroup>>
scf.for %arg1 = %c0 to %7 step %c32 {
%16 = affine.min #map(%arg1)[%7]
%subview = memref.subview %assume_align_6[%arg0, %arg1] [1, %16] [1, 1] : memref<32x?xf16, #hal.descriptor_type<storage_buffer>> to memref<?xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%17:2 = affine.delinearize_index %thread_id_x into (4) : index, index
%18 = arith.subi %16, %c1 : index
%19:2 = affine.delinearize_index %18 into (4, 8) : index, index
%20 = arith.addi %19#1, %c1 : index
%21 = arith.cmpi eq, %17#1, %19#0 : index
%22 = arith.cmpi slt, %17#1, %19#0 : index
%23 = arith.select %22, %c8, %c0 : index
%24 = arith.select %21, %20, %23 : index
%25 = vector.create_mask %24 : vector<8xi1>
%26 = affine.linearize_index disjoint [%17#1, %c0] by (4, 8) : index
%27 = affine.apply #map1()[%arg1, %26]
%28 = vector.transfer_read %assume_align[%arg0, %27], %cst_1, %25 {in_bounds = [true]} : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%alloc_7 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
vector.transfer_write %28, %alloc_7[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
%29 = vector.transfer_read %assume_align_6[%arg0, %27], %cst_1, %25 {in_bounds = [true]} : memref<32x?xf16, #hal.descriptor_type<storage_buffer>>, vector<8xf16>
%alloc_8 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
%alloc_9 = memref.alloc(%16) : memref<?xf16, #gpu.address_space<workgroup>>
vector.transfer_write %29, %alloc_8[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
linalg.generic {indexing_maps = [#map2, #map3, #map3, #map2], iterator_types = ["parallel"]} ins(%alloc_7, %alloc_3, %alloc : memref<?xf16, #gpu.address_space<workgroup>>, memref<f16, #gpu.address_space<workgroup>>, memref<f16, #gpu.address_space<workgroup>>) outs(%alloc_8 : memref<?xf16, #gpu.address_space<workgroup>>) {
^bb0(%in: f16, %in_10: f16, %in_11: f16, %out: f16):
%31 = arith.subf %in, %in_10 : f16
%32 = math.exp %31 : f16
%33 = arith.divf %32, %in_11 : f16
linalg.yield %33 : f16
}
%30 = vector.transfer_read %alloc_8[%26], %cst_1, %25 {in_bounds = [true]} : memref<?xf16, #gpu.address_space<workgroup>>, vector<8xf16>
vector.transfer_write %30, %alloc_9[%26], %25 {in_bounds = [true]} : vector<8xf16>, memref<?xf16, #gpu.address_space<workgroup>>
gpu.barrier
memref.copy %alloc_9, %subview {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<?xf16, #gpu.address_space<workgroup>> to memref<?xf16, strided<[1], offset: ?>, #hal.descriptor_type<storage_buffer>>
gpu.barrier
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
memref.dealloc %alloc_5 : memref<1x32xf16, #gpu.address_space<workgroup>>
memref.dealloc %alloc_4 : memref<1x32xf16, #gpu.address_space<workgroup>>
memref.dealloc %alloc_3 : memref<f16, #gpu.address_space<workgroup>>
memref.dealloc %alloc : memref<f16, #gpu.address_space<workgroup>>
return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment