Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created December 7, 2022 15:15
Show Gist options
  • Save pashu123/e44ce15ff8cca0d4518a4fcc841181ef to your computer and use it in GitHub Desktop.
Save pashu123/e44ce15ff8cca0d4518a4fcc841181ef to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: tensor<10x9216x9216xf16>, %arg1: tensor<1xf16>, %arg2: tensor<2x77x1024xf16>, %arg3: tensor<f32>) -> tensor<10x9216x9216xf16> {
%cst = arith.constant 0.000000e+00 : f16
%0 = tensor.empty() : tensor<10x9216x9216xf16>
%1 = tensor.empty() : tensor<10x9216x1xf16>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<10x9216x9216xf16>) outs(%0 : tensor<10x9216x9216xf16>) {
^bb0(%in: f16, %out: f16):
%6 = math.exp %in : f16
linalg.yield %6 : f16
} -> tensor<10x9216x9216xf16>
%3 = linalg.fill ins(%cst : f16) outs(%1 : tensor<10x9216x1xf16>) -> tensor<10x9216x1xf16>
%4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2 : tensor<10x9216x9216xf16>) outs(%3 : tensor<10x9216x1xf16>) {
^bb0(%in: f16, %out: f16):
%6 = arith.addf %in, %out : f16
linalg.yield %6 : f16
} -> tensor<10x9216x1xf16>
%5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2, %4 : tensor<10x9216x9216xf16>, tensor<10x9216x1xf16>) outs(%0 : tensor<10x9216x9216xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%6 = arith.divf %in, %in_0 : f16
linalg.yield %6 : f16
} -> tensor<10x9216x9216xf16>
return %5 : tensor<10x9216x9216xf16>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment