Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created January 16, 2025 19:52
Show Gist options
  • Save davidberard98/273cfbba2c74c22ade7d2577275c72b3 to your computer and use it in GitHub Desktop.
Save davidberard98/273cfbba2c74c22ade7d2577275c72b3 to your computer and use it in GitHub Desktop.
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}>
#loc = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0)
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
#shared2 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_sparse_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg9: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0)) attributes {noinline = false} {
%c4_i32 = arith.constant 4 : i32 loc(#loc1)
%c192_i32 = arith.constant 192 : i32 loc(#loc1)
%c3_i32 = arith.constant 3 : i32 loc(#loc1)
%c2_i32 = arith.constant 2 : i32 loc(#loc1)
%c-1_i32 = arith.constant -1 : i32 loc(#loc1)
%c128_i32 = arith.constant 128 : i32 loc(#loc1)
%c256_i32 = arith.constant 256 : i32 loc(#loc1)
%c64_i32 = arith.constant 64 : i32 loc(#loc1)
%c0_i32 = arith.constant 0 : i32 loc(#loc1)
%c1_i32 = arith.constant 1 : i32 loc(#loc1)
%c63_i32 = arith.constant 63 : i32 loc(#loc1)
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = tt.get_program_id y : i32 loc(#loc3)
%2 = arith.muli %0, %c128_i32 : i32 loc(#loc4)
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc5)
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc5)
%5 = tt.splat %2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc6)
%6 = tt.splat %2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc6)
%7 = arith.addi %5, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc6)
%8 = arith.addi %6, %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc6)
%9 = tt.splat %arg4 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc7)
%10 = arith.remsi %7, %9 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc7)
%11 = arith.muli %1, %c256_i32 : i32 loc(#loc8)
%12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc9)
%13 = tt.splat %11 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10)
%14 = arith.addi %13, %12 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10)
%15 = tt.splat %arg5 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc11)
%16 = arith.remsi %14, %15 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc11)
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked2> loc(#loc12)
%18 = tt.expand_dims %10 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc13)
%19 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked> loc(#loc14)
%20 = tt.addptr %19, %18 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi32, #blocked> loc(#loc14)
%21 = tt.expand_dims %16 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc15)
%22 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<1x256x!tt.ptr<f16>, #blocked1> loc(#loc16)
%23 = tt.addptr %22, %21 : tensor<1x256x!tt.ptr<f16>, #blocked1>, tensor<1x256xi32, #blocked1> loc(#loc16)
%24 = arith.addi %arg6, %c63_i32 : i32 loc(#loc45)
%25 = arith.divsi %24, %c64_i32 : i32 loc(#loc46)
%26 = arith.extsi %arg7 : i32 to i64 loc(#loc20)
%27 = tt.splat %26 : i64 -> tensor<1x64xi64, #blocked> loc(#loc20)
%28 = tt.broadcast %20 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked> loc(#loc21)
%29 = arith.extsi %arg8 : i32 to i64 loc(#loc22)
%30 = tt.splat %29 : i64 -> tensor<64x1xi64, #blocked1> loc(#loc22)
%31 = tt.broadcast %23 : tensor<1x256x!tt.ptr<f16>, #blocked1> -> tensor<64x256x!tt.ptr<f16>, #blocked1> loc(#loc23)
%32 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc24)
%33 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc24)
%34 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> loc(#loc25)
%35 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> loc(#loc26)
%36 = arith.cmpi sgt, %25, %c0_i32 : i32 loc(#loc27)
%37 = tt.splat %arg3 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28)
%38 = tt.addptr %37, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28)
%39 = ttg.memdesc_subview %32[%c0_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%40 = tt.splat %36 : i1 -> tensor<64xi1, #blocked2> loc(#loc27)
%41 = ttg.async_copy_global_to_local %38, %39 mask %40 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%42 = ttg.async_commit_group %41 loc(#loc24)
%43 = ttg.memdesc_subview %33[%c0_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%44 = ttg.async_copy_global_to_local %38, %43 mask %40 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%45 = ttg.async_commit_group %44 loc(#loc24)
%46 = arith.cmpi sgt, %25, %c1_i32 : i32 loc(#loc27)
%47 = tt.addptr %arg3, %c64_i32 : !tt.ptr<i64>, i32 loc(#loc29)
%48 = tt.splat %47 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28)
%49 = tt.addptr %48, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28)
%50 = ttg.memdesc_subview %32[%c1_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%51 = tt.splat %46 : i1 -> tensor<64xi1, #blocked2> loc(#loc27)
%52 = ttg.async_copy_global_to_local %49, %50 mask %51 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%53 = ttg.async_commit_group %52 loc(#loc24)
%54 = ttg.memdesc_subview %33[%c1_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%55 = ttg.async_copy_global_to_local %49, %54 mask %51 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%56 = ttg.async_commit_group %55 loc(#loc24)
%57 = arith.cmpi sgt, %25, %c2_i32 : i32 loc(#loc27)
%58 = ttg.async_wait %42 {num = 3 : i32} loc(#loc24)
%59 = ttg.local_load %39 token %58 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc24)
%60 = ttg.async_wait %45 {num = 2 : i32} loc(#loc24)
%61 = ttg.local_load %43 token %60 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24)
%62 = tt.expand_dims %59 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> loc(#loc30)
%63 = arith.muli %62, %27 : tensor<1x64xi64, #blocked> loc(#loc20)
%64 = tt.broadcast %63 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> loc(#loc21)
%65 = tt.addptr %28, %64 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> loc(#loc21)
%66 = ttg.memdesc_subview %34[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%67 = tt.splat %36 : i1 -> tensor<128x64xi1, #blocked> loc(#loc27)
%68 = ttg.async_copy_global_to_local %65, %66 mask %67 : tensor<128x64x!tt.ptr<f16>, #blocked> -> <128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%69 = ttg.async_commit_group %68 loc(#loc25)
%70 = tt.expand_dims %61 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> loc(#loc31)
%71 = arith.muli %70, %30 : tensor<64x1xi64, #blocked1> loc(#loc22)
%72 = tt.broadcast %71 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> loc(#loc23)
%73 = tt.addptr %31, %72 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> loc(#loc23)
%74 = ttg.memdesc_subview %35[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%75 = tt.splat %36 : i1 -> tensor<64x256xi1, #blocked1> loc(#loc27)
%76 = ttg.async_copy_global_to_local %73, %74 mask %75 : tensor<64x256x!tt.ptr<f16>, #blocked1> -> <64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%77 = ttg.async_commit_group %76 loc(#loc26)
%78 = tt.addptr %arg3, %c128_i32 : !tt.ptr<i64>, i32 loc(#loc29)
%79 = tt.splat %78 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28)
%80 = tt.addptr %79, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28)
%81 = tt.splat %57 : i1 -> tensor<64xi1, #blocked2> loc(#loc27)
%82 = ttg.async_copy_global_to_local %80, %39 mask %81 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%83 = ttg.async_commit_group %82 loc(#loc24)
%84 = ttg.async_copy_global_to_local %80, %43 mask %81 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%85 = ttg.async_commit_group %84 loc(#loc24)
%86 = arith.cmpi sgt, %25, %c3_i32 : i32 loc(#loc27)
%87 = ttg.async_wait %53 {num = 5 : i32} loc(#loc24)
%88 = ttg.local_load %50 token %87 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc24)
%89 = ttg.async_wait %56 {num = 4 : i32} loc(#loc24)
%90 = ttg.local_load %54 token %89 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24)
%91 = tt.expand_dims %88 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> loc(#loc30)
%92 = arith.muli %91, %27 : tensor<1x64xi64, #blocked> loc(#loc20)
%93 = tt.broadcast %92 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> loc(#loc21)
%94 = tt.addptr %28, %93 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> loc(#loc21)
%95 = ttg.memdesc_subview %34[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%96 = tt.splat %46 : i1 -> tensor<128x64xi1, #blocked> loc(#loc27)
%97 = ttg.async_copy_global_to_local %94, %95 mask %96 : tensor<128x64x!tt.ptr<f16>, #blocked> -> <128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%98 = ttg.async_commit_group %97 loc(#loc25)
%99 = tt.expand_dims %90 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> loc(#loc31)
%100 = arith.muli %99, %30 : tensor<64x1xi64, #blocked1> loc(#loc22)
%101 = tt.broadcast %100 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> loc(#loc23)
%102 = tt.addptr %31, %101 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> loc(#loc23)
%103 = ttg.memdesc_subview %35[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%104 = tt.splat %46 : i1 -> tensor<64x256xi1, #blocked1> loc(#loc27)
%105 = ttg.async_copy_global_to_local %102, %103 mask %104 : tensor<64x256x!tt.ptr<f16>, #blocked1> -> <64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%106 = ttg.async_commit_group %105 loc(#loc26)
%107 = tt.addptr %arg3, %c192_i32 : !tt.ptr<i64>, i32 loc(#loc29)
%108 = tt.splat %107 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28)
%109 = tt.addptr %108, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28)
%110 = tt.splat %86 : i1 -> tensor<64xi1, #blocked2> loc(#loc27)
%111 = ttg.async_copy_global_to_local %109, %50 mask %110 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%112 = ttg.async_commit_group %111 loc(#loc24)
%113 = ttg.async_copy_global_to_local %109, %54 mask %110 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%114 = ttg.async_commit_group %113 loc(#loc24)
%115:11 = scf.for %arg10 = %c0_i32 to %25 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %c1_i32, %arg13 = %c1_i32, %arg14 = %c1_i32, %arg15 = %c-1_i32, %arg16 = %77, %arg17 = %106, %arg18 = %83, %arg19 = %112, %arg20 = %85, %arg21 = %114) -> (tensor<128x256xf32, #mma>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 {
%136 = arith.subi %25, %c4_i32 : i32 loc(#loc27)
%137 = arith.cmpi slt, %arg10, %136 : i32 loc(#loc27)
%138 = arith.subi %25, %c2_i32 : i32 loc(#loc27)
%139 = arith.cmpi slt, %arg10, %138 : i32 loc(#loc27)
%140 = arith.addi %arg15, %c1_i32 : i32 loc(#loc27)
%141 = arith.cmpi slt, %140, %c3_i32 : i32 loc(#loc27)
%142 = arith.select %141, %140, %c0_i32 : i32 loc(#loc27)
%143 = ttg.memdesc_subview %34[%142, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%144 = ttg.async_wait %arg16 {num = 6 : i32} loc(#loc25)
%145 = ttg.memdesc_subview %35[%142, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%146 = ttng.warp_group_dot %143, %145, %arg11 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc32)
%147:3 = ttng.warp_group_dot_wait %146, %143, %145 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc32)
%148 = arith.addi %arg14, %c1_i32 : i32 loc(#loc27)
%149 = arith.cmpi slt, %148, %c3_i32 : i32 loc(#loc27)
%150 = arith.select %149, %148, %c0_i32 : i32 loc(#loc27)
%151 = arith.addi %arg13, %c1_i32 : i32 loc(#loc27)
%152 = arith.cmpi slt, %151, %c2_i32 : i32 loc(#loc27)
%153 = arith.select %152, %151, %c0_i32 : i32 loc(#loc27)
%154 = ttg.async_wait %arg18 {num = 5 : i32} loc(#loc24)
%155 = ttg.memdesc_subview %32[%153, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%156 = ttg.local_load %155 token %154 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc24)
%157 = ttg.async_wait %arg20 {num = 4 : i32} loc(#loc24)
%158 = ttg.memdesc_subview %33[%153, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%159 = ttg.local_load %158 token %157 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24)
%160 = tt.expand_dims %156 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> loc(#loc30)
%161 = arith.muli %160, %27 : tensor<1x64xi64, #blocked> loc(#loc20)
%162 = tt.broadcast %161 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> loc(#loc21)
%163 = tt.addptr %28, %162 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> loc(#loc21)
%164 = ttg.memdesc_subview %34[%150, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%165 = tt.splat %139 : i1 -> tensor<128x64xi1, #blocked> loc(#loc27)
%166 = ttg.async_copy_global_to_local %163, %164 mask %165 : tensor<128x64x!tt.ptr<f16>, #blocked> -> <128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25)
%167 = ttg.async_commit_group %166 loc(#loc25)
%168 = tt.expand_dims %159 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> loc(#loc31)
%169 = arith.muli %168, %30 : tensor<64x1xi64, #blocked1> loc(#loc22)
%170 = tt.broadcast %169 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> loc(#loc23)
%171 = tt.addptr %31, %170 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> loc(#loc23)
%172 = ttg.memdesc_subview %35[%150, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%173 = tt.splat %139 : i1 -> tensor<64x256xi1, #blocked1> loc(#loc27)
%174 = ttg.async_copy_global_to_local %171, %172 mask %173 : tensor<64x256x!tt.ptr<f16>, #blocked1> -> <64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26)
%175 = ttg.async_commit_group %174 loc(#loc26)
%176 = arith.addi %arg12, %c1_i32 : i32 loc(#loc27)
%177 = arith.cmpi slt, %176, %c2_i32 : i32 loc(#loc27)
%178 = arith.select %177, %176, %c0_i32 : i32 loc(#loc27)
%179 = arith.addi %arg10, %c4_i32 : i32 loc(#loc27)
%180 = arith.muli %179, %c64_i32 : i32 loc(#loc33)
%181 = tt.addptr %arg3, %180 : !tt.ptr<i64>, i32 loc(#loc29)
%182 = tt.splat %181 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28)
%183 = tt.addptr %182, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28)
%184 = ttg.memdesc_subview %32[%178, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%185 = tt.splat %137 : i1 -> tensor<64xi1, #blocked2> loc(#loc27)
%186 = ttg.async_copy_global_to_local %183, %184 mask %185 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%187 = ttg.async_commit_group %186 loc(#loc24)
%188 = ttg.memdesc_subview %33[%178, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%189 = ttg.async_copy_global_to_local %183, %188 mask %185 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24)
%190 = ttg.async_commit_group %189 loc(#loc24)
scf.yield %147#0, %178, %153, %150, %142, %arg17, %175, %arg19, %187, %arg21, %190 : tensor<128x256xf32, #mma>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token loc(#loc27)
} loc(#loc27)
%116 = ttng.warp_group_dot_wait %115#0 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc27)
%117 = ttg.async_wait {num = 0 : i32} loc(#loc27)
ttg.local_dealloc %32 : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc27)
ttg.local_dealloc %33 : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc27)
ttg.local_dealloc %34 : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> loc(#loc27)
ttg.local_dealloc %35 : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> loc(#loc27)
%118 = arith.truncf %116 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc34)
%119 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc35)
%120 = tt.splat %arg9 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc36)
%121 = arith.muli %120, %119 : tensor<128x1xi32, #blocked1> loc(#loc36)
%122 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1> loc(#loc37)
%123 = tt.addptr %122, %121 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1> loc(#loc37)
%124 = tt.expand_dims %14 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc38)
%125 = tt.broadcast %123 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x256x!tt.ptr<f16>, #blocked1> loc(#loc39)
%126 = tt.broadcast %124 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> loc(#loc39)
%127 = tt.addptr %125, %126 : tensor<128x256x!tt.ptr<f16>, #blocked1>, tensor<128x256xi32, #blocked1> loc(#loc39)
%128 = tt.splat %arg4 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc40)
%129 = arith.cmpi slt, %119, %128 : tensor<128x1xi32, #blocked1> loc(#loc40)
%130 = tt.splat %arg5 : i32 -> tensor<1x256xi32, #blocked1> loc(#loc41)
%131 = arith.cmpi slt, %124, %130 : tensor<1x256xi32, #blocked1> loc(#loc41)
%132 = tt.broadcast %129 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc42)
%133 = tt.broadcast %131 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc42)
%134 = arith.andi %132, %133 : tensor<128x256xi1, #blocked1> loc(#loc42)
%135 = ttg.convert_layout %118 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> loc(#loc43)
tt.store %127, %135, %134 : tensor<128x256x!tt.ptr<f16>, #blocked1> loc(#loc43)
tt.return loc(#loc44)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":182:26)
#loc3 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":183:26)
#loc4 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:23)
#loc5 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:51)
#loc6 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:38)
#loc7 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:68)
#loc8 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:23)
#loc9 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:51)
#loc10 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:38)
#loc11 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:68)
#loc12 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":187:26)
#loc13 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":188:29)
#loc14 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":188:21)
#loc15 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":189:29)
#loc16 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":189:21)
#loc17 = loc("/home/dberard/local/triton-env2/triton/python/triton/language/standard.py":40:22)
#loc18 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":193:33)
#loc19 = loc("/home/dberard/local/triton-env2/triton/python/triton/language/standard.py":40:28)
#loc20 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:53)
#loc21 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:29)
#loc22 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:53)
#loc23 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:29)
#loc24 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:20)
#loc25 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:20)
#loc26 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:20)
#loc27 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":193:22)
#loc28 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:49)
#loc29 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:30)
#loc30 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:42)
#loc31 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:42)
#loc32 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":203:35)
#loc33 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:34)
#loc34 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":205:23)
#loc35 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:41)
#loc36 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:33)
#loc37 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:21)
#loc38 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:72)
#loc39 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:52)
#loc40 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":213:33)
#loc41 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":213:58)
#loc42 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":213:39)
#loc43 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":214:21)
#loc44 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":214:4)
#loc45 = loc(callsite(#loc17 at #loc18))
#loc46 = loc(callsite(#loc19 at #loc18))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment