Last active
November 20, 2020 13:58
-
-
Save jw3126/29cbb8177a9f4acfb66fc3a3ed7cbf24 to your computer and use it in GitHub Desktop.
pytorch_lightning_ddp_gradient_checkpointing_bug
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This reproduces a pytorch_lightning issue | |
# where gradient checkpointing + ddp results in nan loss | |
# | |
# * Run with gpus=1 and it works fine. | |
# * Run with gpus=4 and it loss becomes nan quickly | |
# | |
# See also https://forums.pytorchlightning.ai/t/gradient-checkpointing-ddp-nan/398 | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader, random_split, Dataset | |
from torchvision.datasets import MNIST | |
from torchvision import transforms | |
import pytorch_lightning as pl | |
from pytorch_lightning.metrics.functional import accuracy | |
class RandomDataset(Dataset): | |
def __init__(self, size, num_samples): | |
self.len = num_samples | |
self.data = torch.randn(num_samples, size) | |
def __getitem__(self, index): | |
return self.data[index] | |
def __len__(self): | |
return self.len | |
class MergeLayer(torch.nn.Module): | |
def __init__(self, in_size, out_size): | |
super().__init__() | |
self.layer = torch.nn.Linear(in_size, out_size) | |
def apply_forward(self, xs): | |
# pdb.set_trace() | |
y = torch.cat(xs,dim=1) | |
return self.layer(y) | |
def _apply_forward_splat(self, *xs): | |
return self.apply_forward(xs) | |
def forward(self, xs): | |
requires_grad = False | |
for x in xs: | |
if x.requires_grad: | |
requires_grad = True | |
if requires_grad: | |
return torch.utils.checkpoint.checkpoint(self._apply_forward_splat, *xs) | |
else: | |
return self.apply_forward(xs) | |
class BoringModel(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.layer1 = torch.nn.Linear(32, 32) | |
self.merge = MergeLayer(64, 32) | |
def forward(self, x): | |
x1 = F.leaky_relu(self.layer1(x)) | |
# x2 = F.leaky_relu(self.layer2(x)) | |
xs = [x, x1, ] | |
return self.merge(xs) | |
def training_step(self, batch, batch_idx): | |
output = self.forward(batch) | |
loss = torch.nn.functional.mse_loss(output, batch) | |
return {"loss": loss} | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-2) | |
return optimizer | |
num_samples = 10000 | |
train = RandomDataset(32, num_samples) | |
train = DataLoader(train, batch_size=32) | |
model = BoringModel() | |
# Initialize a trainer | |
trainer = pl.Trainer( | |
max_epochs=10, | |
progress_bar_refresh_rate=20, | |
accelerator="ddp", | |
gpus=4, # nan loss | |
# gpus=1, #works | |
) | |
# Train the model ⚡ | |
trainer.fit(model, train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment