Created
May 8, 2025 07:14
-
-
Save pashu123/0443f27cd15e79f07a059d136c9a3c8b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @softmax_4d_dispatch_0_softmax_1x1x16384x16384xf32_dispatch_tensor_store() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [512, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>} { | |
%cst = arith.constant dense<0.000000e+00> : vector<2048xf32> | |
%cst_0 = arith.constant dense<0xFFC00000> : vector<2048xf32> | |
%c2048 = arith.constant 2048 : index | |
%c16384 = arith.constant 16384 : index | |
%cst_1 = arith.constant 0.000000e+00 : f32 | |
%cst_2 = arith.constant 0xFFC00000 : f32 | |
%c0 = arith.constant 0 : index | |
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> | |
%1 = amdgpu.fat_raw_buffer_cast %0 resetOffset : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> to memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
memref.assume_alignment %1, 64 : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> | |
%3 = amdgpu.fat_raw_buffer_cast %2 resetOffset : memref<1x1x16384x16384xf32, #hal.descriptor_type<storage_buffer>> to memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
memref.assume_alignment %3, 64 : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
scf.forall (%arg0) in (16384) { | |
%subview = memref.subview %3[0, 0, %arg0, 0] [1, 1, 1, 16384] [1, 1, 1, 1] : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
%4 = scf.for %arg1 = %c0 to %c16384 step %c2048 iter_args(%arg2 = %cst_0) -> (vector<2048xf32>) { | |
%15 = vector.transfer_read %1[%c0, %c0, %arg0, %arg1], %cst_1 {in_bounds = [true]} : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<2048xf32> | |
%16 = iree_vector_ext.to_layout %15 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
%17 = iree_vector_ext.to_layout %arg2 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
%18 = arith.maxnumf %16, %17 : vector<2048xf32> | |
%19 = iree_vector_ext.to_layout %18 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
scf.yield %19 : vector<2048xf32> | |
} | |
%5 = vector.multi_reduction <maxnumf>, %4, %cst_2 [0] : vector<2048xf32> to f32 | |
%6 = vector.broadcast %5 : f32 to vector<f32> | |
%7 = iree_vector_ext.to_layout %6 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [], batch_tile = [], outer_tile = [], thread_tile = [], element_tile = [], subgroup_strides = [], thread_strides = []>) : vector<f32> | |
%8 = vector.broadcast %7 : vector<f32> to vector<2048xf32> | |
%9 = scf.for %arg1 = %c0 to %c16384 step %c2048 iter_args(%arg2 = %cst) -> (vector<2048xf32>) { | |
%15 = vector.transfer_read %1[%c0, %c0, %arg0, %arg1], %cst_1 {in_bounds = [true]} : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<2048xf32> | |
%16 = iree_vector_ext.to_layout %15 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
%17 = iree_vector_ext.to_layout %arg2 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
%18 = arith.subf %16, %8 : vector<2048xf32> | |
%19 = math.exp %18 : vector<2048xf32> | |
%20 = arith.addf %19, %17 : vector<2048xf32> | |
%21 = iree_vector_ext.to_layout %20 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
scf.yield %21 : vector<2048xf32> | |
} | |
%10 = vector.multi_reduction <add>, %9, %cst_1 [0] : vector<2048xf32> to f32 | |
%11 = vector.broadcast %10 : f32 to vector<f32> | |
%alloc = memref.alloc() : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>> | |
%12 = iree_vector_ext.to_layout %11 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [], batch_tile = [], outer_tile = [], thread_tile = [], element_tile = [], subgroup_strides = [], thread_strides = []>) : vector<f32> | |
%13 = vector.broadcast %12 : vector<f32> to vector<2048xf32> | |
%14:2 = scf.for %arg1 = %c0 to %c16384 step %c2048 iter_args(%arg2 = %alloc, %arg3 = %subview) -> (memref<1x1x1x16384xf32, #gpu.address_space<workgroup>>, memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>) { | |
%subview_4 = memref.subview %arg3[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
%subview_5 = memref.subview %arg2[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>> to memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
%15 = vector.transfer_read %1[%c0, %c0, %arg0, %arg1], %cst_1 {in_bounds = [true]} : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<2048xf32> | |
%16 = iree_vector_ext.to_layout %15 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
%17 = arith.subf %16, %8 : vector<2048xf32> | |
%18 = math.exp %17 : vector<2048xf32> | |
%19 = arith.divf %18, %13 : vector<2048xf32> | |
%20 = iree_vector_ext.to_layout %18 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
vector.transfer_write %20, %subview_5[%c0, %c0, %c0, %c0] {in_bounds = [true]} : vector<2048xf32>, memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
%21 = iree_vector_ext.to_layout %19 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [8], batch_tile = [1], outer_tile = [1], thread_tile = [64], element_tile = [4], subgroup_strides = [1], thread_strides = [1]>) : vector<2048xf32> | |
vector.transfer_write %21, %subview_4[%c0, %c0, %c0, %c0] {in_bounds = [true]} : vector<2048xf32>, memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
%subview_6 = memref.subview %arg2[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>> to memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
gpu.barrier | |
memref.copy %subview_5, %subview_6 {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> to memref<1x1x1x2048xf32, strided<[16384, 16384, 16384, 1], offset: ?>, #gpu.address_space<workgroup>> | |
gpu.barrier | |
%subview_7 = memref.subview %arg3[0, 0, 0, %arg1] [1, 1, 1, 2048] [1, 1, 1, 1] : memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
memref.copy %subview_4, %subview_7 : memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x2048xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
scf.yield %arg2, %arg3 : memref<1x1x1x16384xf32, #gpu.address_space<workgroup>>, memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
} | |
%subview_3 = memref.subview %3[0, 0, %arg0, 0] [1, 1, 1, 16384] [1, 1, 1, 1] : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
memref.copy %14#1, %subview_3 : memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x1x16384xf32, strided<[268435456, 268435456, 16384, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> | |
} {mapping = [#iree_codegen.workgroup_mapping<x>]} | |
memref.copy %3, %3 : memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<1x1x16384x16384xf32, #amdgpu.address_space<fat_raw_buffer>> | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment