Skip to content

Instantly share code, notes, and snippets.

@bjacob
Created March 3, 2021 16:31
Show Gist options
  • Save bjacob/9ee064564657840bb348fb1a9dcde3bd to your computer and use it in GitHub Desktop.
Save bjacob/9ee064564657840bb348fb1a9dcde3bd to your computer and use it in GitHub Desktop.

yaml:

--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: mmt_kernel
  cpp_op_name: MmtKernelOp
  doc: |-
    WRITE ME
  implements:
  - LinalgContractionOpInterface
structured_op: !LinalgStructuredOpConfig
  args:
  - !<LinalgTensorDef>
    name: lhs
    usage: input
    shape: affine_map<()[M, N, K, M0, N0, K0] -> (M, K, M0, K0)>
    element_type_var: LhsType
  - !<LinalgTensorDef>
    name: rhs
    usage: input
    shape: affine_map<()[M, N, K, M0, N0, K0] -> (N, K, N0, K0)>
    element_type_var: RhsType
  - !<LinalgTensorDef>
    name: accum
    usage: output
    shape: affine_map<()[M, N, K, M0, N0, K0] -> (M, N, M0, N0)>
    element_type_var: AccumType
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<(m, n, k, m0, n0, k0)[M, N, K, M0, N0, K0] -> (m, k, m0, k0)>
    - affine_map<(m, n, k, m0, n0, k0)[M, N, K, M0, N0, K0] -> (n, k, n0, k0)>
    - affine_map<(m, n, k, m0, n0, k0)[M, N, K, M0, N0, K0] -> (m, n, m0, n0)>
  iterator_types:
  - parallel
  - parallel
  - reduction
  - parallel
  - parallel
  - reduction
  assignments:
  - !ScalarAssign
    arg: accum
    value: !ScalarExpression
      scalar_apply:
        fn_name: add
        operands:
        - !ScalarExpression
          scalar_arg: accum
        - !ScalarExpression
          scalar_apply:
            fn_name: mul
            operands:
            - !ScalarExpression
              symbolic_cast:
                type_var: AccumType
                operands:
                - !ScalarExpression
                  scalar_arg: lhs
            - !ScalarExpression
              symbolic_cast:
                type_var: AccumType
                operands:
                - !ScalarExpression
                  scalar_arg: rhs

Output:

benoitjacob@benoitjacob:~/iree-build-linux$ ./third_party/llvm-project/llvm/bin/mlir-opt ~/foo.mlir -split-input-file -linalg-generalize-named-ops
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
module  {
  func @testfunc(%arg0: tensor<?x?x8x4xi8>, %arg1: tensor<?x?x8x4xi8>, %arg2: tensor<?x?x8x8xi32>) -> tensor<?x?x8x8xi32> {
    %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x8x4xi8>, tensor<?x?x8x4xi8>) outs(%arg2 : tensor<?x?x8x8xi32>) {
    ^bb0(%arg3: i8, %arg4: i8, %arg5: i32):  // no predecessors
      %1 = sexti %arg3 : i8 to i32
      %2 = sexti %arg4 : i8 to i32
      %3 = muli %1, %2 : i32
      %4 = addi %arg5, %3 : i32
      linalg.yield %4 : i32
    } -> tensor<?x?x8x8xi32>
    return %0 : tensor<?x?x8x8xi32>
  }
}

benoitjacob@benoitjacob:~/iree-build-linux$ ./third_party/llvm-project/llvm/bin/mlir-opt ~/foo.mlir -split-input-file -linalg-generalize-named-ops -linalg-bufferize
#map0 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
module  {
  func @testfunc(%arg0: tensor<?x?x8x4xi8>, %arg1: tensor<?x?x8x4xi8>, %arg2: tensor<?x?x8x8xi32>) -> tensor<?x?x8x8xi32> {
    %0 = tensor_to_memref %arg0 : memref<?x?x8x4xi8>
    %1 = tensor_to_memref %arg1 : memref<?x?x8x4xi8>
    %2 = tensor_to_memref %arg2 : memref<?x?x8x8xi32>
    %c0 = constant 0 : index
    %3 = dim %2, %c0 : memref<?x?x8x8xi32>
    %c1 = constant 1 : index
    %4 = dim %2, %c1 : memref<?x?x8x8xi32>
    %5 = alloc(%3, %4) : memref<?x?x8x8xi32>
    linalg.copy(%2, %5) : memref<?x?x8x8xi32>, memref<?x?x8x8xi32> 
    linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%0, %1 : memref<?x?x8x4xi8>, memref<?x?x8x4xi8>) outs(%5 : memref<?x?x8x8xi32>) {
    ^bb0(%arg3: i8, %arg4: i8, %arg5: i32):  // no predecessors
      %7 = sexti %arg3 : i8 to i32
      %8 = sexti %arg4 : i8 to i32
      %9 = muli %7, %8 : i32
      %10 = addi %arg5, %9 : i32
      linalg.yield %10 : i32
    }
    %6 = tensor_load %5 : memref<?x?x8x8xi32>
    return %6 : tensor<?x?x8x8xi32>
  }
}

benoitjacob@benoitjacob:~/iree-build-linux$ ./third_party/llvm-project/llvm/bin/mlir-opt ~/foo.mlir -split-input-file -linalg-generalize-named-ops -linalg-bufferize -convert-linalg-to-loops
module  {
  func @testfunc(%arg0: tensor<?x?x8x4xi8>, %arg1: tensor<?x?x8x4xi8>, %arg2: tensor<?x?x8x8xi32>) -> tensor<?x?x8x8xi32> {
    %c4 = constant 4 : index
    %c8 = constant 8 : index
    %c0 = constant 0 : index
    %c1 = constant 1 : index
    %0 = tensor_to_memref %arg0 : memref<?x?x8x4xi8>
    %1 = tensor_to_memref %arg1 : memref<?x?x8x4xi8>
    %2 = tensor_to_memref %arg2 : memref<?x?x8x8xi32>
    %3 = dim %arg2, %c0 : tensor<?x?x8x8xi32>
    %4 = dim %arg2, %c1 : tensor<?x?x8x8xi32>
    %5 = alloc(%3, %4) : memref<?x?x8x8xi32>
    %6 = dim %arg2, %c0 : tensor<?x?x8x8xi32>
    %7 = dim %arg2, %c1 : tensor<?x?x8x8xi32>
    scf.for %arg3 = %c0 to %6 step %c1 {
      scf.for %arg4 = %c0 to %7 step %c1 {
        scf.for %arg5 = %c0 to %c8 step %c1 {
          scf.for %arg6 = %c0 to %c8 step %c1 {
            %12 = load %2[%arg3, %arg4, %arg5, %arg6] : memref<?x?x8x8xi32>
            store %12, %5[%arg3, %arg4, %arg5, %arg6] : memref<?x?x8x8xi32>
          }
        }
      }
    }
    %8 = dim %arg0, %c0 : tensor<?x?x8x4xi8>
    %9 = dim %arg0, %c1 : tensor<?x?x8x4xi8>
    %10 = dim %arg1, %c0 : tensor<?x?x8x4xi8>
    scf.for %arg3 = %c0 to %8 step %c1 {
      scf.for %arg4 = %c0 to %10 step %c1 {
        scf.for %arg5 = %c0 to %9 step %c1 {
          scf.for %arg6 = %c0 to %c8 step %c1 {
            scf.for %arg7 = %c0 to %c8 step %c1 {
              scf.for %arg8 = %c0 to %c4 step %c1 {
                %12 = load %0[%arg3, %arg5, %arg6, %arg8] : memref<?x?x8x4xi8>
                %13 = load %1[%arg4, %arg5, %arg7, %arg8] : memref<?x?x8x4xi8>
                %14 = load %5[%arg3, %arg4, %arg6, %arg7] : memref<?x?x8x8xi32>
                %15 = sexti %12 : i8 to i32
                %16 = sexti %13 : i8 to i32
                %17 = muli %15, %16 : i32
                %18 = addi %14, %17 : i32
                store %18, %5[%arg3, %arg4, %arg6, %arg7] : memref<?x?x8x8xi32>
              }
            }
          }
        }
      }
    }
    %11 = tensor_load %5 : memref<?x?x8x8xi32>
    return %11 : tensor<?x?x8x8xi32>
  }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment