Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created January 16, 2025 14:44
Show Gist options
  • Save pashu123/10780e337b22db84de618ec95367e704 to your computer and use it in GitHub Desktop.
Save pashu123/10780e337b22db84de618ec95367e704 to your computer and use it in GitHub Desktop.
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
module {
func.func @attention(%arg0: tensor<1x128x32x64xf32>, %arg1: tensor<1x128x32x64xf32>, %arg2: tensor<1x128x32x64xf32>) -> tensor<1x128x32x64xf32> {
%cst = arith.constant 1.250000e-01 : f32
%0 = tensor.empty() : tensor<1x128x32x64xf32>
%1 = iree_linalg_ext.attention {indexing_maps = [#map, #map1, #map2, #map3, #map4]} ins(%arg0, %arg1, %arg2, %cst : tensor<1x128x32x64xf32>, tensor<1x128x32x64xf32>, tensor<1x128x32x64xf32>, f32) outs(%0 : tensor<1x128x32x64xf32>) {
^bb0(%arg4: f32):
iree_linalg_ext.yield %arg4 : f32
} -> tensor<1x128x32x64xf32>
return %1 : tensor<1x128x32x64xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment