Mar 10, 2024.
See pytorch.org/xla for up-to-date info and implementation with multiple TPUs
# Usually pre-installed on TPU instances
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
import torch
import torch_xla.core.xla_model as xm
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = xm.xla_device()
# optional: automatic mixed precision
ctx = torch.autocast(device, dtype=torch.bfloat16)
model = Net(...)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), ...)
dataloader = ...
for x, y in dataloader:
optimizer.zero_grad()
x, y = x.to(device), y.to(device)
with ctx:
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
if device.type == "xla":
# We must execute the computational graph because
# XLA tensors are lazy, contrary to CPU/GPU tensors
# (https://pytorch.org/xla/release/2.2/index.html#xla-tensors-are-lazy)
xm.mark_step()
Warning
Pitfall! If you move a model to TPU after defining its optimizer like the following, the model parameters will not update.
model = Net(...)
optimizer = torch.optim.AdamW(model.parameters(), ...)
model.to(device)
This is most likely because the optimizer doesn't get references to the model parameters (a bug), even though it works perfectly fine on CPUs/GPUs.
# torch.save(module.state_dict(), path)
xm.save(module.state_dict(), path)