Last active
March 30, 2021 08:01
-
-
Save ariG23498/b97356d1b9e2091e46fc176e66c4d977 to your computer and use it in GitHub Desktop.
DDP with Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Usage: | |
# python -m torch.distributed.launch \ | |
# --nproc_per_node 2 \ | |
# --nnodes 1 \ | |
# --node_rank 0 \ | |
# pytorch-ddp-all.py \ | |
# --epochs 10 \ | |
# --batch 512 \ | |
# --entity <ENTITY> \ | |
# --project <PROJECT> | |
# IMPORTS | |
import os | |
import wandb | |
import argparse | |
import numpy as np | |
from datetime import datetime | |
import torch | |
import torchvision | |
import torch.nn as nn | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import torchvision.transforms as transforms | |
def parse_args(): | |
""" | |
Parse arguments given to the scrip. | |
Returns: | |
The parsed argument object. | |
""" | |
parser = argparse.ArgumentParser() | |
# Used for `distribution.launch` | |
parser.add_argument( | |
"--local_rank", type=int, default=-1, metavar="N", help="Local process rank." | |
) | |
parser.add_argument( | |
"--epochs", | |
default=2, | |
type=int, | |
metavar="N", | |
help="number of total epochs to run", | |
) | |
parser.add_argument( | |
"--batch", | |
default=32, | |
type=int, | |
metavar="N", | |
help="number of data samples in one batch", | |
) | |
parser.add_argument( | |
"--entity", | |
type=str, | |
help="wandb entity", | |
) | |
parser.add_argument( | |
"--project", | |
type=str, | |
help="wandb project", | |
) | |
args = parser.parse_args() | |
return args | |
class ConvNet(nn.Module): | |
def __init__(self, num_classes=10): | |
super(ConvNet, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), | |
nn.BatchNorm2d(16), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
) | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), | |
nn.BatchNorm2d(32), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
) | |
self.fc = nn.Linear(7 * 7 * 32, num_classes) | |
def forward(self, x): | |
out = self.layer1(x) | |
out = self.layer2(out) | |
out = out.reshape(out.size(0), -1) | |
out = self.fc(out) | |
return out | |
def train(args, run=None): | |
""" | |
Train method for the model. | |
Args: | |
args: The parsed argument object | |
run: The wandb run object | |
""" | |
# Check to see if local_rank is 0 | |
args.is_master = args.local_rank == 0 | |
# set the device | |
args.device = torch.device(args.local_rank) | |
# initialize PyTorch distributed using environment variables | |
dist.init_process_group(backend="nccl", init_method="env://") | |
torch.cuda.set_device(args.local_rank) | |
# set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.) | |
torch.cuda.manual_seed_all(42) | |
# initialize model | |
model = ConvNet() | |
# send your model to GPU | |
model = model.to(args.device) | |
batch_size = args.batch | |
# define loss function (criterion) and optimizer | |
criterion = nn.CrossEntropyLoss().cuda(args.device) | |
optimizer = torch.optim.SGD(model.parameters(), 1e-4) | |
# Wrap the model | |
model = nn.parallel.DistributedDataParallel( | |
model, device_ids=[args.local_rank], output_device=args.local_rank | |
) | |
# watch gradients only for rank 0 | |
if run: | |
run.watch(model) | |
# Data loading code | |
train_dataset = torchvision.datasets.FashionMNIST( | |
root="./data", train=True, transform=transforms.ToTensor(), download=True | |
) | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
train_loader = torch.utils.data.DataLoader( | |
dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=0, | |
pin_memory=True, | |
sampler=train_sampler, | |
) | |
print("Training Begins ...") | |
start = datetime.now() | |
total_step = len(train_loader) | |
for epoch in range(args.epochs): | |
batch_loss = [] | |
for i, (images, labels) in enumerate(train_loader): | |
images = images.cuda(non_blocking=True) | |
labels = labels.cuda(non_blocking=True) | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
batch_loss.append(loss.item()) | |
if (i + 1) % 100 == 0 and args.is_master: | |
print( | |
"Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format( | |
epoch + 1, args.epochs, i + 1, total_step, loss.item() | |
) | |
) | |
if run: | |
run.log({"batch_loss": loss.item()}) | |
if run: | |
run.log({"epoch": epoch, "loss": np.mean(batch_loss)}) | |
print("Training complete in: " + str(datetime.now() - start)) | |
if __name__ == "__main__": | |
# get args | |
args = parse_args() | |
# Initialize run | |
run = wandb.init( | |
entity=args.entity, | |
project=args.project, | |
group = "DDP", | |
) | |
# Train model with DDP | |
train(args, run) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Usage: | |
# python -m torch.distributed.launch \ | |
# --nproc_per_node 2 \ | |
# --nnodes 1 \ | |
# --node_rank 0 \ | |
# pytorch-ddp-rank0.py \ | |
# --epochs 10 \ | |
# --batch 512 \ | |
# --entity <ENTITY> \ | |
# --project <PROJECT> | |
# IMPORTS | |
import os | |
import wandb | |
import argparse | |
import numpy as np | |
from datetime import datetime | |
import torch | |
import torchvision | |
import torch.nn as nn | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import torchvision.transforms as transforms | |
def parse_args(): | |
""" | |
Parse arguments given to the scrip. | |
Returns: | |
The parsed argument object. | |
""" | |
parser = argparse.ArgumentParser() | |
# Used for `distribution.launch` | |
parser.add_argument( | |
"--local_rank", type=int, default=-1, metavar="N", help="Local process rank." | |
) | |
parser.add_argument( | |
"--epochs", | |
default=2, | |
type=int, | |
metavar="N", | |
help="number of total epochs to run", | |
) | |
parser.add_argument( | |
"--batch", | |
default=32, | |
type=int, | |
metavar="N", | |
help="number of data samples in one batch", | |
) | |
parser.add_argument( | |
"--entity", | |
type=str, | |
help="wandb entity", | |
) | |
parser.add_argument( | |
"--project", | |
type=str, | |
help="wandb project", | |
) | |
args = parser.parse_args() | |
return args | |
class ConvNet(nn.Module): | |
def __init__(self, num_classes=10): | |
super(ConvNet, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), | |
nn.BatchNorm2d(16), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
) | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), | |
nn.BatchNorm2d(32), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
) | |
self.fc = nn.Linear(7 * 7 * 32, num_classes) | |
def forward(self, x): | |
out = self.layer1(x) | |
out = self.layer2(out) | |
out = out.reshape(out.size(0), -1) | |
out = self.fc(out) | |
return out | |
def train(args, run=None): | |
""" | |
Train method for the model. | |
Args: | |
args: The parsed argument object | |
run: The wandb run object | |
""" | |
# Check to see if local_rank is 0 | |
args.is_master = args.local_rank == 0 | |
# set the device | |
args.device = torch.device(args.local_rank) | |
# initialize PyTorch distributed using environment variables | |
dist.init_process_group(backend="nccl", init_method="env://") | |
torch.cuda.set_device(args.local_rank) | |
# set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.) | |
torch.cuda.manual_seed_all(42) | |
# initialize model | |
model = ConvNet() | |
# send your model to GPU | |
model = model.to(args.device) | |
batch_size = args.batch | |
# define loss function (criterion) and optimizer | |
criterion = nn.CrossEntropyLoss().cuda(args.device) | |
optimizer = torch.optim.SGD(model.parameters(), 1e-4) | |
# Wrap the model | |
model = nn.parallel.DistributedDataParallel( | |
model, device_ids=[args.local_rank], output_device=args.local_rank | |
) | |
# watch gradients only for rank 0 | |
if args.is_master and run: | |
run.watch(model) | |
# Data loading code | |
train_dataset = torchvision.datasets.FashionMNIST( | |
root="./data", train=True, transform=transforms.ToTensor(), download=True | |
) | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
train_loader = torch.utils.data.DataLoader( | |
dataset=train_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=0, | |
pin_memory=True, | |
sampler=train_sampler, | |
) | |
print("Training Begins ...") | |
start = datetime.now() | |
total_step = len(train_loader) | |
for epoch in range(args.epochs): | |
batch_loss = [] | |
for i, (images, labels) in enumerate(train_loader): | |
images = images.cuda(non_blocking=True) | |
labels = labels.cuda(non_blocking=True) | |
# Forward pass | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
# Backward and optimize | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
batch_loss.append(loss.item()) | |
if (i + 1) % 100 == 0 and args.is_master: | |
print( | |
"Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format( | |
epoch + 1, args.epochs, i + 1, total_step, loss.item() | |
) | |
) | |
if args.is_master and run: | |
run.log({"batch_loss": loss.item()}) | |
if args.is_master and run: | |
run.log({"epoch": epoch, "loss": np.mean(batch_loss)}) | |
print("Training complete in: " + str(datetime.now() - start)) | |
if __name__ == "__main__": | |
# get args | |
args = parse_args() | |
if args.local_rank == 0: | |
# Initialize wandb run | |
run = wandb.init( | |
entity=args.entity, | |
project=args.project, | |
) | |
# Train model with DDP | |
train(args, run) | |
else: | |
train(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment