Created
January 16, 2025 15:05
-
-
Save pashu123/03d8ea3c1b881e42857e0fee4a524c86 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#map = affine_map<(d0, d1, d2, d3) -> ()> | |
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | |
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> | |
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> | |
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> | |
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> | |
#map6 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | |
module { | |
func.func @attention(%arg0: tensor<1x128x32x64xbf16>, %arg1: tensor<1x128x32x64xbf16>, %arg2: tensor<1x128x32x64xbf16>) -> tensor<1x128x32x64xbf16> { | |
%cst = arith.constant 1.250000e-01 : bf16 | |
%0 = tensor.empty() : tensor<1x128x32x64xbf16> | |
%1 = tensor.empty() : tensor<1x128x32x64xf32> | |
%2 = tensor.empty() : tensor<1x128x32xf32> | |
%cst_0 = arith.constant 0.000000e+00 : f32 | |
%cst_1 = arith.constant -3.40282347E+38 : f32 | |
%cst_2 = arith.constant 0.000000e+00 : f32 | |
%3 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<1x128x32x64xf32>) -> tensor<1x128x32x64xf32> | |
%4 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<1x128x32xf32>) -> tensor<1x128x32xf32> | |
%5 = linalg.fill ins(%cst_2 : f32) outs(%2 : tensor<1x128x32xf32>) -> tensor<1x128x32xf32> | |
%cst_3 = arith.constant 1.445310e+00 : bf16 | |
%6 = arith.mulf %cst, %cst_3 : bf16 | |
%7 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : bf16) outs(%arg0 : tensor<1x128x32x64xbf16>) { | |
^bb0(%in: bf16, %out: bf16): | |
%22 = arith.mulf %in, %out : bf16 | |
linalg.yield %22 : bf16 | |
} -> tensor<1x128x32x64xbf16> | |
%8 = tensor.empty() : tensor<1x128x32x32xf32> | |
%cst_4 = arith.constant 0.000000e+00 : f32 | |
%9 = linalg.fill ins(%cst_4 : f32) outs(%8 : tensor<1x128x32x32xf32>) -> tensor<1x128x32x32xf32> | |
%10 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%7, %arg1 : tensor<1x128x32x64xbf16>, tensor<1x128x32x64xbf16>) outs(%9 : tensor<1x128x32x32xf32>) { | |
^bb0(%in: bf16, %in_5: bf16, %out: f32): | |
%22 = arith.extf %in : bf16 to f32 | |
%23 = arith.extf %in_5 : bf16 to f32 | |
%24 = arith.mulf %22, %23 : f32 | |
%25 = arith.addf %24, %out : f32 | |
linalg.yield %25 : f32 | |
} -> tensor<1x128x32x32xf32> | |
%11 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%10 : tensor<1x128x32x32xf32>) { | |
^bb0(%out: f32): | |
linalg.yield %out : f32 | |
} -> tensor<1x128x32x32xf32> | |
%12 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11 : tensor<1x128x32x32xf32>) outs(%4 : tensor<1x128x32xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%22 = arith.maximumf %in, %out : f32 | |
linalg.yield %22 : f32 | |
} -> tensor<1x128x32xf32> | |
%13 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<1x128x32xf32>) outs(%4 : tensor<1x128x32xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%22 = arith.subf %out, %in : f32 | |
%23 = math.exp2 %22 : f32 | |
linalg.yield %23 : f32 | |
} -> tensor<1x128x32xf32> | |
%14 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<1x128x32xf32>) outs(%5 : tensor<1x128x32xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%22 = arith.mulf %in, %out : f32 | |
linalg.yield %22 : f32 | |
} -> tensor<1x128x32xf32> | |
%15 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x128x32xf32>) outs(%11 : tensor<1x128x32x32xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%22 = arith.subf %out, %in : f32 | |
%23 = math.exp2 %22 : f32 | |
linalg.yield %23 : f32 | |
} -> tensor<1x128x32x32xf32> | |
%16 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%15 : tensor<1x128x32x32xf32>) outs(%14 : tensor<1x128x32xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%22 = arith.addf %in, %out : f32 | |
linalg.yield %22 : f32 | |
} -> tensor<1x128x32xf32> | |
%17 = tensor.empty() : tensor<1x128x32x32xbf16> | |
%18 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%15 : tensor<1x128x32x32xf32>) outs(%17 : tensor<1x128x32x32xbf16>) { | |
^bb0(%in: f32, %out: bf16): | |
%22 = arith.truncf %in : f32 to bf16 | |
linalg.yield %22 : bf16 | |
} -> tensor<1x128x32x32xbf16> | |
%19 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x128x32xf32>) outs(%3 : tensor<1x128x32x64xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%22 = arith.mulf %in, %out : f32 | |
linalg.yield %22 : f32 | |
} -> tensor<1x128x32x64xf32> | |
%20 = linalg.generic {indexing_maps = [#map4, #map3, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%18, %arg2 : tensor<1x128x32x32xbf16>, tensor<1x128x32x64xbf16>) outs(%19 : tensor<1x128x32x64xf32>) { | |
^bb0(%in: bf16, %in_5: bf16, %out: f32): | |
%22 = arith.extf %in : bf16 to f32 | |
%23 = arith.extf %in_5 : bf16 to f32 | |
%24 = arith.mulf %22, %23 : f32 | |
%25 = arith.addf %24, %out : f32 | |
linalg.yield %25 : f32 | |
} -> tensor<1x128x32x64xf32> | |
%21 = linalg.generic {indexing_maps = [#map5, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%16, %20 : tensor<1x128x32xf32>, tensor<1x128x32x64xf32>) outs(%0 : tensor<1x128x32x64xbf16>) { | |
^bb0(%in: f32, %in_5: f32, %out: bf16): | |
%cst_6 = arith.constant 1.000000e+00 : f32 | |
%22 = arith.divf %cst_6, %in : f32 | |
%23 = arith.mulf %22, %in_5 : f32 | |
%24 = arith.truncf %23 : f32 to bf16 | |
linalg.yield %24 : bf16 | |
} -> tensor<1x128x32x64xbf16> | |
return %21 : tensor<1x128x32x64xbf16> | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment