Created
January 16, 2025 14:49
-
-
Save pashu123/c21cecac12f18774cf060dd6a1e99eee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #map = affine_map<(d0, d1, d2, d3) -> ()> | |
| #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | |
| #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> | |
| #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)> | |
| #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> | |
| #map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> | |
| #map6 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | |
| module { | |
| func.func @attention(%arg0: tensor<1x128x32x64xf32>, %arg1: tensor<1x128x32x64xf32>, %arg2: tensor<1x128x32x64xf32>) -> tensor<1x128x32x64xf32> { | |
| %cst = arith.constant 1.250000e-01 : f32 | |
| %0 = tensor.empty() : tensor<1x128x32x64xf32> | |
| %1 = tensor.empty() : tensor<1x128x32x64xf32> | |
| %2 = tensor.empty() : tensor<1x128x32xf32> | |
| %cst_0 = arith.constant 0.000000e+00 : f32 | |
| %cst_1 = arith.constant -3.40282347E+38 : f32 | |
| %cst_2 = arith.constant 0.000000e+00 : f32 | |
| %3 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<1x128x32x64xf32>) -> tensor<1x128x32x64xf32> | |
| %4 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<1x128x32xf32>) -> tensor<1x128x32xf32> | |
| %5 = linalg.fill ins(%cst_2 : f32) outs(%2 : tensor<1x128x32xf32>) -> tensor<1x128x32xf32> | |
| %cst_3 = arith.constant 1.44269502 : f32 | |
| %6 = arith.mulf %cst, %cst_3 : f32 | |
| %7 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : f32) outs(%arg0 : tensor<1x128x32x64xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.mulf %in, %out : f32 | |
| linalg.yield %20 : f32 | |
| } -> tensor<1x128x32x64xf32> | |
| %8 = tensor.empty() : tensor<1x128x32x32xf32> | |
| %cst_4 = arith.constant 0.000000e+00 : f32 | |
| %9 = linalg.fill ins(%cst_4 : f32) outs(%8 : tensor<1x128x32x32xf32>) -> tensor<1x128x32x32xf32> | |
| %10 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%7, %arg1 : tensor<1x128x32x64xf32>, tensor<1x128x32x64xf32>) outs(%9 : tensor<1x128x32x32xf32>) { | |
| ^bb0(%in: f32, %in_5: f32, %out: f32): | |
| %20 = arith.mulf %in, %in_5 : f32 | |
| %21 = arith.addf %20, %out : f32 | |
| linalg.yield %21 : f32 | |
| } -> tensor<1x128x32x32xf32> | |
| %11 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%10 : tensor<1x128x32x32xf32>) { | |
| ^bb0(%out: f32): | |
| linalg.yield %out : f32 | |
| } -> tensor<1x128x32x32xf32> | |
| %12 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11 : tensor<1x128x32x32xf32>) outs(%4 : tensor<1x128x32xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.maximumf %in, %out : f32 | |
| linalg.yield %20 : f32 | |
| } -> tensor<1x128x32xf32> | |
| %13 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<1x128x32xf32>) outs(%4 : tensor<1x128x32xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.subf %out, %in : f32 | |
| %21 = math.exp2 %20 : f32 | |
| linalg.yield %21 : f32 | |
| } -> tensor<1x128x32xf32> | |
| %14 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<1x128x32xf32>) outs(%5 : tensor<1x128x32xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.mulf %in, %out : f32 | |
| linalg.yield %20 : f32 | |
| } -> tensor<1x128x32xf32> | |
| %15 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x128x32xf32>) outs(%11 : tensor<1x128x32x32xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.subf %out, %in : f32 | |
| %21 = math.exp2 %20 : f32 | |
| linalg.yield %21 : f32 | |
| } -> tensor<1x128x32x32xf32> | |
| %16 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%15 : tensor<1x128x32x32xf32>) outs(%14 : tensor<1x128x32xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.addf %in, %out : f32 | |
| linalg.yield %20 : f32 | |
| } -> tensor<1x128x32xf32> | |
| %17 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x128x32xf32>) outs(%3 : tensor<1x128x32x64xf32>) { | |
| ^bb0(%in: f32, %out: f32): | |
| %20 = arith.mulf %in, %out : f32 | |
| linalg.yield %20 : f32 | |
| } -> tensor<1x128x32x64xf32> | |
| %18 = linalg.generic {indexing_maps = [#map4, #map3, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%15, %arg2 : tensor<1x128x32x32xf32>, tensor<1x128x32x64xf32>) outs(%17 : tensor<1x128x32x64xf32>) { | |
| ^bb0(%in: f32, %in_5: f32, %out: f32): | |
| %20 = arith.mulf %in, %in_5 : f32 | |
| %21 = arith.addf %20, %out : f32 | |
| linalg.yield %21 : f32 | |
| } -> tensor<1x128x32x64xf32> | |
| %19 = linalg.generic {indexing_maps = [#map5, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%16, %18 : tensor<1x128x32xf32>, tensor<1x128x32x64xf32>) outs(%0 : tensor<1x128x32x64xf32>) { | |
| ^bb0(%in: f32, %in_5: f32, %out: f32): | |
| %cst_6 = arith.constant 1.000000e+00 : f32 | |
| %20 = arith.divf %cst_6, %in : f32 | |
| %21 = arith.mulf %20, %in_5 : f32 | |
| linalg.yield %21 : f32 | |
| } -> tensor<1x128x32x64xf32> | |
| return %19 : tensor<1x128x32x64xf32> | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment