Skip to content

Instantly share code, notes, and snippets.

@liangfu
Created November 26, 2024 19:51
Show Gist options
  • Save liangfu/01cc590ffecfc514943796f7c3002193 to your computer and use it in GitHub Desktop.
Save liangfu/01cc590ffecfc514943796f7c3002193 to your computer and use it in GitHub Desktop.
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.custom_kernel # Required to register custom ops.
class PallasAttentionBackend:
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
def main():
use_torch_compile = True
device = xm.xla_device()
kv_caches = [(
torch.arange(4).reshape(1,4).expand(4,4).to(device=device),
torch.arange(4).reshape(4,1).expand(4,4).to(device=device),
),]
src_to_dists = (
torch.tensor(1, device=device), torch.tensor(2, device=device)
)
print(kv_caches)
if use_torch_compile:
PallasAttentionBackend.copy_blocks(kv_caches, (torch.tensor(0, device=device), torch.tensor(1, device=device)))
else:
compiled_code = torch.compile(PallasAttentionBackend.copy_blocks, backend='openxla')
compiled_code(kv_caches, src_to_dists)
print(kv_caches)
if __name__=="__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment