Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 17, 2025 18:02
Show Gist options
  • Save pashu123/1d471079d2da92b3a1ce450ad28a07f5 to your computer and use it in GitHub Desktop.
Save pashu123/1d471079d2da92b3a1ce450ad28a07f5 to your computer and use it in GitHub Desktop.
module {
func.func @flash_attention_func(%arg0: !torch.vtensor<[32,8,?,32],f16>, %arg1: !torch.vtensor<[32,8,?,32],f16>, %arg2: !torch.vtensor<[32,8,?,32],f16>) -> (!torch.vtensor<[32,8,?,32],f16>, !torch.vtensor<[32,8,?],f32>) {
%float0.000000e00 = torch.constant.float 0.000000e+00
%true = torch.constant.bool true
%none = torch.constant.none
%none_0 = torch.constant.none
%0:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%arg0, %arg1, %arg2, %float0.000000e00, %true, %none, %none_0) : (!torch.vtensor<[32,8,?,32],f16>, !torch.vtensor<[32,8,?,32],f16>, !torch.vtensor<[32,8,?,32],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[32,8,?,32],f16>, !torch.vtensor<[32,8,?],f32>)
return %0#0, %0#1 : !torch.vtensor<[32,8,?,32],f16>, !torch.vtensor<[32,8,?],f32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment