Skip to content

Instantly share code, notes, and snippets.

@afrendeiro
Created July 24, 2023 08:11
Show Gist options
  • Save afrendeiro/dc83d8b09f1e96b6a5af98b3d22a0302 to your computer and use it in GitHub Desktop.
Save afrendeiro/dc83d8b09f1e96b6a5af98b3d22a0302 to your computer and use it in GitHub Desktop.
Multi-GPU training with hf-accelerate
#!/usr/bin/env python
"""
A simple example of how to use the Accelerator API
to train a vision model on a dummy dataset.
Accelerator enables training on a single or multiple GPUs.
Run once `accelerate config` to set up your configuration file.
Run with `accelerate launch test_gpus_accelerate.py` to run on all GPUs.
"""
import fire
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from accelerate import Accelerator
class DummyDataset(Dataset):
def __len__(self) -> int:
return 1_000_000
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
image = torch.rand(3, 224, 224)
label = torch.randint(0, 1000, (1,))[0]
return image, label
def test(
model_name: str = "alexnet",
batch_size: int = 2048,
epochs: int = 3,
num_workers: int = 8,
):
accelerator = Accelerator()
device = accelerator.device
model = getattr(torchvision.models, model_name)(weights="DEFAULT").to(device)
optimizer = torch.optim.Adam(model.parameters())
dataset = DummyDataset()
data = DataLoader(
dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
model, optimizer, data = accelerator.prepare(model, optimizer, data)
model.train()
if accelerator.is_local_main_process:
tqdm0 = tqdm(range(epochs), position=0, leave=True, desc="Epochs")
tqdm1 = tqdm(data, position=1, leave=False, desc="Batches")
else:
tqdm0 = range(epochs)
tqdm1 = data
for epoch in tqdm0:
for source, targets in tqdm1:
source = source.to(device)
targets = targets.to(device)
optimizer.zero_grad()
output = model(source)
loss = F.cross_entropy(output, targets)
if accelerator.is_local_main_process:
tqdm1.set_postfix({"images": source.shape[0], "loss": loss.item()})
accelerator.backward(loss)
optimizer.step()
# Valid evaluation
if __name__ == "__main__":
fire.Fire(test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment