Created
March 5, 2025 14:22
-
-
Save makslevental/a43324599b5ff749598f726e1f81081d to your computer and use it in GitHub Desktop.
Triton + Linalg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
BEFORE | |
#map = affine_map<(d0, d1) -> (d0, d1)> | |
module { | |
module attributes {transform.target_tag = "payload"} { | |
tt.func public @matmul_kernel_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {noinline = false} { | |
%c64_i32 = arith.constant 64 : i32 | |
%c64_i32_0 = arith.constant 64 : i32 | |
%c64_i32_1 = arith.constant 64 : i32 | |
%c1_i32 = arith.constant 1 : i32 | |
%0 = tt.get_program_id x : i32 | |
%1 = arith.ceildivsi %arg3, %c64_i32 : i32 | |
%2 = arith.ceildivsi %arg4, %c64_i32_0 : i32 | |
%c1_i32_2 = arith.constant 1 : i32 | |
%3 = arith.muli %2, %c1_i32_2 : i32 | |
%4 = arith.floordivsi %0, %3 : i32 | |
%c1_i32_3 = arith.constant 1 : i32 | |
%5 = arith.muli %4, %c1_i32_3 : i32 | |
%6 = arith.subi %1, %5 : i32 | |
%7 = arith.minsi %6, %c1_i32 : i32 | |
%8 = arith.remsi %0, %3 : i32 | |
%9 = arith.remsi %8, %7 : i32 | |
%10 = arith.addi %5, %9 : i32 | |
%11 = arith.remsi %0, %3 : i32 | |
%12 = arith.floordivsi %11, %7 : i32 | |
%c64_i32_4 = arith.constant 64 : i32 | |
%13 = arith.muli %10, %c64_i32_4 : i32 | |
%14 = tt.splat %13 : i32 -> tensor<64xi32> | |
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> | |
%16 = arith.addi %14, %15 : tensor<64xi32> | |
%17 = tt.splat %arg3 : i32 -> tensor<64xi32> | |
%18 = arith.remsi %16, %17 : tensor<64xi32> | |
%c64_i32_5 = arith.constant 64 : i32 | |
%19 = arith.muli %12, %c64_i32_5 : i32 | |
%20 = tt.splat %19 : i32 -> tensor<64xi32> | |
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> | |
%22 = arith.addi %20, %21 : tensor<64xi32> | |
%23 = tt.splat %arg4 : i32 -> tensor<64xi32> | |
%24 = arith.remsi %22, %23 : tensor<64xi32> | |
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> | |
%26 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%27 = tt.splat %arg6 : i32 -> tensor<64x1xi32> | |
%28 = arith.muli %26, %27 : tensor<64x1xi32> | |
%29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%30 = tt.splat %arg7 : i32 -> tensor<1x64xi32> | |
%31 = arith.muli %29, %30 : tensor<1x64xi32> | |
%32 = tt.broadcast %28 : tensor<64x1xi32> -> tensor<64x64xi32> | |
%33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<64x64xi32> | |
%34 = arith.addi %32, %33 : tensor<64x64xi32> | |
%35 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%36 = tt.addptr %35, %34 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> | |
%37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%38 = tt.splat %arg8 : i32 -> tensor<64x1xi32> | |
%39 = arith.muli %37, %38 : tensor<64x1xi32> | |
%40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%41 = tt.splat %arg9 : i32 -> tensor<1x64xi32> | |
%42 = arith.muli %40, %41 : tensor<1x64xi32> | |
%43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x64xi32> | |
%44 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<64x64xi32> | |
%45 = arith.addi %43, %44 : tensor<64x64xi32> | |
%46 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%47 = tt.addptr %46, %45 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> | |
%c64_i32_6 = arith.constant 64 : i32 | |
%48 = arith.muli %arg7, %c64_i32_6 : i32 | |
%49 = tt.splat %48 : i32 -> tensor<64x64xi32> | |
%c64_i32_7 = arith.constant 64 : i32 | |
%50 = arith.muli %arg8, %c64_i32_7 : i32 | |
%51 = tt.splat %50 : i32 -> tensor<64x64xi32> | |
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32> | |
%cst_8 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> | |
%cst_9 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> | |
%52 = arith.ceildivsi %arg5, %c64_i32_1 : i32 | |
%53 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%36 : tensor<64x64x!tt.ptr<f32>>) outs(%47 : tensor<64x64x!tt.ptr<f32>>) { | |
^bb0(%in: !tt.ptr<f32>, %out: !tt.ptr<f32>): | |
%72 = tt.splat %in : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%73 = tt.splat %out : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%74 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%c64_i32_10 = arith.constant 64 : i32 | |
%75 = arith.subi %arg5, %c64_i32_10 : i32 | |
%76 = tt.splat %75 : i32 -> tensor<1x64xi32> | |
%77 = arith.cmpi slt, %74, %76 : tensor<1x64xi32> | |
%78 = tt.broadcast %77 : tensor<1x64xi1> -> tensor<64x64xi1> | |
%79 = tt.load %72, %78, %cst_9 : tensor<64x64x!tt.ptr<f32>> | |
%80 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%c64_i32_11 = arith.constant 64 : i32 | |
%81 = arith.subi %arg5, %c64_i32_11 : i32 | |
%82 = tt.splat %81 : i32 -> tensor<64x1xi32> | |
%83 = arith.cmpi slt, %80, %82 : tensor<64x1xi32> | |
%84 = tt.broadcast %83 : tensor<64x1xi1> -> tensor<64x64xi1> | |
%85 = tt.load %73, %84, %cst_8 : tensor<64x64x!tt.ptr<f32>> | |
%cst_12 = arith.constant 0.000000e+00 : f32 | |
%86 = tt.splat %cst_12 : f32 -> tensor<64x64xf32> | |
%87 = tt.dot %79, %85, %86 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> | |
linalg.yield %out : !tt.ptr<f32> | |
} -> tensor<64x64x!tt.ptr<f32>> | |
%54 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%55 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>> | |
%56 = tt.splat %arg10 : i32 -> tensor<64x1xi32> | |
%57 = arith.muli %56, %54 : tensor<64x1xi32> | |
%58 = tt.addptr %55, %57 : tensor<64x1x!tt.ptr<f32>>, tensor<64x1xi32> | |
%59 = tt.broadcast %58 : tensor<64x1x!tt.ptr<f32>> -> tensor<64x64x!tt.ptr<f32>> | |
%60 = tt.expand_dims %22 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%61 = tt.splat %arg11 : i32 -> tensor<1x64xi32> | |
%62 = arith.muli %61, %60 : tensor<1x64xi32> | |
%63 = tt.broadcast %62 : tensor<1x64xi32> -> tensor<64x64xi32> | |
%64 = tt.addptr %59, %63 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> | |
%65 = tt.splat %arg3 : i32 -> tensor<64x1xi32> | |
%66 = arith.cmpi slt, %54, %65 : tensor<64x1xi32> | |
%67 = tt.broadcast %66 : tensor<64x1xi1> -> tensor<64x64xi1> | |
%68 = tt.splat %arg4 : i32 -> tensor<1x64xi32> | |
%69 = arith.cmpi slt, %60, %68 : tensor<1x64xi32> | |
%70 = tt.broadcast %69 : tensor<1x64xi1> -> tensor<64x64xi1> | |
%71 = arith.andi %67, %70 : tensor<64x64xi1> | |
tt.return | |
} | |
} | |
module attributes {transform.with_named_sequence} { | |
transform.named_sequence @cleanup(%arg0: !transform.any_op {transform.readonly}) { | |
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | |
transform.apply_patterns to %0 { | |
transform.apply_patterns.linalg.tiling_canonicalization | |
transform.apply_patterns.scf.for_loop_canonicalization | |
transform.apply_patterns.canonicalization | |
} : !transform.any_op | |
%1 = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op | |
transform.apply_licm to %1 : !transform.any_op | |
transform.apply_cse to %0 : !transform.any_op | |
transform.yield | |
} | |
transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) { | |
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op | |
%tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [16, 64](mapping = [#gpu.thread<y>, #gpu.thread<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%arg0) : (!transform.any_op) -> () | |
transform.yield | |
} | |
} | |
} | |
AFTER | |
#map = affine_map<(d0) -> (d0 * 16)> | |
#map1 = affine_map<(d0) -> (d0 * 64)> | |
#map2 = affine_map<(d0, d1) -> (d0, d1)> | |
module { | |
module attributes {transform.target_tag = "payload"} { | |
tt.func public @matmul_kernel_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {noinline = false} { | |
%c64_i32 = arith.constant 64 : i32 | |
%c64_i32_0 = arith.constant 64 : i32 | |
%c64_i32_1 = arith.constant 64 : i32 | |
%c1_i32 = arith.constant 1 : i32 | |
%0 = tt.get_program_id x : i32 | |
%1 = arith.ceildivsi %arg3, %c64_i32 : i32 | |
%2 = arith.ceildivsi %arg4, %c64_i32_0 : i32 | |
%c1_i32_2 = arith.constant 1 : i32 | |
%3 = arith.muli %2, %c1_i32_2 : i32 | |
%4 = arith.floordivsi %0, %3 : i32 | |
%c1_i32_3 = arith.constant 1 : i32 | |
%5 = arith.muli %4, %c1_i32_3 : i32 | |
%6 = arith.subi %1, %5 : i32 | |
%7 = arith.minsi %6, %c1_i32 : i32 | |
%8 = arith.remsi %0, %3 : i32 | |
%9 = arith.remsi %8, %7 : i32 | |
%10 = arith.addi %5, %9 : i32 | |
%11 = arith.remsi %0, %3 : i32 | |
%12 = arith.floordivsi %11, %7 : i32 | |
%c64_i32_4 = arith.constant 64 : i32 | |
%13 = arith.muli %10, %c64_i32_4 : i32 | |
%14 = tt.splat %13 : i32 -> tensor<64xi32> | |
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> | |
%16 = arith.addi %14, %15 : tensor<64xi32> | |
%17 = tt.splat %arg3 : i32 -> tensor<64xi32> | |
%18 = arith.remsi %16, %17 : tensor<64xi32> | |
%c64_i32_5 = arith.constant 64 : i32 | |
%19 = arith.muli %12, %c64_i32_5 : i32 | |
%20 = tt.splat %19 : i32 -> tensor<64xi32> | |
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> | |
%22 = arith.addi %20, %21 : tensor<64xi32> | |
%23 = tt.splat %arg4 : i32 -> tensor<64xi32> | |
%24 = arith.remsi %22, %23 : tensor<64xi32> | |
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> | |
%26 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%27 = tt.splat %arg6 : i32 -> tensor<64x1xi32> | |
%28 = arith.muli %26, %27 : tensor<64x1xi32> | |
%29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%30 = tt.splat %arg7 : i32 -> tensor<1x64xi32> | |
%31 = arith.muli %29, %30 : tensor<1x64xi32> | |
%32 = tt.broadcast %28 : tensor<64x1xi32> -> tensor<64x64xi32> | |
%33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<64x64xi32> | |
%34 = arith.addi %32, %33 : tensor<64x64xi32> | |
%35 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%36 = tt.addptr %35, %34 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> | |
%37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%38 = tt.splat %arg8 : i32 -> tensor<64x1xi32> | |
%39 = arith.muli %37, %38 : tensor<64x1xi32> | |
%40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%41 = tt.splat %arg9 : i32 -> tensor<1x64xi32> | |
%42 = arith.muli %40, %41 : tensor<1x64xi32> | |
%43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x64xi32> | |
%44 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<64x64xi32> | |
%45 = arith.addi %43, %44 : tensor<64x64xi32> | |
%46 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%47 = tt.addptr %46, %45 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> | |
%c64_i32_6 = arith.constant 64 : i32 | |
%48 = arith.muli %arg7, %c64_i32_6 : i32 | |
%49 = tt.splat %48 : i32 -> tensor<64x64xi32> | |
%c64_i32_7 = arith.constant 64 : i32 | |
%50 = arith.muli %arg8, %c64_i32_7 : i32 | |
%51 = tt.splat %50 : i32 -> tensor<64x64xi32> | |
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32> | |
%cst_8 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> | |
%cst_9 = arith.constant dense<0.000000e+00> : tensor<64x64xf32> | |
%52 = arith.ceildivsi %arg5, %c64_i32_1 : i32 | |
%53 = scf.forall (%arg12, %arg13) in (4, 1) shared_outs(%arg14 = %47) -> (tensor<64x64x!tt.ptr<f32>>) { | |
%72 = affine.apply #map(%arg12) | |
%73 = affine.apply #map1(%arg13) | |
%extracted_slice = tensor.extract_slice %36[%72, %73] [16, 64] [1, 1] : tensor<64x64x!tt.ptr<f32>> to tensor<16x64x!tt.ptr<f32>> | |
%extracted_slice_10 = tensor.extract_slice %arg14[%72, %73] [16, 64] [1, 1] : tensor<64x64x!tt.ptr<f32>> to tensor<16x64x!tt.ptr<f32>> | |
%74 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<16x64x!tt.ptr<f32>>) outs(%extracted_slice_10 : tensor<16x64x!tt.ptr<f32>>) { | |
^bb0(%in: !tt.ptr<f32>, %out: !tt.ptr<f32>): | |
%75 = tt.splat %in : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%76 = tt.splat %out : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>> | |
%77 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%c64_i32_11 = arith.constant 64 : i32 | |
%78 = arith.subi %arg5, %c64_i32_11 : i32 | |
%79 = tt.splat %78 : i32 -> tensor<1x64xi32> | |
%80 = arith.cmpi slt, %77, %79 : tensor<1x64xi32> | |
%81 = tt.broadcast %80 : tensor<1x64xi1> -> tensor<64x64xi1> | |
%82 = tt.load %75, %81, %cst_9 : tensor<64x64x!tt.ptr<f32>> | |
%83 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%c64_i32_12 = arith.constant 64 : i32 | |
%84 = arith.subi %arg5, %c64_i32_12 : i32 | |
%85 = tt.splat %84 : i32 -> tensor<64x1xi32> | |
%86 = arith.cmpi slt, %83, %85 : tensor<64x1xi32> | |
%87 = tt.broadcast %86 : tensor<64x1xi1> -> tensor<64x64xi1> | |
%88 = tt.load %76, %87, %cst_8 : tensor<64x64x!tt.ptr<f32>> | |
%cst_13 = arith.constant 0.000000e+00 : f32 | |
%89 = tt.splat %cst_13 : f32 -> tensor<64x64xf32> | |
%90 = tt.dot %82, %88, %89 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> | |
linalg.yield %out : !tt.ptr<f32> | |
} -> tensor<16x64x!tt.ptr<f32>> | |
scf.forall.in_parallel { | |
tensor.parallel_insert_slice %74 into %arg14[%72, %73] [16, 64] [1, 1] : tensor<16x64x!tt.ptr<f32>> into tensor<64x64x!tt.ptr<f32>> | |
} | |
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]} | |
%54 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> | |
%55 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>> | |
%56 = tt.splat %arg10 : i32 -> tensor<64x1xi32> | |
%57 = arith.muli %56, %54 : tensor<64x1xi32> | |
%58 = tt.addptr %55, %57 : tensor<64x1x!tt.ptr<f32>>, tensor<64x1xi32> | |
%59 = tt.broadcast %58 : tensor<64x1x!tt.ptr<f32>> -> tensor<64x64x!tt.ptr<f32>> | |
%60 = tt.expand_dims %22 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> | |
%61 = tt.splat %arg11 : i32 -> tensor<1x64xi32> | |
%62 = arith.muli %61, %60 : tensor<1x64xi32> | |
%63 = tt.broadcast %62 : tensor<1x64xi32> -> tensor<64x64xi32> | |
%64 = tt.addptr %59, %63 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32> | |
%65 = tt.splat %arg3 : i32 -> tensor<64x1xi32> | |
%66 = arith.cmpi slt, %54, %65 : tensor<64x1xi32> | |
%67 = tt.broadcast %66 : tensor<64x1xi1> -> tensor<64x64xi1> | |
%68 = tt.splat %arg4 : i32 -> tensor<1x64xi32> | |
%69 = arith.cmpi slt, %60, %68 : tensor<1x64xi32> | |
%70 = tt.broadcast %69 : tensor<1x64xi1> -> tensor<64x64xi1> | |
%71 = arith.andi %67, %70 : tensor<64x64xi1> | |
tt.return | |
} | |
} | |
module attributes {transform.with_named_sequence} { | |
transform.named_sequence @cleanup(%arg0: !transform.any_op {transform.readonly}) { | |
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op | |
transform.apply_patterns to %0 { | |
transform.apply_patterns.linalg.tiling_canonicalization | |
transform.apply_patterns.scf.for_loop_canonicalization | |
transform.apply_patterns.canonicalization | |
} : !transform.any_op | |
%1 = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op | |
transform.apply_licm to %1 : !transform.any_op | |
transform.apply_cse to %0 : !transform.any_op | |
transform.yield | |
} | |
transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) { | |
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op | |
%tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [16, 64](mapping = [#gpu.thread<y>, #gpu.thread<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) | |
transform.include @cleanup failures(propagate) (%arg0) : (!transform.any_op) -> () | |
transform.yield | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment