Created
June 10, 2021 10:44
-
-
Save edgarriba/af6247edb32586b19e740f17735ff055 to your computer and use it in GitHub Desktop.
lightning_optuna_multigpu example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_lightning_simple.py | |
""" | |
Optuna example that optimizes multi-layer perceptrons using PyTorch Lightning. | |
In this example, we optimize the validation accuracy of hand-written digit recognition using | |
PyTorch Lightning, and FashionMNIST. We optimize the neural network architecture. As it is too time | |
consuming to use the whole FashionMNIST dataset, we here use a small subset of it. | |
You can run this example as follows, pruning can be turned on and off with the `--pruning` | |
argument. | |
$ python pytorch_lightning_simple.py [--pruning] | |
""" | |
import argparse | |
import os | |
from typing import List | |
from typing import Optional | |
from packaging import version | |
import pytorch_lightning as pl | |
import torch | |
from torch import nn | |
from torch import optim | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torch.utils.data import random_split | |
from torchvision import datasets | |
from torchvision import transforms | |
import optuna | |
from optuna.integration import PyTorchLightningPruningCallback | |
if version.parse(pl.__version__) < version.parse("1.0.2"): | |
raise RuntimeError("PyTorch Lightning>=1.0.2 is required for this example.") | |
PERCENT_VALID_EXAMPLES = 0.1 | |
BATCHSIZE = 128 | |
CLASSES = 10 | |
EPOCHS = 2 | |
DIR = os.getcwd() | |
class Net(nn.Module): | |
def __init__(self, dropout: float, output_dims: List[int]): | |
super().__init__() | |
layers: List[nn.Module] = [] | |
input_dim: int = 28 * 28 | |
for output_dim in output_dims: | |
layers.append(nn.Linear(input_dim, output_dim)) | |
layers.append(nn.ReLU()) | |
layers.append(nn.Dropout(dropout)) | |
input_dim = output_dim | |
layers.append(nn.Linear(input_dim, CLASSES)) | |
self.layers: nn.Module = nn.Sequential(*layers) | |
def forward(self, data: torch.Tensor) -> torch.Tensor: | |
logits = self.layers(data) | |
return F.log_softmax(logits, dim=1) | |
class LightningNet(pl.LightningModule): | |
def __init__(self, dropout: float, output_dims: List[int]): | |
super().__init__() | |
self.model = Net(dropout, output_dims) | |
def forward(self, data: torch.Tensor) -> torch.Tensor: | |
return self.model(data.view(-1, 28 * 28)) | |
this | |
def training_step(self, batch, batch_idx: int) -> torch.Tensor: | |
data, target = batch | |
output = self(data) | |
return F.nll_loss(output, target) | |
def validation_step(self, batch, batch_idx: int) -> None: | |
data, target = batch | |
output = self(data) | |
pred = output.argmax(dim=1, keepdim=True) | |
accuracy = pred.eq(target.view_as(pred)).float().mean() | |
self.log("val_acc", accuracy) | |
self.log("hp_metric", accuracy, on_step=False, on_epoch=True) | |
def configure_optimizers(self) -> optim.Optimizer: | |
return optim.Adam(self.model.parameters()) | |
class FashionMNISTDataModule(pl.LightningDataModule): | |
def __init__(self, data_dir: str, batch_size: int): | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
def setup(self, stage: Optional[str] = None) -> None: | |
self.mnist_test = datasets.FashionMNIST( | |
self.data_dir, train=False, download=True, transform=transforms.ToTensor() | |
) | |
mnist_full = datasets.FashionMNIST( | |
self.data_dir, train=True, download=True, transform=transforms.ToTensor() | |
) | |
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
self.mnist_train, batch_size=self.batch_size, shuffle=True, pin_memory=True | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
self.mnist_val, batch_size=self.batch_size, shuffle=False, pin_memory=True | |
) | |
def test_dataloader(self) -> DataLoader: | |
return DataLoader( | |
self.mnist_test, batch_size=self.batch_size, shuffle=False, pin_memory=True | |
) | |
def objective(trial: optuna.trial.Trial) -> float: | |
# We optimize the number of layers, hidden units in each layer and dropouts. | |
n_layers = trial.suggest_int("n_layers", 1, 3) | |
dropout = trial.suggest_float("dropout", 0.2, 0.5) | |
output_dims = [ | |
trial.suggest_int("n_units_l{}".format(i), 4, 128, log=True) for i in range(n_layers) | |
] | |
model = LightningNet(dropout, output_dims) | |
datamodule = FashionMNISTDataModule(data_dir=DIR, batch_size=BATCHSIZE) | |
trainer = pl.Trainer( | |
logger=True, | |
limit_val_batches=PERCENT_VALID_EXAMPLES, | |
checkpoint_callback=False, | |
max_epochs=EPOCHS, | |
gpus=1 if torch.cuda.is_available() else None, | |
accelerator='ddp_spawn', | |
callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")], | |
) | |
hyperparameters = dict(n_layers=n_layers, dropout=dropout, output_dims=output_dims) | |
trainer.logger.log_hyperparams(hyperparameters) | |
trainer.fit(model, datamodule=datamodule) | |
return trainer.spawn_callback_metrics["val_acc"] | |
#return trainer.spawn_callback_metrics["val_acc"] | |
#return model.ddp_spawn_args.callback_metrics["val_acc"].item() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="PyTorch Lightning example.") | |
parser.add_argument( | |
"--pruning", | |
"-p", | |
action="store_true", | |
help="Activate the pruning feature. `MedianPruner` stops unpromising " | |
"trials at the early stages of training.", | |
) | |
args = parser.parse_args() | |
pruner: optuna.pruners.BasePruner = ( | |
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() | |
) | |
study = optuna.create_study(direction="maximize", pruner=pruner) | |
study.optimize(objective, n_trials=100, timeout=600) | |
print("Number of finished trials: {}".format(len(study.trials))) | |
print("Best trial:") | |
trial = study.best_trial | |
print(" Value: {}".format(trial.value)) | |
print(" Params: ") | |
for key, value in trial.params.items(): | |
print(" {}: {}".format(key, value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment