Skip to content

Instantly share code, notes, and snippets.

@philtrade
Last active February 14, 2022 09:35
Show Gist options
  • Save philtrade/cb9953079a99afc522fa777574f2549d to your computer and use it in GitHub Desktop.
Save philtrade/cb9953079a99afc522fa777574f2549d to your computer and use it in GitHub Desktop.
FastAI v1 GANTrainer interfered by PyTorch DistributedDataParallel
#!/usr/bin/env python3
# Run this script as:
# (Yes, even with nproc_per_node=1, it'll trigger the bug)
# python -m torch.distributed.launch --nproc_per_node=1 wgan_ddp.py
#
import argparse
from fastai.vision import *
from fastai.vision.gan import *
from fastai.distributed import *
import torch
def get_data(path, bs, size):
return (GANItemList.from_folder(path, noise_sz=100)
.split_none()
.label_from_func(noop)
.transform(tfms=[[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], []], size=size, tfm_y=True)
.databunch(bs=bs)
.normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True))
def train(local_rank:int, epochs:int=1):
path = untar_data(URLs.LSUN_BEDROOMS)
bs = 128 * 14
data = get_data(path, bs, 64)
generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = basic_critic (in_size=64, n_channels=3, n_extra_layers=1)
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
learn = learn.to_distributed(local_rank)
learn.fit(epochs, 2e-4)
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
train(args.local_rank, 1)
@foobarhe
Copy link

hello, I met the same issue here.
Did you find a work around for this stuck problem? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment