Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 30, 2024 17:00
Show Gist options
  • Save pashu123/020217a35f1c643ed03b169ce41f68d9 to your computer and use it in GitHub Desktop.
Save pashu123/020217a35f1c643ed03b169ce41f68d9 to your computer and use it in GitHub Desktop.
module {
func.func @decode_bs4(%arg0: !torch.vtensor<[4,?],si64>, %arg1: !torch.vtensor<[128256,4096],f16>) -> !torch.vtensor<[4,?,4096],f32> {
%false = torch.constant.bool false
%false_0 = torch.constant.bool false
%int-1 = torch.constant.int -1
%int6 = torch.constant.int 6
%0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
%1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
return %1 : !torch.vtensor<[4,?,4096],f32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment