Created
February 11, 2025 07:41
-
-
Save pashu123/7dd521b5fefb663fce2a9d247d484998 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#map = affine_map<(d0, d1, d2, d3) -> ()> | |
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | |
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> | |
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> | |
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> | |
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> | |
#map6 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | |
module { | |
func.func @attention(%arg0: tensor<1x32x64x96xf32>, %arg1: tensor<1x32x64x96xf32>, %arg2: tensor<1x32x64x96xf32>) -> tensor<1x32x64x96xf32> { | |
%cst = arith.constant 1.250000e-01 : f32 | |
%0 = tensor.empty() : tensor<1x32x64x96xf32> | |
%1 = tensor.empty() : tensor<1x32x64x96xf32> | |
%2 = tensor.empty() : tensor<1x32x64xf32> | |
%cst_0 = arith.constant 0.000000e+00 : f32 | |
%cst_1 = arith.constant -3.40282347E+38 : f32 | |
%cst_2 = arith.constant 0.000000e+00 : f32 | |
%3 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<1x32x64x96xf32>) -> tensor<1x32x64x96xf32> | |
%4 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<1x32x64xf32>) -> tensor<1x32x64xf32> | |
%5 = linalg.fill ins(%cst_2 : f32) outs(%2 : tensor<1x32x64xf32>) -> tensor<1x32x64xf32> | |
%cst_3 = arith.constant 1.44269502 : f32 | |
%6 = arith.mulf %cst, %cst_3 : f32 | |
%7 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : f32) outs(%arg0 : tensor<1x32x64x96xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.mulf %in, %out : f32 | |
linalg.yield %20 : f32 | |
} -> tensor<1x32x64x96xf32> | |
%8 = tensor.empty() : tensor<1x32x64x64xf32> | |
%cst_4 = arith.constant 0.000000e+00 : f32 | |
%9 = linalg.fill ins(%cst_4 : f32) outs(%8 : tensor<1x32x64x64xf32>) -> tensor<1x32x64x64xf32> | |
%10 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%7, %arg1 : tensor<1x32x64x96xf32>, tensor<1x32x64x96xf32>) outs(%9 : tensor<1x32x64x64xf32>) { | |
^bb0(%in: f32, %in_5: f32, %out: f32): | |
%20 = arith.mulf %in, %in_5 : f32 | |
%21 = arith.addf %20, %out : f32 | |
linalg.yield %21 : f32 | |
} -> tensor<1x32x64x64xf32> | |
%11 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%10 : tensor<1x32x64x64xf32>) { | |
^bb0(%out: f32): | |
linalg.yield %out : f32 | |
} -> tensor<1x32x64x64xf32> | |
%12 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11 : tensor<1x32x64x64xf32>) outs(%4 : tensor<1x32x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.maximumf %in, %out : f32 | |
linalg.yield %20 : f32 | |
} -> tensor<1x32x64xf32> | |
%13 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<1x32x64xf32>) outs(%4 : tensor<1x32x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.subf %out, %in : f32 | |
%21 = math.exp2 %20 : f32 | |
linalg.yield %21 : f32 | |
} -> tensor<1x32x64xf32> | |
%14 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<1x32x64xf32>) outs(%5 : tensor<1x32x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.mulf %in, %out : f32 | |
linalg.yield %20 : f32 | |
} -> tensor<1x32x64xf32> | |
%15 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x32x64xf32>) outs(%11 : tensor<1x32x64x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.subf %out, %in : f32 | |
%21 = math.exp2 %20 : f32 | |
linalg.yield %21 : f32 | |
} -> tensor<1x32x64x64xf32> | |
%16 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%15 : tensor<1x32x64x64xf32>) outs(%14 : tensor<1x32x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.addf %in, %out : f32 | |
linalg.yield %20 : f32 | |
} -> tensor<1x32x64xf32> | |
%17 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x32x64xf32>) outs(%3 : tensor<1x32x64x96xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%20 = arith.mulf %in, %out : f32 | |
linalg.yield %20 : f32 | |
} -> tensor<1x32x64x96xf32> | |
%18 = linalg.generic {indexing_maps = [#map4, #map3, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%15, %arg2 : tensor<1x32x64x64xf32>, tensor<1x32x64x96xf32>) outs(%17 : tensor<1x32x64x96xf32>) { | |
^bb0(%in: f32, %in_5: f32, %out: f32): | |
%20 = arith.mulf %in, %in_5 : f32 | |
%21 = arith.addf %20, %out : f32 | |
linalg.yield %21 : f32 | |
} -> tensor<1x32x64x96xf32> | |
%19 = linalg.generic {indexing_maps = [#map5, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%16, %18 : tensor<1x32x64xf32>, tensor<1x32x64x96xf32>) outs(%0 : tensor<1x32x64x96xf32>) { | |
^bb0(%in: f32, %in_5: f32, %out: f32): | |
%cst_6 = arith.constant 1.000000e+00 : f32 | |
%20 = arith.divf %cst_6, %in : f32 | |
%21 = arith.mulf %20, %in_5 : f32 | |
linalg.yield %21 : f32 | |
} -> tensor<1x32x64x96xf32> | |
return %19 : tensor<1x32x64x96xf32> | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment