Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created December 12, 2022 17:36
Show Gist options
  • Save pashu123/4e8d358a9ad28cb547192fe32cc9f8e5 to your computer and use it in GitHub Desktop.
Save pashu123/4e8d358a9ad28cb547192fe32cc9f8e5 to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: tensor<10x4096x64xf16>, %arg1: tensor<10x64x4096xf16>) -> tensor<10x4096x4096xf16> {
%cst = arith.constant 0.000000e+00 : f16
%cst_0 = arith.constant 1.250000e-01 : f16
%0 = tensor.empty() : tensor<10x4096x4096xf16>
%1 = linalg.fill ins(%cst : f16) outs(%0 : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16>
%2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<10x4096x64xf16>, tensor<10x64x4096xf16>) outs(%1 : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16>
%3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<10x4096x4096xf16>) outs(%0 : tensor<10x4096x4096xf16>) {
^bb0(%in: f16, %out: f16):
%5 = arith.mulf %in, %cst_0 : f16
linalg.yield %5 : f16
} -> tensor<10x4096x4096xf16>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%3, %0 : tensor<10x4096x4096xf16>, tensor<10x4096x4096xf16>) outs(%0 : tensor<10x4096x4096xf16>) {
^bb0(%in: f16, %in_1: f16, %out: f16):
%5 = arith.mulf %in_1, %cst : f16
%6 = arith.addf %in, %5 : f16
linalg.yield %6 : f16
} -> tensor<10x4096x4096xf16>
return %4 : tensor<10x4096x4096xf16>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment