Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created January 13, 2025 09:45
Show Gist options
  • Save pashu123/7d857514119f493bf25e6a6139ca8634 to your computer and use it in GitHub Desktop.
Save pashu123/7d857514119f493bf25e6a6139ca8634 to your computer and use it in GitHub Desktop.
func.func @attention(%arg1: tensor<20x4096x64xf16>, %arg2: tensor<20x4096x64xf16>, %arg3: tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16> {
%c0 = arith.constant 0 : index
%scale = arith.constant 0.125 : f16
%7 = tensor.empty() : tensor<20x4096x64xf16>
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%arg1, %arg2, %arg3, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%7 : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
return %8 : tensor<20x4096x64xf16>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment