Skip to content

Instantly share code, notes, and snippets.

@makslevental
Created March 5, 2025 14:22
Show Gist options
  • Save makslevental/a43324599b5ff749598f726e1f81081d to your computer and use it in GitHub Desktop.
Save makslevental/a43324599b5ff749598f726e1f81081d to your computer and use it in GitHub Desktop.
Triton + Linalg
BEFORE
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
module attributes {transform.target_tag = "payload"} {
tt.func public @matmul_kernel_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {noinline = false} {
%c64_i32 = arith.constant 64 : i32
%c64_i32_0 = arith.constant 64 : i32
%c64_i32_1 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id x : i32
%1 = arith.ceildivsi %arg3, %c64_i32 : i32
%2 = arith.ceildivsi %arg4, %c64_i32_0 : i32
%c1_i32_2 = arith.constant 1 : i32
%3 = arith.muli %2, %c1_i32_2 : i32
%4 = arith.floordivsi %0, %3 : i32
%c1_i32_3 = arith.constant 1 : i32
%5 = arith.muli %4, %c1_i32_3 : i32
%6 = arith.subi %1, %5 : i32
%7 = arith.minsi %6, %c1_i32 : i32
%8 = arith.remsi %0, %3 : i32
%9 = arith.remsi %8, %7 : i32
%10 = arith.addi %5, %9 : i32
%11 = arith.remsi %0, %3 : i32
%12 = arith.floordivsi %11, %7 : i32
%c64_i32_4 = arith.constant 64 : i32
%13 = arith.muli %10, %c64_i32_4 : i32
%14 = tt.splat %13 : i32 -> tensor<64xi32>
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%16 = arith.addi %14, %15 : tensor<64xi32>
%17 = tt.splat %arg3 : i32 -> tensor<64xi32>
%18 = arith.remsi %16, %17 : tensor<64xi32>
%c64_i32_5 = arith.constant 64 : i32
%19 = arith.muli %12, %c64_i32_5 : i32
%20 = tt.splat %19 : i32 -> tensor<64xi32>
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%22 = arith.addi %20, %21 : tensor<64xi32>
%23 = tt.splat %arg4 : i32 -> tensor<64xi32>
%24 = arith.remsi %22, %23 : tensor<64xi32>
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%26 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%27 = tt.splat %arg6 : i32 -> tensor<64x1xi32>
%28 = arith.muli %26, %27 : tensor<64x1xi32>
%29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%30 = tt.splat %arg7 : i32 -> tensor<1x64xi32>
%31 = arith.muli %29, %30 : tensor<1x64xi32>
%32 = tt.broadcast %28 : tensor<64x1xi32> -> tensor<64x64xi32>
%33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<64x64xi32>
%34 = arith.addi %32, %33 : tensor<64x64xi32>
%35 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%36 = tt.addptr %35, %34 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%38 = tt.splat %arg8 : i32 -> tensor<64x1xi32>
%39 = arith.muli %37, %38 : tensor<64x1xi32>
%40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%41 = tt.splat %arg9 : i32 -> tensor<1x64xi32>
%42 = arith.muli %40, %41 : tensor<1x64xi32>
%43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x64xi32>
%44 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<64x64xi32>
%45 = arith.addi %43, %44 : tensor<64x64xi32>
%46 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%47 = tt.addptr %46, %45 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%c64_i32_6 = arith.constant 64 : i32
%48 = arith.muli %arg7, %c64_i32_6 : i32
%49 = tt.splat %48 : i32 -> tensor<64x64xi32>
%c64_i32_7 = arith.constant 64 : i32
%50 = arith.muli %arg8, %c64_i32_7 : i32
%51 = tt.splat %50 : i32 -> tensor<64x64xi32>
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_8 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%52 = arith.ceildivsi %arg5, %c64_i32_1 : i32
%53 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%36 : tensor<64x64x!tt.ptr<f32>>) outs(%47 : tensor<64x64x!tt.ptr<f32>>) {
^bb0(%in: !tt.ptr<f32>, %out: !tt.ptr<f32>):
%72 = tt.splat %in : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%73 = tt.splat %out : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%74 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%c64_i32_10 = arith.constant 64 : i32
%75 = arith.subi %arg5, %c64_i32_10 : i32
%76 = tt.splat %75 : i32 -> tensor<1x64xi32>
%77 = arith.cmpi slt, %74, %76 : tensor<1x64xi32>
%78 = tt.broadcast %77 : tensor<1x64xi1> -> tensor<64x64xi1>
%79 = tt.load %72, %78, %cst_9 : tensor<64x64x!tt.ptr<f32>>
%80 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%c64_i32_11 = arith.constant 64 : i32
%81 = arith.subi %arg5, %c64_i32_11 : i32
%82 = tt.splat %81 : i32 -> tensor<64x1xi32>
%83 = arith.cmpi slt, %80, %82 : tensor<64x1xi32>
%84 = tt.broadcast %83 : tensor<64x1xi1> -> tensor<64x64xi1>
%85 = tt.load %73, %84, %cst_8 : tensor<64x64x!tt.ptr<f32>>
%cst_12 = arith.constant 0.000000e+00 : f32
%86 = tt.splat %cst_12 : f32 -> tensor<64x64xf32>
%87 = tt.dot %79, %85, %86 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
linalg.yield %out : !tt.ptr<f32>
} -> tensor<64x64x!tt.ptr<f32>>
%54 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%55 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>>
%56 = tt.splat %arg10 : i32 -> tensor<64x1xi32>
%57 = arith.muli %56, %54 : tensor<64x1xi32>
%58 = tt.addptr %55, %57 : tensor<64x1x!tt.ptr<f32>>, tensor<64x1xi32>
%59 = tt.broadcast %58 : tensor<64x1x!tt.ptr<f32>> -> tensor<64x64x!tt.ptr<f32>>
%60 = tt.expand_dims %22 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%61 = tt.splat %arg11 : i32 -> tensor<1x64xi32>
%62 = arith.muli %61, %60 : tensor<1x64xi32>
%63 = tt.broadcast %62 : tensor<1x64xi32> -> tensor<64x64xi32>
%64 = tt.addptr %59, %63 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%65 = tt.splat %arg3 : i32 -> tensor<64x1xi32>
%66 = arith.cmpi slt, %54, %65 : tensor<64x1xi32>
%67 = tt.broadcast %66 : tensor<64x1xi1> -> tensor<64x64xi1>
%68 = tt.splat %arg4 : i32 -> tensor<1x64xi32>
%69 = arith.cmpi slt, %60, %68 : tensor<1x64xi32>
%70 = tt.broadcast %69 : tensor<1x64xi1> -> tensor<64x64xi1>
%71 = arith.andi %67, %70 : tensor<64x64xi1>
tt.return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @cleanup(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
%1 = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_licm to %1 : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.yield
}
transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [16, 64](mapping = [#gpu.thread<y>, #gpu.thread<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%arg0) : (!transform.any_op) -> ()
transform.yield
}
}
}
AFTER
#map = affine_map<(d0) -> (d0 * 16)>
#map1 = affine_map<(d0) -> (d0 * 64)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
module {
module attributes {transform.target_tag = "payload"} {
tt.func public @matmul_kernel_2(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) attributes {noinline = false} {
%c64_i32 = arith.constant 64 : i32
%c64_i32_0 = arith.constant 64 : i32
%c64_i32_1 = arith.constant 64 : i32
%c1_i32 = arith.constant 1 : i32
%0 = tt.get_program_id x : i32
%1 = arith.ceildivsi %arg3, %c64_i32 : i32
%2 = arith.ceildivsi %arg4, %c64_i32_0 : i32
%c1_i32_2 = arith.constant 1 : i32
%3 = arith.muli %2, %c1_i32_2 : i32
%4 = arith.floordivsi %0, %3 : i32
%c1_i32_3 = arith.constant 1 : i32
%5 = arith.muli %4, %c1_i32_3 : i32
%6 = arith.subi %1, %5 : i32
%7 = arith.minsi %6, %c1_i32 : i32
%8 = arith.remsi %0, %3 : i32
%9 = arith.remsi %8, %7 : i32
%10 = arith.addi %5, %9 : i32
%11 = arith.remsi %0, %3 : i32
%12 = arith.floordivsi %11, %7 : i32
%c64_i32_4 = arith.constant 64 : i32
%13 = arith.muli %10, %c64_i32_4 : i32
%14 = tt.splat %13 : i32 -> tensor<64xi32>
%15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%16 = arith.addi %14, %15 : tensor<64xi32>
%17 = tt.splat %arg3 : i32 -> tensor<64xi32>
%18 = arith.remsi %16, %17 : tensor<64xi32>
%c64_i32_5 = arith.constant 64 : i32
%19 = arith.muli %12, %c64_i32_5 : i32
%20 = tt.splat %19 : i32 -> tensor<64xi32>
%21 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%22 = arith.addi %20, %21 : tensor<64xi32>
%23 = tt.splat %arg4 : i32 -> tensor<64xi32>
%24 = arith.remsi %22, %23 : tensor<64xi32>
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%26 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%27 = tt.splat %arg6 : i32 -> tensor<64x1xi32>
%28 = arith.muli %26, %27 : tensor<64x1xi32>
%29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%30 = tt.splat %arg7 : i32 -> tensor<1x64xi32>
%31 = arith.muli %29, %30 : tensor<1x64xi32>
%32 = tt.broadcast %28 : tensor<64x1xi32> -> tensor<64x64xi32>
%33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<64x64xi32>
%34 = arith.addi %32, %33 : tensor<64x64xi32>
%35 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%36 = tt.addptr %35, %34 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%38 = tt.splat %arg8 : i32 -> tensor<64x1xi32>
%39 = arith.muli %37, %38 : tensor<64x1xi32>
%40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%41 = tt.splat %arg9 : i32 -> tensor<1x64xi32>
%42 = arith.muli %40, %41 : tensor<1x64xi32>
%43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x64xi32>
%44 = tt.broadcast %42 : tensor<1x64xi32> -> tensor<64x64xi32>
%45 = arith.addi %43, %44 : tensor<64x64xi32>
%46 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%47 = tt.addptr %46, %45 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%c64_i32_6 = arith.constant 64 : i32
%48 = arith.muli %arg7, %c64_i32_6 : i32
%49 = tt.splat %48 : i32 -> tensor<64x64xi32>
%c64_i32_7 = arith.constant 64 : i32
%50 = arith.muli %arg8, %c64_i32_7 : i32
%51 = tt.splat %50 : i32 -> tensor<64x64xi32>
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_8 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%52 = arith.ceildivsi %arg5, %c64_i32_1 : i32
%53 = scf.forall (%arg12, %arg13) in (4, 1) shared_outs(%arg14 = %47) -> (tensor<64x64x!tt.ptr<f32>>) {
%72 = affine.apply #map(%arg12)
%73 = affine.apply #map1(%arg13)
%extracted_slice = tensor.extract_slice %36[%72, %73] [16, 64] [1, 1] : tensor<64x64x!tt.ptr<f32>> to tensor<16x64x!tt.ptr<f32>>
%extracted_slice_10 = tensor.extract_slice %arg14[%72, %73] [16, 64] [1, 1] : tensor<64x64x!tt.ptr<f32>> to tensor<16x64x!tt.ptr<f32>>
%74 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<16x64x!tt.ptr<f32>>) outs(%extracted_slice_10 : tensor<16x64x!tt.ptr<f32>>) {
^bb0(%in: !tt.ptr<f32>, %out: !tt.ptr<f32>):
%75 = tt.splat %in : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%76 = tt.splat %out : !tt.ptr<f32> -> tensor<64x64x!tt.ptr<f32>>
%77 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%c64_i32_11 = arith.constant 64 : i32
%78 = arith.subi %arg5, %c64_i32_11 : i32
%79 = tt.splat %78 : i32 -> tensor<1x64xi32>
%80 = arith.cmpi slt, %77, %79 : tensor<1x64xi32>
%81 = tt.broadcast %80 : tensor<1x64xi1> -> tensor<64x64xi1>
%82 = tt.load %75, %81, %cst_9 : tensor<64x64x!tt.ptr<f32>>
%83 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%c64_i32_12 = arith.constant 64 : i32
%84 = arith.subi %arg5, %c64_i32_12 : i32
%85 = tt.splat %84 : i32 -> tensor<64x1xi32>
%86 = arith.cmpi slt, %83, %85 : tensor<64x1xi32>
%87 = tt.broadcast %86 : tensor<64x1xi1> -> tensor<64x64xi1>
%88 = tt.load %76, %87, %cst_8 : tensor<64x64x!tt.ptr<f32>>
%cst_13 = arith.constant 0.000000e+00 : f32
%89 = tt.splat %cst_13 : f32 -> tensor<64x64xf32>
%90 = tt.dot %82, %88, %89 : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32>
linalg.yield %out : !tt.ptr<f32>
} -> tensor<16x64x!tt.ptr<f32>>
scf.forall.in_parallel {
tensor.parallel_insert_slice %74 into %arg14[%72, %73] [16, 64] [1, 1] : tensor<16x64x!tt.ptr<f32>> into tensor<64x64x!tt.ptr<f32>>
}
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
%54 = tt.expand_dims %16 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%55 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x1x!tt.ptr<f32>>
%56 = tt.splat %arg10 : i32 -> tensor<64x1xi32>
%57 = arith.muli %56, %54 : tensor<64x1xi32>
%58 = tt.addptr %55, %57 : tensor<64x1x!tt.ptr<f32>>, tensor<64x1xi32>
%59 = tt.broadcast %58 : tensor<64x1x!tt.ptr<f32>> -> tensor<64x64x!tt.ptr<f32>>
%60 = tt.expand_dims %22 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%61 = tt.splat %arg11 : i32 -> tensor<1x64xi32>
%62 = arith.muli %61, %60 : tensor<1x64xi32>
%63 = tt.broadcast %62 : tensor<1x64xi32> -> tensor<64x64xi32>
%64 = tt.addptr %59, %63 : tensor<64x64x!tt.ptr<f32>>, tensor<64x64xi32>
%65 = tt.splat %arg3 : i32 -> tensor<64x1xi32>
%66 = arith.cmpi slt, %54, %65 : tensor<64x1xi32>
%67 = tt.broadcast %66 : tensor<64x1xi1> -> tensor<64x64xi1>
%68 = tt.splat %arg4 : i32 -> tensor<1x64xi32>
%69 = arith.cmpi slt, %60, %68 : tensor<1x64xi32>
%70 = tt.broadcast %69 : tensor<1x64xi1> -> tensor<64x64xi1>
%71 = arith.andi %67, %70 : tensor<64x64xi1>
tt.return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @cleanup(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
%1 = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_licm to %1 : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.yield
}
transform.named_sequence @main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [16, 64](mapping = [#gpu.thread<y>, #gpu.thread<x>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.include @cleanup failures(propagate) (%arg0) : (!transform.any_op) -> ()
transform.yield
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment