Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 11, 2025 07:41
Show Gist options
  • Save pashu123/7dd521b5fefb663fce2a9d247d484998 to your computer and use it in GitHub Desktop.
Save pashu123/7dd521b5fefb663fce2a9d247d484998 to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1, d2, d3) -> ()>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
func.func @attention(%arg0: tensor<1x32x64x96xf32>, %arg1: tensor<1x32x64x96xf32>, %arg2: tensor<1x32x64x96xf32>) -> tensor<1x32x64x96xf32> {
%cst = arith.constant 1.250000e-01 : f32
%0 = tensor.empty() : tensor<1x32x64x96xf32>
%1 = tensor.empty() : tensor<1x32x64x96xf32>
%2 = tensor.empty() : tensor<1x32x64xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant -3.40282347E+38 : f32
%cst_2 = arith.constant 0.000000e+00 : f32
%3 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<1x32x64x96xf32>) -> tensor<1x32x64x96xf32>
%4 = linalg.fill ins(%cst_1 : f32) outs(%2 : tensor<1x32x64xf32>) -> tensor<1x32x64xf32>
%5 = linalg.fill ins(%cst_2 : f32) outs(%2 : tensor<1x32x64xf32>) -> tensor<1x32x64xf32>
%cst_3 = arith.constant 1.44269502 : f32
%6 = arith.mulf %cst, %cst_3 : f32
%7 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : f32) outs(%arg0 : tensor<1x32x64x96xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.mulf %in, %out : f32
linalg.yield %20 : f32
} -> tensor<1x32x64x96xf32>
%8 = tensor.empty() : tensor<1x32x64x64xf32>
%cst_4 = arith.constant 0.000000e+00 : f32
%9 = linalg.fill ins(%cst_4 : f32) outs(%8 : tensor<1x32x64x64xf32>) -> tensor<1x32x64x64xf32>
%10 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%7, %arg1 : tensor<1x32x64x96xf32>, tensor<1x32x64x96xf32>) outs(%9 : tensor<1x32x64x64xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%20 = arith.mulf %in, %in_5 : f32
%21 = arith.addf %20, %out : f32
linalg.yield %21 : f32
} -> tensor<1x32x64x64xf32>
%11 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%10 : tensor<1x32x64x64xf32>) {
^bb0(%out: f32):
linalg.yield %out : f32
} -> tensor<1x32x64x64xf32>
%12 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11 : tensor<1x32x64x64xf32>) outs(%4 : tensor<1x32x64xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.maximumf %in, %out : f32
linalg.yield %20 : f32
} -> tensor<1x32x64xf32>
%13 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<1x32x64xf32>) outs(%4 : tensor<1x32x64xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.subf %out, %in : f32
%21 = math.exp2 %20 : f32
linalg.yield %21 : f32
} -> tensor<1x32x64xf32>
%14 = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel", "parallel"]} ins(%13 : tensor<1x32x64xf32>) outs(%5 : tensor<1x32x64xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.mulf %in, %out : f32
linalg.yield %20 : f32
} -> tensor<1x32x64xf32>
%15 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12 : tensor<1x32x64xf32>) outs(%11 : tensor<1x32x64x64xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.subf %out, %in : f32
%21 = math.exp2 %20 : f32
linalg.yield %21 : f32
} -> tensor<1x32x64x64xf32>
%16 = linalg.generic {indexing_maps = [#map1, #map5], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%15 : tensor<1x32x64x64xf32>) outs(%14 : tensor<1x32x64xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.addf %in, %out : f32
linalg.yield %20 : f32
} -> tensor<1x32x64xf32>
%17 = linalg.generic {indexing_maps = [#map5, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<1x32x64xf32>) outs(%3 : tensor<1x32x64x96xf32>) {
^bb0(%in: f32, %out: f32):
%20 = arith.mulf %in, %out : f32
linalg.yield %20 : f32
} -> tensor<1x32x64x96xf32>
%18 = linalg.generic {indexing_maps = [#map4, #map3, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%15, %arg2 : tensor<1x32x64x64xf32>, tensor<1x32x64x96xf32>) outs(%17 : tensor<1x32x64x96xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%20 = arith.mulf %in, %in_5 : f32
%21 = arith.addf %20, %out : f32
linalg.yield %21 : f32
} -> tensor<1x32x64x96xf32>
%19 = linalg.generic {indexing_maps = [#map5, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%16, %18 : tensor<1x32x64xf32>, tensor<1x32x64x96xf32>) outs(%0 : tensor<1x32x64x96xf32>) {
^bb0(%in: f32, %in_5: f32, %out: f32):
%cst_6 = arith.constant 1.000000e+00 : f32
%20 = arith.divf %cst_6, %in : f32
%21 = arith.mulf %20, %in_5 : f32
linalg.yield %21 : f32
} -> tensor<1x32x64x96xf32>
return %19 : tensor<1x32x64x96xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment