Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 3, 2025 12:45
Show Gist options
  • Save pashu123/bfb2bcce16d74577817fbbdc87eba7d5 to your computer and use it in GitHub Desktop.
Save pashu123/bfb2bcce16d74577817fbbdc87eba7d5 to your computer and use it in GitHub Desktop.
func.func @matvec_dispatch_0_matmul_transpose_b_32000x1x4096_f16xf16xf32() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [128, 1, 1] subgroup_size = 64>} {
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x1x512xf32>
%cst_1 = arith.constant dense<0.000000e+00> : vector<16x1xf32>
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
%c512 = arith.constant 512 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>
%1 = amdgpu.fat_raw_buffer_cast %0 resetOffset : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>> to memref<32000x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
memref.assume_alignment %1, 64 : memref<32000x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>
%3 = amdgpu.fat_raw_buffer_cast %2 resetOffset : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>> to memref<1x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
memref.assume_alignment %3, 64 : memref<1x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
%4 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<32000x1xf16, #hal.descriptor_type<storage_buffer>>
%5 = amdgpu.fat_raw_buffer_cast %4 resetOffset : memref<32000x1xf16, #hal.descriptor_type<storage_buffer>> to memref<32000x1xf16, #amdgpu.address_space<fat_raw_buffer>>
memref.assume_alignment %5, 64 : memref<32000x1xf16, #amdgpu.address_space<fat_raw_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 2000 : index
%6 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%workgroup_id_x]
%7 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args(%arg1 = %cst_0) -> (vector<16x1x512xf32>) {
%11 = vector.transfer_read %1[%6, %arg0], %cst {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (d0, 0, d1)>} : memref<32000x4096xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<16x1x512xf16>
%12 = vector.transfer_read %3[%c0, %arg0], %cst {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d0, d1)>} : memref<1x4096xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<16x1x512xf16>
%13 = arith.extf %11 : vector<16x1x512xf16> to vector<16x1x512xf32>
%14 = arith.extf %12 : vector<16x1x512xf16> to vector<16x1x512xf32>
%15 = arith.mulf %13, %14 : vector<16x1x512xf32>
%16 = arith.addf %arg1, %15 : vector<16x1x512xf32>
scf.yield %16 : vector<16x1x512xf32>
}
%8 = vector.multi_reduction <add>, %7, %cst_1 [2] : vector<16x1x512xf32> to vector<16x1xf32>
%9 = arith.truncf %8 : vector<16x1xf32> to vector<16x1xf16>
%subview = memref.subview %5[0, 0] [32000, 1] [1, 1] : memref<32000x1xf16, #amdgpu.address_space<fat_raw_buffer>> to memref<32000xf16, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
%10 = vector.shape_cast %9 : vector<16x1xf16> to vector<16xf16>
vector.transfer_write %10, %subview[%6] {in_bounds = [true]} : vector<16xf16>, memref<32000xf16, strided<[1]>, #amdgpu.address_space<fat_raw_buffer>>
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment