Last active
July 30, 2021 12:48
-
-
Save andrewssobral/090dcab34308bdd1ed75e5f2f6b4a1d0 to your computer and use it in GitHub Desktop.
pytorch_lightning_distributed_training.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import multiprocessing | |
import torch | |
import torch.nn.functional as F | |
import torchmetrics | |
import pytorch_lightning as pl | |
from argparse import ArgumentParser | |
from torch import nn | |
from torchvision.datasets import MNIST | |
from torch.utils.data import DataLoader, random_split | |
from torchvision import transforms | |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning import loggers | |
# Download MNIST | |
# For new versions of TorchVision | |
# wget -O MNIST.tar.gz https://activeeon-public.s3.eu-west-2.amazonaws.com/datasets/MNIST.new.tar.gz | |
# For old versions of TorchVision (<= 1.0.1*) | |
# wget -O MNIST.tar.gz https://activeeon-public.s3.eu-west-2.amazonaws.com/datasets/MNIST.old.tar.gz | |
# tar -zxvf MNIST.tar.gz | |
def download_mnist(dataset_path): | |
dataset_folder = os.path.join(dataset_path, "MNIST") | |
dataset_file = dataset_folder + ".tar.gz" | |
if not os.path.exists(dataset_folder): | |
if not os.path.exists(dataset_file): | |
print("Downloading MNIST data set to " + dataset_file) | |
import wget | |
wget.download("https://activeeon-public.s3.eu-west-2.amazonaws.com/datasets/MNIST.new.tar.gz", dataset_file) | |
print("Extracting " + dataset_file) | |
import tarfile | |
tar = tarfile.open(dataset_file, "r:gz") | |
tar.extractall(dataset_path) | |
tar.close() | |
else: | |
print("Using existing MNIST data set at " + dataset_folder) | |
class LitMNIST(pl.LightningModule): | |
def __init__(self, dataset_path='./', optimizer='adam', learning_rate=2e-4, batch_size=64, hidden_size=64): | |
super().__init__() | |
# Set our init args as class attributes | |
self.data_dir = dataset_path | |
self.hidden_size = hidden_size | |
self.learning_rate = learning_rate | |
self.optimizer = optimizer.lower() | |
self.batch_size = batch_size | |
self.train_acc_metric = torchmetrics.Accuracy() | |
self.val_acc_metric = torchmetrics.Accuracy() | |
# Hardcode some dataset specific attributes | |
self.num_classes = 10 | |
self.dims = (1, 28, 28) | |
channels, width, height = self.dims | |
self.transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
# Define PyTorch model | |
self.model = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(channels * width * height, hidden_size), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(hidden_size, hidden_size), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(hidden_size, self.num_classes) | |
) | |
def forward(self, x): | |
x = self.model(x) | |
return F.log_softmax(x, dim=1) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
preds = torch.argmax(logits, dim=1) | |
self.train_loss = F.nll_loss(logits, y) | |
self.train_acc = self.train_acc_metric(preds, y) | |
# Calling self.log will surface up scalars for you in TensorBoard | |
# https://pytorch-lightning.readthedocs.io/en/1.1.2/multi_gpu.html#synchronize-validation-and-test-logging | |
# https://pytorch-lightning.readthedocs.io/en/1.0.8/logging.html#automatic-logging | |
self.log('train_loss', self.train_loss, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) | |
self.log('train_acc', self.train_acc, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) | |
return self.train_loss | |
def training_epoch_end(self, outs): | |
print("Epoch training acc: ", self.train_acc_metric.compute()) | |
self.train_acc_metric.reset() | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
preds = torch.argmax(logits, dim=1) | |
self.val_loss = F.nll_loss(logits, y) | |
self.val_acc = self.val_acc_metric(preds, y) | |
self.log('val_loss', self.val_loss, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) | |
self.log('val_acc', self.val_acc, prog_bar=False, on_step=True, on_epoch=True, sync_dist=True) | |
return self.val_loss | |
def validation_epoch_end(self, outs): | |
print("Epoch validation acc: ", self.val_acc_metric.compute()) | |
self.val_acc_metric.reset() | |
def test_step(self, batch, batch_idx): | |
# Here we just reuse the validation_step for testing | |
return self.validation_step(batch, batch_idx) | |
def get_metrics(self): | |
# Metric on all batches using custom accumulation | |
train_acc = self.train_acc_metric.compute() | |
val_acc = self.val_acc_metric.compute() | |
print(f"Training accuracy on all data: {train_acc}") | |
print(f"Validation accuracy on all data: {val_acc}") | |
metrics = {'train_acc': train_acc.item(), 'val_acc': val_acc.item()} | |
return metrics | |
def configure_optimizers(self): | |
optimizer = None | |
if self.optimizer == 'adam': | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
if self.optimizer == 'sgd': | |
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) | |
if self.optimizer == 'rmsprop': | |
optimizer = torch.optim.RMSprop(self.parameters(), lr=self.learning_rate) | |
assert optimizer is not None | |
return optimizer | |
def prepare_data(self): | |
# download | |
MNIST(self.data_dir, train=True, download=False) | |
MNIST(self.data_dir, train=False, download=False) | |
def setup(self, stage=None): | |
# Assign train/val datasets for use in dataloaders | |
if stage == 'fit' or stage is None: | |
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) | |
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) | |
# Assign test dataset for use in dataloader(s) | |
if stage == 'test' or stage is None: | |
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) | |
def train_dataloader(self): | |
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.get_cpu_count(), pin_memory=self.get_pin_memory()) | |
def val_dataloader(self): | |
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.get_cpu_count(), pin_memory=self.get_pin_memory()) | |
def test_dataloader(self): | |
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.get_cpu_count(), pin_memory=self.get_pin_memory()) | |
def get_cpu_count(self): | |
cpu_count = multiprocessing.cpu_count() | |
return cpu_count | |
def get_pin_memory(self): | |
if torch.cuda.is_available(): | |
return True | |
else: | |
return False | |
def main(args): | |
# Download the MNIST data set | |
download_mnist(args.dataset_path) | |
# Build CNN | |
model = LitMNIST( | |
dataset_path=args.dataset_path, | |
optimizer=args.optimizer, | |
learning_rate=args.learning_rate, | |
batch_size=args.batch_size | |
) | |
# Check if Tensorboard is enabled | |
logger = [] | |
if args.tensorboard_enabled: | |
logger = loggers.TensorBoardLogger(args.tensorboard_logdir) | |
callbacks = [] | |
checkpoint_callback = None | |
if args.model_checkpoint_enabled: | |
# returned metrics: | |
# ['train_loss_step', 'train_acc_step', 'train_loss_epoch', 'train_acc_epoch', 'train_loss', 'train_acc', | |
# 'val_loss_epoch', 'val_acc_epoch', 'val_loss', 'val_acc'] | |
checkpoint_callback = ModelCheckpoint( | |
monitor='val_loss_epoch', | |
dirpath=args.model_checkpoint_path, | |
filename='sample-mnist-{epoch:02d}-{val_loss_epoch:.2f}' | |
) | |
callbacks.append(checkpoint_callback) | |
# Use GPUs if they are available | |
gpus = None | |
plugins = None | |
if torch.cuda.is_available(): | |
gpus = torch.cuda.device_count() | |
# from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin | |
# plugin = DDPSequentialPlugin(balance=[1, 1, 1, 1]) | |
# plugins = [plugin] | |
# Distributed modes | |
# https://pytorch-lightning.readthedocs.io/en/1.1.2/multi_gpu.html#distributed-modes | |
# Data Parallel (accelerator='dp') (multiple-gpus, 1 machine) | |
# DistributedDataParallel (accelerator='ddp') (multiple-gpus across many machines - python script based) | |
# DistributedDataParallel (accelerator=’ddp_spawn’) (multiple-gpus across many machines - spawn based) | |
# DistributedDataParallel 2 (accelerator=’ddp2’) (DP in a machine, DDP across machines) | |
# Horovod (accelerator=’horovod’) (multi-machine, multi-gpu, configured at runtime) | |
# TPUs (tpu_cores=8|x) (tpu or TPU pod) | |
# DistributedDataParallel | |
# python pytorch_lightning_distributed_training.py --accelerator ddp --gpus 1 --max_epochs 3 --model_checkpoint_enabled | |
# https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#select-torch-distributed-backend | |
# export PL_TORCH_DISTRIBUTED_BACKEND=gloo | |
# python pytorch_lightning_distributed_training.py --accelerator ddp --gpus 2 --max_epochs 3 --model_checkpoint_enabled | |
# srun --partition DL380_GPU --ntasks=2 --nodes 2 python pytorch_lightning_distributed_training.py --accelerator ddp --gpus 1 --num_nodes 2 --max_epochs 3 | |
# srun --partition DL380_GPU --ntasks=4 --nodes 2 python pytorch_lightning_distributed_training.py --accelerator ddp --gpus 2 --num_nodes 2 --max_epochs 3 | |
# trainer = pl.Trainer.from_argparse_args(args) | |
# Train Horovod on GPU (number of GPUs / machines provided on command-line) | |
# https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#horovod | |
# horovodrun --gloo -np 4 python pytorch_lightning_distributed_training.py | |
# trainer = pl.Trainer(accelerator='horovod', gpus=1) | |
# | |
# Makes all trainer options available from the command line | |
# horovodrun --gloo -np 1 -H gpu01:1 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 | |
# horovodrun --gloo -np 2 -H gpu01:2 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 | |
# horovodrun --gloo -np 4 -H gpu01:4 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 5 | |
# horovodrun --gloo -np 4 -H gpu01:2,gpu02:2 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 | |
# horovodrun --gloo -np 2 -H gpu01:1,gpu02:1 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 | |
# horovodrun --gloo -np 3 -H gpu01:1,gpu02:1,gpu03:1 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 5 | |
# horovodrun --gloo -np 6 -H gpu01:2,gpu02:2,gpu03:2 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 5 | |
# horovodrun --gloo -np 12 -H gpu01:4,gpu02:4,gpu03:4 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 5 | |
# trainer = pl.Trainer.from_argparse_args(args, logger=logger) | |
if args.accelerator == 'ddp': | |
callbacks.append(EarlyStopping(monitor='val_loss')) # EarlyStopping doesnt work with horovod | |
trainer = pl.Trainer.from_argparse_args(args, logger=[], callbacks=callbacks) # Tensorboard logger doesnt work with ddp | |
if args.accelerator == 'horovod': | |
trainer = pl.Trainer.from_argparse_args(args, logger=logger, callbacks=callbacks) | |
# final accuracies should be averaged | |
trainer.fit(model) | |
print("Metrics:\n", model.get_metrics()) | |
if args.model_checkpoint_enabled: | |
print("best_model_path: ", checkpoint_callback.best_model_path) | |
# automatically loads the best weights! | |
# trainer.test() | |
# clear && HOROVOD_CACHE_CAPACITY=0 horovodrun --gloo -np 1 -H gpu01:1 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 --dataset_path /nfs/activeeon/datasets --optimizer adam --learning_rate 0.01 --tensorboard_enabled --tensorboard_logdir logs/ --model_checkpoint_enabled | |
# clear && HOROVOD_CACHE_CAPACITY=0 horovodrun --gloo -np 2 -H gpu01:2 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 --dataset_path /nfs/activeeon/datasets --optimizer adam --learning_rate 0.01 | |
# clear && HOROVOD_CACHE_CAPACITY=0 horovodrun --gloo -np 2 -H gpu01:1,gpu02:1 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 --dataset_path /nfs/activeeon/datasets --optimizer adam --learning_rate 0.01 | |
# clear && HOROVOD_CACHE_CAPACITY=0 horovodrun --gloo -np 4 -H gpu01:2,gpu02:2 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 --dataset_path /nfs/activeeon/datasets --optimizer adam --learning_rate 0.01 | |
# clear && HOROVOD_CACHE_CAPACITY=0 horovodrun --gloo -np 6 -H gpu01:2,gpu02:2,gpu03:2 python pytorch_lightning_distributed_training.py --accelerator horovod --gpus 1 --max_epochs 3 --dataset_path /nfs/activeeon/datasets --optimizer adam --learning_rate 0.01 | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
# add your own args | |
parser.add_argument('--dataset_path', type=str, default="./") | |
parser.add_argument('--optimizer', type=str, default="adam") | |
parser.add_argument('--learning_rate', type=float, default=2e-4) | |
parser.add_argument('--batch_size', type=int, default=64) | |
parser.add_argument('--tensorboard_enabled', action='store_true') | |
parser.add_argument('--tensorboard_logdir', type=str, default="logs/") | |
parser.add_argument('--model_checkpoint_enabled', action='store_true') | |
parser.add_argument('--model_checkpoint_path', type=str, default="checkpoints/") | |
# add all the available options to the trainer | |
parser = pl.Trainer.add_argparse_args(parser) | |
args = parser.parse_args() | |
print(args) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment