Created
March 4, 2025 14:04
-
-
Save pashu123/6ded73afe21fd55668f11ade7e84b820 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @encode_prompts$async_dispatch_10_softmax_12x64x64xf32_generic() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = true, use_igemm_convolution = false>}>} { | |
%cst = arith.constant dense<0.000000e+00> : vector<64xf32> | |
%cst_0 = arith.constant dense<0xFFC00000> : vector<64xf32> | |
%c0 = arith.constant 0 : index | |
%cst_1 = arith.constant 0.000000e+00 : f32 | |
%cst_2 = arith.constant 0xFFC00000 : f32 | |
%0 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32 | |
%1 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32 | |
%2 = arith.index_castui %0 : i32 to index | |
%3 = arith.index_castui %1 : i32 to index | |
%4:2 = util.assume.int | |
%2[<umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 196608, umax = 196608, udiv = 196608>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>, <umin = 688128, umax = 688128, udiv = 688128>], | |
%3[<umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 393216, umax = 393216, udiv = 393216>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>, <umin = 884736, umax = 884736, udiv = 884736>] | |
: index, index | |
%5 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%4#0) flags("ReadOnly|Indirect") : memref<12x64x64xf32, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %5, 64 : memref<12x64x64xf32, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
%6 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%4#1) flags(Indirect) : memref<12x64x64xf16, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
memref.assume_alignment %6, 64 : memref<12x64x64xf16, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
scf.forall (%arg0, %arg1) in (12, 64) { | |
%7 = vector.transfer_read %5[%arg0, %arg1, %c0], %cst_1 {in_bounds = [true]} : memref<12x64x64xf32, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<64xf32> | |
%8 = iree_vector_ext.to_layout %7 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [1], outer_tile = [1], thread_tile = [16], element_tile = [4], subgroup_strides = [0], thread_strides = [1]>) : vector<64xf32> | |
%9 = iree_vector_ext.to_layout %cst_0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [1], outer_tile = [1], thread_tile = [16], element_tile = [4], subgroup_strides = [0], thread_strides = [1]>) : vector<64xf32> | |
%10 = arith.maxnumf %8, %9 : vector<64xf32> | |
%11 = iree_vector_ext.to_layout %10 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [1], outer_tile = [1], thread_tile = [16], element_tile = [4], subgroup_strides = [0], thread_strides = [1]>) : vector<64xf32> | |
%12 = vector.multi_reduction <maxnumf>, %11, %cst_2 [0] : vector<64xf32> to f32 | |
%13 = vector.broadcast %12 : f32 to vector<f32> | |
%14 = iree_vector_ext.to_layout %13 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [], batch_tile = [], outer_tile = [], thread_tile = [], element_tile = [], subgroup_strides = [], thread_strides = []>) : vector<f32> | |
%15 = iree_vector_ext.to_layout %cst to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [1], outer_tile = [1], thread_tile = [16], element_tile = [4], subgroup_strides = [0], thread_strides = [1]>) : vector<64xf32> | |
%16 = vector.broadcast %14 : vector<f32> to vector<64xf32> | |
%17 = arith.subf %8, %16 : vector<64xf32> | |
%18 = math.exp %17 : vector<64xf32> | |
%19 = arith.addf %18, %15 : vector<64xf32> | |
%20 = iree_vector_ext.to_layout %19 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [1], outer_tile = [1], thread_tile = [16], element_tile = [4], subgroup_strides = [0], thread_strides = [1]>) : vector<64xf32> | |
%21 = vector.multi_reduction <add>, %20, %cst_1 [0] : vector<64xf32> to f32 | |
%22 = vector.broadcast %21 : f32 to vector<f32> | |
%23 = iree_vector_ext.to_layout %7 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [8], outer_tile = [1], thread_tile = [1], element_tile = [8], subgroup_strides = [0], thread_strides = [0]>) : vector<64xf32> | |
%24 = iree_vector_ext.to_layout %22 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [], batch_tile = [], outer_tile = [], thread_tile = [], element_tile = [], subgroup_strides = [], thread_strides = []>) : vector<f32> | |
%25 = vector.broadcast %24 : vector<f32> to vector<64xf32> | |
%26 = arith.subf %23, %16 : vector<64xf32> | |
%27 = math.exp %26 : vector<64xf32> | |
%28 = arith.divf %27, %25 : vector<64xf32> | |
%29 = iree_vector_ext.to_layout %28 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [8], outer_tile = [1], thread_tile = [1], element_tile = [8], subgroup_strides = [0], thread_strides = [0]>) : vector<64xf32> | |
%30 = iree_vector_ext.to_layout %29 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [8], outer_tile = [1], thread_tile = [1], element_tile = [8], subgroup_strides = [0], thread_strides = [0]>) : vector<64xf32> | |
%31 = arith.truncf %30 : vector<64xf32> to vector<64xf16> | |
%32 = iree_vector_ext.to_layout %31 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1], batch_tile = [8], outer_tile = [1], thread_tile = [1], element_tile = [8], subgroup_strides = [0], thread_strides = [0]>) : vector<64xf16> | |
%subview = memref.subview %6[%arg0, %arg1, 0] [1, 1, 64] [1, 1, 1] : memref<12x64x64xf16, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> to memref<1x1x64xf16, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
vector.transfer_write %32, %subview[%c0, %c0, %c0] {in_bounds = [true]} : vector<64xf16>, memref<1x1x64xf16, strided<[4096, 64, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> | |
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]} | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment