Created
May 8, 2025 07:14
-
-
Save pashu123/0443f27cd15e79f07a059d136c9a3c8b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| func.func @softmax_4d_dispatch_0_softmax_1x1x16384x16384xf32_dispatch_tensor_store() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [512, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>} { | |
| %cst = arith.constant dense<0.000000e+00> : vector<2048xf32> | |
| %cst_0 = arith.constant dense<0xFFC00000> : vector<2048xf32> | |
| %c2048 = arith.constant 2048 : index | |
| %c16384 = arith.constant 16384 : index | |
| %cst_1 = arith.constant 0.000000e+00 : f32 | |
| %cst_2 = arith.constant 0xFFC00000 : f32 | |
| %c0 = arith.constant 0 : index | |
| %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> | |
| %1 = amdgpu.fat_raw_buffer_cast %0 resetOffset : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> to memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
| memref.assume_alignment %1, 64 : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
| %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> | |
| %3 = amdgpu.fat_raw_buffer_cast %2 resetOffset : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> to memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
| memref.assume_alignment %3, 64 : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
| scf.forall (%arg0) in (16384) { | |
| %subview = memref.subview %3[0, 0, %arg0, 0] [1, 1, 1, 16384] [1, 1, 1, 1] : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| %4 = scf.for %arg1 = %c0 to %c16384 step %c2048 iter_args(%arg2 = %cst_0) -> (vector<2048xf32>) { | |
| %15 = vector.transfer_read %1[%c0, %c0, %arg0, %arg1], %cst_1 {in_bounds = [true]} : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<2048xf32> | |
| %16 = iree_vector_ext.to_layout %15 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| %17 = iree_vector_ext.to_layout %arg2 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| %18 = arith.maxnumf %16, %17 : vector<2048xf32> | |
| %19 = iree_vector_ext.to_layout %18 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| scf.yield %19 : vector<2048xf32> | |
| } | |
| %5 = vector.multi_reduction <maxnumf>, %4, %cst_2 [0] : vector<2048xf32> to f32 | |
| %6 = vector.broadcast %5 : f32 to vector<f32> | |
| %7 = iree_vector_ext.to_layout %6 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [], batch_tile = [], outer_tile = [], thread_tile = [], element_tile = [], subgroup_strides = [], thread_strides = []>) : vector<f32> | |
| %8 = vector.broadcast %7 : vector<f32> to vector<2048xf32> | |
| %9 = scf.for %arg1 = %c0 to %c16384 step %c2048 iter_args(%arg2 = %cst) -> (vector<2048xf32>) { | |
| %15 = vector.transfer_read %1[%c0, %c0, %arg0, %arg1], %cst_1 {in_bounds = [true]} : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<2048xf32> | |
| %16 = iree_vector_ext.to_layout %15 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| %17 = iree_vector_ext.to_layout %arg2 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| %18 = arith.subf %16, %8 : vector<2048xf32> | |
| %19 = math.exp %18 : vector<2048xf32> | |
| %20 = arith.addf %19, %17 : vector<2048xf32> | |
| %21 = iree_vector_ext.to_layout %20 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| scf.yield %21 : vector<2048xf32> | |
| } | |
| %10 = vector.multi_reduction <add>, %9, %cst_1 [0] : vector<2048xf32> to f32 | |
| %11 = vector.broadcast %10 : f32 to vector<f32> | |
| %alloc = memref.alloc() : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>> | |
| %12 = iree_vector_ext.to_layout %11 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [], batch_tile = [], outer_tile = [], thread_tile = [], element_tile = [], subgroup_strides = [], thread_strides = []>) : vector<f32> | |
| %13 = vector.broadcast %12 : vector<f32> to vector<2048xf32> | |
| %14:2 = scf.for %arg1 = %c0 to %c16384 step %c2048 iter_args(%arg2 = %alloc, %arg3 = %subview) -> (memref<1x1x1x16384xf32, #gpu.address_space<workgroup>>, memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>) { | |
| %subview_4 = memref.subview %arg3[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| %subview_5 = memref.subview %arg2[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>> to memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
| %15 = vector.transfer_read %1[%c0, %c0, %arg0, %arg1], %cst_1 {in_bounds = [true]} : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<2048xf32> | |
| %16 = iree_vector_ext.to_layout %15 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| %17 = arith.subf %16, %8 : vector<2048xf32> | |
| %18 = math.exp %17 : vector<2048xf32> | |
| %19 = arith.divf %18, %13 : vector<2048xf32> | |
| %20 = iree_vector_ext.to_layout %18 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| vector.transfer_write %20, %subview_5[%c0, %c0, %c0, %c0] {in_bounds = [true]} : vector<2048xf32>, memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
| %21 = iree_vector_ext.to_layout %19 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
| vector.transfer_write %21, %subview_4[%c0, %c0, %c0, %c0] {in_bounds = [true]} : vector<2048xf32>, memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| %subview_6 = memref.subview %arg2[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>> to memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
| gpu.barrier | |
| memref.copy %subview_5, %subview_6 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
| gpu.barrier | |
| %subview_7 = memref.subview %arg3[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| memref.copy %subview_4, %subview_7 : memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| scf.yield %arg2, %arg3 : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>>, memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| } | |
| %subview_3 = memref.subview %3[0, 0, %arg0, 0] [1, 1, 1, 16384] [1, 1, 1, 1] : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| memref.copy %14#1, %subview_3 : memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
| } {mapping = [#iree_codegen.workgroup_mapping<x>]} | |
| memref.copy %3, %3 : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
| return | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment