Last active
November 20, 2024 08:03
-
-
Save minjang/25715aa9d618c6040a570c7188f03197 to your computer and use it in GitHub Desktop.
TTMIR for matmul_kernel (03-matrix-multiplication-cpu.py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}> | |
| #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}> | |
| #loc = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0) | |
| module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cpu", "triton_gpu.threads-per-warp" = 1 : i32} { | |
| tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg1: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg2: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg3: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg4: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg5: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg6: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg7: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg8: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0)) attributes {noinline = false} { | |
| %c8_i32 = arith.constant 8 : i32 loc(#loc1) | |
| %c16_i32 = arith.constant 16 : i32 loc(#loc1) | |
| %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc1) | |
| %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> loc(#loc1) | |
| %c15_i32 = arith.constant 15 : i32 loc(#loc1) | |
| %c1_i32 = arith.constant 1 : i32 loc(#loc1) | |
| %c0_i32 = arith.constant 0 : i32 loc(#loc1) | |
| %cst_1 = arith.constant dense<16> : tensor<16x16xi32, #blocked1> loc(#loc1) | |
| %0 = tt.get_program_id x : i32 loc(#loc2) | |
| %1 = arith.addi %arg3, %c15_i32 : i32 loc(#loc56) | |
| %2 = arith.divsi %1, %c16_i32 : i32 loc(#loc57) | |
| %3 = arith.addi %arg4, %c15_i32 : i32 loc(#loc58) | |
| %4 = arith.divsi %3, %c16_i32 : i32 loc(#loc59) | |
| %5 = arith.muli %4, %c8_i32 : i32 loc(#loc7) | |
| %6 = arith.divsi %0, %5 : i32 loc(#loc8) | |
| %7 = arith.muli %6, %c8_i32 : i32 loc(#loc9) | |
| %8 = arith.subi %2, %7 : i32 loc(#loc10) | |
| %9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11) | |
| %10 = arith.remsi %0, %9 : i32 loc(#loc12) | |
| %11 = arith.addi %7, %10 : i32 loc(#loc13) | |
| %12 = arith.remsi %0, %5 : i32 loc(#loc14) | |
| %13 = arith.divsi %12, %9 : i32 loc(#loc15) | |
| %14 = arith.muli %11, %c16_i32 : i32 loc(#loc16) | |
| %15 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc17) | |
| %16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc17) | |
| %17 = tt.splat %14 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18) | |
| %18 = arith.addi %17, %15 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18) | |
| %19 = tt.splat %arg3 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc19) | |
| %20 = arith.remsi %18, %19 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc19) | |
| %21 = arith.muli %13, %c16_i32 : i32 loc(#loc20) | |
| %22 = tt.splat %21 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21) | |
| %23 = arith.addi %22, %16 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21) | |
| %24 = tt.splat %arg4 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22) | |
| %25 = arith.remsi %23, %24 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22) | |
| %26 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc23) | |
| %27 = tt.splat %arg6 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc24) | |
| %28 = arith.muli %26, %27 : tensor<16x1xi32, #blocked1> loc(#loc24) | |
| %29 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc25) | |
| %30 = tt.broadcast %28 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc26) | |
| %31 = tt.broadcast %29 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc26) | |
| %32 = arith.addi %30, %31 : tensor<16x16xi32, #blocked1> loc(#loc26) | |
| %33 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc27) | |
| %34 = tt.addptr %33, %32 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc27) | |
| %35 = tt.expand_dims %15 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc28) | |
| %36 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc29) | |
| %37 = arith.muli %35, %36 : tensor<16x1xi32, #blocked1> loc(#loc29) | |
| %38 = tt.expand_dims %25 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc30) | |
| %39 = tt.broadcast %37 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc31) | |
| %40 = tt.broadcast %38 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc31) | |
| %41 = arith.addi %39, %40 : tensor<16x16xi32, #blocked1> loc(#loc31) | |
| %42 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc32) | |
| %43 = tt.addptr %42, %41 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc32) | |
| %44 = arith.addi %arg5, %c15_i32 : i32 loc(#loc60) | |
| %45 = arith.divsi %44, %c16_i32 : i32 loc(#loc61) | |
| %46 = arith.muli %arg7, %c16_i32 : i32 loc(#loc34) | |
| %47 = tt.splat %46 : i32 -> tensor<16x16xi32, #blocked1> loc(#loc35) | |
| %48:3 = scf.for %arg9 = %c0_i32 to %45 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %34, %arg12 = %43) -> (tensor<16x16xf32, #blocked>, tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16x!tt.ptr<f32>, #blocked1>) : i32 { | |
| %66 = arith.muli %arg9, %c16_i32 : i32 loc(#loc37) | |
| %67 = arith.subi %arg5, %66 : i32 loc(#loc38) | |
| %68 = tt.splat %67 : i32 -> tensor<1x16xi32, #blocked1> loc(#loc39) | |
| %69 = arith.cmpi slt, %29, %68 : tensor<1x16xi32, #blocked1> loc(#loc39) | |
| %70 = tt.broadcast %69 : tensor<1x16xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc40) | |
| %71 = tt.load %arg11, %70, %cst_0 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc40) | |
| %72 = tt.splat %67 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41) | |
| %73 = arith.cmpi slt, %35, %72 : tensor<16x1xi32, #blocked1> loc(#loc41) | |
| %74 = tt.broadcast %73 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc42) | |
| %75 = tt.load %arg12, %74, %cst_0 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc42) | |
| %76 = triton_gpu.convert_layout %71 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc40) | |
| %77 = triton_gpu.convert_layout %75 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc42) | |
| %78 = tt.dot %76, %77, %arg10, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> loc(#loc43) | |
| %79 = tt.addptr %arg11, %cst_1 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc44) | |
| %80 = tt.addptr %arg12, %47 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc35) | |
| scf.yield %78, %79, %80 : tensor<16x16xf32, #blocked>, tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc45) | |
| } loc(#loc36) | |
| %49 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc46) | |
| %50 = tt.splat %arg8 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc47) | |
| %51 = arith.muli %50, %49 : tensor<16x1xi32, #blocked1> loc(#loc47) | |
| %52 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>, #blocked1> loc(#loc48) | |
| %53 = tt.addptr %52, %51 : tensor<16x1x!tt.ptr<f32>, #blocked1>, tensor<16x1xi32, #blocked1> loc(#loc48) | |
| %54 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc49) | |
| %55 = tt.broadcast %53 : tensor<16x1x!tt.ptr<f32>, #blocked1> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc50) | |
| %56 = tt.broadcast %54 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc50) | |
| %57 = tt.addptr %55, %56 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc50) | |
| %58 = tt.splat %arg3 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc51) | |
| %59 = arith.cmpi slt, %49, %58 : tensor<16x1xi32, #blocked1> loc(#loc51) | |
| %60 = tt.splat %arg4 : i32 -> tensor<1x16xi32, #blocked1> loc(#loc52) | |
| %61 = arith.cmpi slt, %54, %60 : tensor<1x16xi32, #blocked1> loc(#loc52) | |
| %62 = tt.broadcast %59 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc53) | |
| %63 = tt.broadcast %61 : tensor<1x16xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc53) | |
| %64 = arith.andi %62, %63 : tensor<16x16xi1, #blocked1> loc(#loc53) | |
| %65 = triton_gpu.convert_layout %48#0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> loc(#loc54) | |
| tt.store %57, %65, %64 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc54) | |
| tt.return loc(#loc55) | |
| } loc(#loc) | |
| } loc(#loc) | |
| #loc1 = loc(unknown) | |
| #loc2 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":188:24) | |
| #loc3 = loc("/data/users/minjang/triton-oss/triton-cpu/python/triton/language/standard.py":40:22) | |
| #loc4 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":189:27) | |
| #loc5 = loc("/data/users/minjang/triton-oss/triton-cpu/python/triton/language/standard.py":40:28) | |
| #loc6 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":190:27) | |
| #loc7 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":191:38) | |
| #loc8 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":192:22) | |
| #loc9 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":193:29) | |
| #loc10 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":194:35) | |
| #loc11 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":194:48) | |
| #loc12 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":195:33) | |
| #loc13 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":195:27) | |
| #loc14 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":196:19) | |
| #loc15 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":196:40) | |
| #loc16 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:23) | |
| #loc17 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:51) | |
| #loc18 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:38) | |
| #loc19 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:68) | |
| #loc20 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:23) | |
| #loc21 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:38) | |
| #loc22 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:68) | |
| #loc23 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:30) | |
| #loc24 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:41) | |
| #loc25 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:60) | |
| #loc26 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:53) | |
| #loc27 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:22) | |
| #loc28 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:29) | |
| #loc29 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:40) | |
| #loc30 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:60) | |
| #loc31 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:52) | |
| #loc32 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:22) | |
| #loc33 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":217:33) | |
| #loc34 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:33) | |
| #loc35 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:18) | |
| #loc36 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":217:22) | |
| #loc37 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:59) | |
| #loc38 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:55) | |
| #loc39 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:51) | |
| #loc40 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:20) | |
| #loc41 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":222:51) | |
| #loc42 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":222:20) | |
| #loc43 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":224:35) | |
| #loc44 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":226:18) | |
| #loc45 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:8) | |
| #loc46 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:41) | |
| #loc47 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:33) | |
| #loc48 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:21) | |
| #loc49 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:72) | |
| #loc50 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:52) | |
| #loc51 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:33) | |
| #loc52 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:58) | |
| #loc53 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:39) | |
| #loc54 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":239:21) | |
| #loc55 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":239:4) | |
| #loc56 = loc(callsite(#loc3 at #loc4)) | |
| #loc57 = loc(callsite(#loc5 at #loc4)) | |
| #loc58 = loc(callsite(#loc3 at #loc6)) | |
| #loc59 = loc(callsite(#loc5 at #loc6)) | |
| #loc60 = loc(callsite(#loc3 at #loc33)) | |
| #loc61 = loc(callsite(#loc5 at #loc33)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment