Created
January 16, 2025 19:52
-
-
Save davidberard98/a4aaafd19466444f08b77f7fd7e64f0b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 8], order = [0, 1]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> | |
#loc = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0) | |
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> | |
#shared = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> | |
#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> | |
#shared2 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_sparse_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg4: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg7: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0), %arg9: i32 {tt.divisibility = 16 : i32} loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":157:0)) attributes {noinline = false} { | |
%c4_i32 = arith.constant 4 : i32 loc(#loc1) | |
%c192_i32 = arith.constant 192 : i32 loc(#loc1) | |
%c3_i32 = arith.constant 3 : i32 loc(#loc1) | |
%c2_i32 = arith.constant 2 : i32 loc(#loc1) | |
%c-1_i32 = arith.constant -1 : i32 loc(#loc1) | |
%c128_i32 = arith.constant 128 : i32 loc(#loc1) | |
%c256_i32 = arith.constant 256 : i32 loc(#loc1) | |
%c64_i32 = arith.constant 64 : i32 loc(#loc1) | |
%c0_i32 = arith.constant 0 : i32 loc(#loc1) | |
%c1_i32 = arith.constant 1 : i32 loc(#loc1) | |
%c63_i32 = arith.constant 63 : i32 loc(#loc1) | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> loc(#loc1) | |
%0 = tt.get_program_id x : i32 loc(#loc2) | |
%1 = tt.get_program_id y : i32 loc(#loc3) | |
%2 = arith.muli %0, %c128_i32 : i32 loc(#loc4) | |
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc5) | |
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc5) | |
%5 = tt.splat %2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc6) | |
%6 = tt.splat %2 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc6) | |
%7 = arith.addi %5, %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc6) | |
%8 = arith.addi %6, %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc6) | |
%9 = tt.splat %arg4 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc7) | |
%10 = arith.remsi %7, %9 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc7) | |
%11 = arith.muli %1, %c256_i32 : i32 loc(#loc8) | |
%12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc9) | |
%13 = tt.splat %11 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10) | |
%14 = arith.addi %13, %12 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc10) | |
%15 = tt.splat %arg5 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc11) | |
%16 = arith.remsi %14, %15 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> loc(#loc11) | |
%17 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked2> loc(#loc12) | |
%18 = tt.expand_dims %10 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> loc(#loc13) | |
%19 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked> loc(#loc14) | |
%20 = tt.addptr %19, %18 : tensor<128x1x!tt.ptr<f16>, #blocked>, tensor<128x1xi32, #blocked> loc(#loc14) | |
%21 = tt.expand_dims %16 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc15) | |
%22 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<1x256x!tt.ptr<f16>, #blocked1> loc(#loc16) | |
%23 = tt.addptr %22, %21 : tensor<1x256x!tt.ptr<f16>, #blocked1>, tensor<1x256xi32, #blocked1> loc(#loc16) | |
%24 = arith.addi %arg6, %c63_i32 : i32 loc(#loc45) | |
%25 = arith.divsi %24, %c64_i32 : i32 loc(#loc46) | |
%26 = arith.extsi %arg7 : i32 to i64 loc(#loc20) | |
%27 = tt.splat %26 : i64 -> tensor<1x64xi64, #blocked> loc(#loc20) | |
%28 = tt.broadcast %20 : tensor<128x1x!tt.ptr<f16>, #blocked> -> tensor<128x64x!tt.ptr<f16>, #blocked> loc(#loc21) | |
%29 = arith.extsi %arg8 : i32 to i64 loc(#loc22) | |
%30 = tt.splat %29 : i64 -> tensor<64x1xi64, #blocked1> loc(#loc22) | |
%31 = tt.broadcast %23 : tensor<1x256x!tt.ptr<f16>, #blocked1> -> tensor<64x256x!tt.ptr<f16>, #blocked1> loc(#loc23) | |
%32 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc24) | |
%33 = ttg.local_alloc : () -> !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc24) | |
%34 = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> loc(#loc25) | |
%35 = ttg.local_alloc : () -> !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> loc(#loc26) | |
%36 = arith.cmpi sgt, %25, %c0_i32 : i32 loc(#loc27) | |
%37 = tt.splat %arg3 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28) | |
%38 = tt.addptr %37, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28) | |
%39 = ttg.memdesc_subview %32[%c0_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%40 = tt.splat %36 : i1 -> tensor<64xi1, #blocked2> loc(#loc27) | |
%41 = ttg.async_copy_global_to_local %38, %39 mask %40 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%42 = ttg.async_commit_group %41 loc(#loc24) | |
%43 = ttg.memdesc_subview %33[%c0_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%44 = ttg.async_copy_global_to_local %38, %43 mask %40 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%45 = ttg.async_commit_group %44 loc(#loc24) | |
%46 = arith.cmpi sgt, %25, %c1_i32 : i32 loc(#loc27) | |
%47 = tt.addptr %arg3, %c64_i32 : !tt.ptr<i64>, i32 loc(#loc29) | |
%48 = tt.splat %47 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28) | |
%49 = tt.addptr %48, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28) | |
%50 = ttg.memdesc_subview %32[%c1_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%51 = tt.splat %46 : i1 -> tensor<64xi1, #blocked2> loc(#loc27) | |
%52 = ttg.async_copy_global_to_local %49, %50 mask %51 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%53 = ttg.async_commit_group %52 loc(#loc24) | |
%54 = ttg.memdesc_subview %33[%c1_i32, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%55 = ttg.async_copy_global_to_local %49, %54 mask %51 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%56 = ttg.async_commit_group %55 loc(#loc24) | |
%57 = arith.cmpi sgt, %25, %c2_i32 : i32 loc(#loc27) | |
%58 = ttg.async_wait %42 {num = 3 : i32} loc(#loc24) | |
%59 = ttg.local_load %39 token %58 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc24) | |
%60 = ttg.async_wait %45 {num = 2 : i32} loc(#loc24) | |
%61 = ttg.local_load %43 token %60 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) | |
%62 = tt.expand_dims %59 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> loc(#loc30) | |
%63 = arith.muli %62, %27 : tensor<1x64xi64, #blocked> loc(#loc20) | |
%64 = tt.broadcast %63 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> loc(#loc21) | |
%65 = tt.addptr %28, %64 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> loc(#loc21) | |
%66 = ttg.memdesc_subview %34[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%67 = tt.splat %36 : i1 -> tensor<128x64xi1, #blocked> loc(#loc27) | |
%68 = ttg.async_copy_global_to_local %65, %66 mask %67 : tensor<128x64x!tt.ptr<f16>, #blocked> -> <128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%69 = ttg.async_commit_group %68 loc(#loc25) | |
%70 = tt.expand_dims %61 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> loc(#loc31) | |
%71 = arith.muli %70, %30 : tensor<64x1xi64, #blocked1> loc(#loc22) | |
%72 = tt.broadcast %71 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> loc(#loc23) | |
%73 = tt.addptr %31, %72 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> loc(#loc23) | |
%74 = ttg.memdesc_subview %35[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%75 = tt.splat %36 : i1 -> tensor<64x256xi1, #blocked1> loc(#loc27) | |
%76 = ttg.async_copy_global_to_local %73, %74 mask %75 : tensor<64x256x!tt.ptr<f16>, #blocked1> -> <64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%77 = ttg.async_commit_group %76 loc(#loc26) | |
%78 = tt.addptr %arg3, %c128_i32 : !tt.ptr<i64>, i32 loc(#loc29) | |
%79 = tt.splat %78 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28) | |
%80 = tt.addptr %79, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28) | |
%81 = tt.splat %57 : i1 -> tensor<64xi1, #blocked2> loc(#loc27) | |
%82 = ttg.async_copy_global_to_local %80, %39 mask %81 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%83 = ttg.async_commit_group %82 loc(#loc24) | |
%84 = ttg.async_copy_global_to_local %80, %43 mask %81 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%85 = ttg.async_commit_group %84 loc(#loc24) | |
%86 = arith.cmpi sgt, %25, %c3_i32 : i32 loc(#loc27) | |
%87 = ttg.async_wait %53 {num = 5 : i32} loc(#loc24) | |
%88 = ttg.local_load %50 token %87 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc24) | |
%89 = ttg.async_wait %56 {num = 4 : i32} loc(#loc24) | |
%90 = ttg.local_load %54 token %89 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) | |
%91 = tt.expand_dims %88 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> loc(#loc30) | |
%92 = arith.muli %91, %27 : tensor<1x64xi64, #blocked> loc(#loc20) | |
%93 = tt.broadcast %92 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> loc(#loc21) | |
%94 = tt.addptr %28, %93 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> loc(#loc21) | |
%95 = ttg.memdesc_subview %34[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%96 = tt.splat %46 : i1 -> tensor<128x64xi1, #blocked> loc(#loc27) | |
%97 = ttg.async_copy_global_to_local %94, %95 mask %96 : tensor<128x64x!tt.ptr<f16>, #blocked> -> <128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%98 = ttg.async_commit_group %97 loc(#loc25) | |
%99 = tt.expand_dims %90 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> loc(#loc31) | |
%100 = arith.muli %99, %30 : tensor<64x1xi64, #blocked1> loc(#loc22) | |
%101 = tt.broadcast %100 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> loc(#loc23) | |
%102 = tt.addptr %31, %101 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> loc(#loc23) | |
%103 = ttg.memdesc_subview %35[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%104 = tt.splat %46 : i1 -> tensor<64x256xi1, #blocked1> loc(#loc27) | |
%105 = ttg.async_copy_global_to_local %102, %103 mask %104 : tensor<64x256x!tt.ptr<f16>, #blocked1> -> <64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%106 = ttg.async_commit_group %105 loc(#loc26) | |
%107 = tt.addptr %arg3, %c192_i32 : !tt.ptr<i64>, i32 loc(#loc29) | |
%108 = tt.splat %107 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28) | |
%109 = tt.addptr %108, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28) | |
%110 = tt.splat %86 : i1 -> tensor<64xi1, #blocked2> loc(#loc27) | |
%111 = ttg.async_copy_global_to_local %109, %50 mask %110 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%112 = ttg.async_commit_group %111 loc(#loc24) | |
%113 = ttg.async_copy_global_to_local %109, %54 mask %110 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%114 = ttg.async_commit_group %113 loc(#loc24) | |
%115:11 = scf.for %arg10 = %c0_i32 to %25 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %c1_i32, %arg13 = %c1_i32, %arg14 = %c1_i32, %arg15 = %c-1_i32, %arg16 = %77, %arg17 = %106, %arg18 = %83, %arg19 = %112, %arg20 = %85, %arg21 = %114) -> (tensor<128x256xf32, #mma>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token) : i32 { | |
%136 = arith.subi %25, %c4_i32 : i32 loc(#loc27) | |
%137 = arith.cmpi slt, %arg10, %136 : i32 loc(#loc27) | |
%138 = arith.subi %25, %c2_i32 : i32 loc(#loc27) | |
%139 = arith.cmpi slt, %arg10, %138 : i32 loc(#loc27) | |
%140 = arith.addi %arg15, %c1_i32 : i32 loc(#loc27) | |
%141 = arith.cmpi slt, %140, %c3_i32 : i32 loc(#loc27) | |
%142 = arith.select %141, %140, %c0_i32 : i32 loc(#loc27) | |
%148 = arith.addi %arg14, %c1_i32 : i32 loc(#loc27) | |
%149 = arith.cmpi slt, %148, %c3_i32 : i32 loc(#loc27) | |
%150 = arith.select %149, %148, %c0_i32 : i32 loc(#loc27) | |
%151 = arith.addi %arg13, %c1_i32 : i32 loc(#loc27) | |
%152 = arith.cmpi slt, %151, %c2_i32 : i32 loc(#loc27) | |
%153 = arith.select %152, %151, %c0_i32 : i32 loc(#loc27) | |
%154 = ttg.async_wait %arg18 {num = 5 : i32} loc(#loc24) | |
%155 = ttg.memdesc_subview %32[%153, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%156 = ttg.local_load %155 token %154 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc24) | |
%157 = ttg.async_wait %arg20 {num = 4 : i32} loc(#loc24) | |
%158 = ttg.memdesc_subview %33[%153, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%159 = ttg.local_load %158 token %157 : !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> -> tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> loc(#loc24) | |
%143 = ttg.memdesc_subview %34[%142, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%144 = ttg.async_wait %arg16 {num = 6 : i32} loc(#loc25) | |
%145 = ttg.memdesc_subview %35[%142, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%146 = ttng.warp_group_dot %143, %145, %arg11 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc32) | |
%147:3 = ttng.warp_group_dot_wait %146, %143, %145 {pendings = 1 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc32) | |
%160 = tt.expand_dims %156 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi64, #blocked> loc(#loc30) | |
%161 = arith.muli %160, %27 : tensor<1x64xi64, #blocked> loc(#loc20) | |
%162 = tt.broadcast %161 : tensor<1x64xi64, #blocked> -> tensor<128x64xi64, #blocked> loc(#loc21) | |
%163 = tt.addptr %28, %162 : tensor<128x64x!tt.ptr<f16>, #blocked>, tensor<128x64xi64, #blocked> loc(#loc21) | |
%164 = ttg.memdesc_subview %34[%150, %c0_i32, %c0_i32] : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%165 = tt.splat %139 : i1 -> tensor<128x64xi1, #blocked> loc(#loc27) | |
%166 = ttg.async_copy_global_to_local %163, %164 mask %165 : tensor<128x64x!tt.ptr<f16>, #blocked> -> <128x64xf16, #shared1, #smem, mutable, 3x128x64> loc(#loc25) | |
%167 = ttg.async_commit_group %166 loc(#loc25) | |
%168 = tt.expand_dims %159 {axis = 1 : i32} : tensor<64xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi64, #blocked1> loc(#loc31) | |
%169 = arith.muli %168, %30 : tensor<64x1xi64, #blocked1> loc(#loc22) | |
%170 = tt.broadcast %169 : tensor<64x1xi64, #blocked1> -> tensor<64x256xi64, #blocked1> loc(#loc23) | |
%171 = tt.addptr %31, %170 : tensor<64x256x!tt.ptr<f16>, #blocked1>, tensor<64x256xi64, #blocked1> loc(#loc23) | |
%172 = ttg.memdesc_subview %35[%150, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%173 = tt.splat %139 : i1 -> tensor<64x256xi1, #blocked1> loc(#loc27) | |
%174 = ttg.async_copy_global_to_local %171, %172 mask %173 : tensor<64x256x!tt.ptr<f16>, #blocked1> -> <64x256xf16, #shared2, #smem, mutable, 3x64x256> loc(#loc26) | |
%175 = ttg.async_commit_group %174 loc(#loc26) | |
%176 = arith.addi %arg12, %c1_i32 : i32 loc(#loc27) | |
%177 = arith.cmpi slt, %176, %c2_i32 : i32 loc(#loc27) | |
%178 = arith.select %177, %176, %c0_i32 : i32 loc(#loc27) | |
%179 = arith.addi %arg10, %c4_i32 : i32 loc(#loc27) | |
%180 = arith.muli %179, %c64_i32 : i32 loc(#loc33) | |
%181 = tt.addptr %arg3, %180 : !tt.ptr<i64>, i32 loc(#loc29) | |
%182 = tt.splat %181 : !tt.ptr<i64> -> tensor<64x!tt.ptr<i64>, #blocked2> loc(#loc28) | |
%183 = tt.addptr %182, %17 : tensor<64x!tt.ptr<i64>, #blocked2>, tensor<64xi32, #blocked2> loc(#loc28) | |
%184 = ttg.memdesc_subview %32[%178, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%185 = tt.splat %137 : i1 -> tensor<64xi1, #blocked2> loc(#loc27) | |
%186 = ttg.async_copy_global_to_local %183, %184 mask %185 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%187 = ttg.async_commit_group %186 loc(#loc24) | |
%188 = ttg.memdesc_subview %33[%178, %c0_i32] : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> -> !ttg.memdesc<64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%189 = ttg.async_copy_global_to_local %183, %188 mask %185 : tensor<64x!tt.ptr<i64>, #blocked2> -> <64xi64, #shared, #smem, mutable, 2x64> loc(#loc24) | |
%190 = ttg.async_commit_group %189 loc(#loc24) | |
scf.yield %147#0, %178, %153, %150, %142, %arg17, %175, %arg19, %187, %arg21, %190 : tensor<128x256xf32, #mma>, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token loc(#loc27) | |
} loc(#loc27) | |
%116 = ttng.warp_group_dot_wait %115#0 {pendings = 0 : i32} : tensor<128x256xf32, #mma> loc(#loc27) | |
%117 = ttg.async_wait {num = 0 : i32} loc(#loc27) | |
ttg.local_dealloc %32 : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc27) | |
ttg.local_dealloc %33 : !ttg.memdesc<2x64xi64, #shared, #smem, mutable> loc(#loc27) | |
ttg.local_dealloc %34 : !ttg.memdesc<3x128x64xf16, #shared1, #smem, mutable> loc(#loc27) | |
ttg.local_dealloc %35 : !ttg.memdesc<3x64x256xf16, #shared2, #smem, mutable> loc(#loc27) | |
%118 = arith.truncf %116 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> loc(#loc34) | |
%119 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> loc(#loc35) | |
%120 = tt.splat %arg9 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc36) | |
%121 = arith.muli %120, %119 : tensor<128x1xi32, #blocked1> loc(#loc36) | |
%122 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1> loc(#loc37) | |
%123 = tt.addptr %122, %121 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1> loc(#loc37) | |
%124 = tt.expand_dims %14 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> loc(#loc38) | |
%125 = tt.broadcast %123 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x256x!tt.ptr<f16>, #blocked1> loc(#loc39) | |
%126 = tt.broadcast %124 : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> loc(#loc39) | |
%127 = tt.addptr %125, %126 : tensor<128x256x!tt.ptr<f16>, #blocked1>, tensor<128x256xi32, #blocked1> loc(#loc39) | |
%128 = tt.splat %arg4 : i32 -> tensor<128x1xi32, #blocked1> loc(#loc40) | |
%129 = arith.cmpi slt, %119, %128 : tensor<128x1xi32, #blocked1> loc(#loc40) | |
%130 = tt.splat %arg5 : i32 -> tensor<1x256xi32, #blocked1> loc(#loc41) | |
%131 = arith.cmpi slt, %124, %130 : tensor<1x256xi32, #blocked1> loc(#loc41) | |
%132 = tt.broadcast %129 : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc42) | |
%133 = tt.broadcast %131 : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> loc(#loc42) | |
%134 = arith.andi %132, %133 : tensor<128x256xi1, #blocked1> loc(#loc42) | |
%135 = ttg.convert_layout %118 : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> loc(#loc43) | |
tt.store %127, %135, %134 : tensor<128x256x!tt.ptr<f16>, #blocked1> loc(#loc43) | |
tt.return loc(#loc44) | |
} loc(#loc) | |
} loc(#loc) | |
#loc1 = loc(unknown) | |
#loc2 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":182:26) | |
#loc3 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":183:26) | |
#loc4 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:23) | |
#loc5 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:51) | |
#loc6 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:38) | |
#loc7 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":185:68) | |
#loc8 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:23) | |
#loc9 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:51) | |
#loc10 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:38) | |
#loc11 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":186:68) | |
#loc12 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":187:26) | |
#loc13 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":188:29) | |
#loc14 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":188:21) | |
#loc15 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":189:29) | |
#loc16 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":189:21) | |
#loc17 = loc("/home/dberard/local/triton-env2/triton/python/triton/language/standard.py":40:22) | |
#loc18 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":193:33) | |
#loc19 = loc("/home/dberard/local/triton-env2/triton/python/triton/language/standard.py":40:28) | |
#loc20 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:53) | |
#loc21 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:29) | |
#loc22 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:53) | |
#loc23 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:29) | |
#loc24 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:20) | |
#loc25 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:20) | |
#loc26 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:20) | |
#loc27 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":193:22) | |
#loc28 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:49) | |
#loc29 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:30) | |
#loc30 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":199:42) | |
#loc31 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":200:42) | |
#loc32 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":203:35) | |
#loc33 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":196:34) | |
#loc34 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":205:23) | |
#loc35 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:41) | |
#loc36 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:33) | |
#loc37 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:21) | |
#loc38 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:72) | |
#loc39 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":211:52) | |
#loc40 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":213:33) | |
#loc41 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":213:58) | |
#loc42 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":213:39) | |
#loc43 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":214:21) | |
#loc44 = loc("/home/dberard/local/fbsource/fbcode/scripts/dberard/matmul_sparse/matmul_sparse.py":214:4) | |
#loc45 = loc(callsite(#loc17 at #loc18)) | |
#loc46 = loc(callsite(#loc19 at #loc18)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment