Skip to content

Instantly share code, notes, and snippets.

@bjacob
Last active March 5, 2021 21:23
Show Gist options
  • Save bjacob/57c1d200d866f5b49e5a43a59ffb2282 to your computer and use it in GitHub Desktop.
Save bjacob/57c1d200d866f5b49e5a43a59ffb2282 to your computer and use it in GitHub Desktop.
diff --git a/google3/third_party/llvm/llvm-project/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py b/google3/third_party/llvm/llvm-project/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py
--- a/google3/third_party/llvm/llvm-project/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py
+++ b/google3/third_party/llvm/llvm-project/mlir/lib/Bindings/Python/mlir/tools/linalg_opdsl/ops/core_named_ops.py
@@ -68,3 +68,37 @@ def dot(A=TensorDef(T1, S.M), B=TensorDe
"""
implements(ContractionOpInterface)
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+
+@linalg_structured_op
+def mmt_4d_kernel(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
+ rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0),
+ accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)):
+ """A lowering path for linalg.matmul towards efficient code generation on CPU.
+ The differences from linalg.matmul are:
+ * The right hand side is transposed, whence the 't' in 'mmt'. In other words, this op computes
+ `accumulator + lhs * transpose(rhs)` instead of `accumulator + lhs * rhs`. This transposition
+ brings RHS on an equal footing as LHS from the perspective of an efficient implementation:
+ now both are traversed row-wise by the inner accumulation loop, so we want the same
+ row-major layouts for both LHS and RHS. Without that transposition, the below discussion of
+ layouts would be complicated by having to describe LHS and RHS separately, and the actual
+ code would be accordingly more complicated.
+ * The input and output tensors have a 4D shape instead of a 2D shape. They are interpreted
+ as 2D matrices with one level of 2D tile subdivision, whence the 2+2=4 dimensions.
+ The inner tile dimensions are identified with '0' suffixes below, for instance the LHS
+ matrix shape (M, K, M0, K0) reads as: MxK tiles, each of shape M0xK0.
+ * **(The whole point)** Whereas linalg.matmul is agnostic as to the actual layout of its input tensors, this op
+ comes with a recommendation that its input tensors be bufferized into a row-major layout
+ (meaning that the last-enumerated dimension is contiguous in memory),
+ and with no inner striding (meaning no striding except possibly in the outermost dimension).
+ Because of the 4D shape encoding a level of 2D tile subdivision as described above, this
+ row-major layout of the 4D tensor effectively means a tiled layout.
+ So whereas linalg.matmul is a high-level op not prescribing details of efficient
+ implementation, this op is making such a specific recommendation. By the time a matmul has
+ been lowered to this op, the effective layouts of the buffers to be consumed by the matmul
+ kernel are determined. Namely, the parameters controlling these layouts are the M0, K0, N0
+ values occurring in the shapes of the inputs of this op. They are to be determined purely
+ by the CPU ISA, specifically by the shape of the SIMD instructions to be
+ used by the kernel and by the shape of the register space of this SIMD ISA.
+ """
+ implements(ContractionOpInterface)
+ C[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment