Skip to content

Instantly share code, notes, and snippets.

@malfet
Created August 21, 2025 19:23
Show Gist options
  • Save malfet/9e0cd9c0ce0efc2a7d736797efab0280 to your computer and use it in GitHub Desktop.
Save malfet/9e0cd9c0ce0efc2a7d736797efab0280 to your computer and use it in GitHub Desktop.
import torch
import os
os.environ["MTL_CAPTURE_ENABLED"]="1"
a = torch.ones(2, (1 << 31) + 5, dtype=torch.int8, device='mps')
index_0 = torch.tensor([0, -1, 0, 1], device=a.device)
index_1 = torch.tensor([-2, -1, 0, 1], device=a.device)
values = torch.tensor([12, 13, 10, 11], dtype=a.dtype, device=a.device)
with torch.mps.profiler.metal_capture("index_put"):
a.index_put_((index_0, index_1), values, accumulate=True)
b = a[1, -2].cpu()
c = a[:, -2].cpu()
print(b, c)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment