Skip to content

Instantly share code, notes, and snippets.

@bobmayuze
Last active May 25, 2025 09:19
Show Gist options
  • Save bobmayuze/1bd03986ac7e9153bcd3eca343f5e9f7 to your computer and use it in GitHub Desktop.
Save bobmayuze/1bd03986ac7e9153bcd3eca343f5e9f7 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import os
from torchvision import datasets, transforms
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
class MNISTModel(nn.Module):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train():
# Initialize process group
dist.init_process_group(backend="nccl")
# Get local rank from environment variable
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Set device
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
print(f"Running on rank {rank} (local_rank: {local_rank})")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(dataset, batch_size=64, sampler=sampler)
model = MNISTModel().to(device)
model = DDP(model, device_ids=[local_rank])
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(1, 11):
sampler.set_epoch(epoch)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
if rank == 0:
torch.save(model.module.state_dict(), "mnist_model.pth")
print("Model saved as mnist_model.pth")
dist.destroy_process_group()
if __name__ == "__main__":
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment