Last active
February 14, 2022 09:35
-
-
Save philtrade/cb9953079a99afc522fa777574f2549d to your computer and use it in GitHub Desktop.
FastAI v1 GANTrainer interfered by PyTorch DistributedDataParallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Run this script as: | |
# (Yes, even with nproc_per_node=1, it'll trigger the bug) | |
# python -m torch.distributed.launch --nproc_per_node=1 wgan_ddp.py | |
# | |
import argparse | |
from fastai.vision import * | |
from fastai.vision.gan import * | |
from fastai.distributed import * | |
import torch | |
def get_data(path, bs, size): | |
return (GANItemList.from_folder(path, noise_sz=100) | |
.split_none() | |
.label_from_func(noop) | |
.transform(tfms=[[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], []], size=size, tfm_y=True) | |
.databunch(bs=bs) | |
.normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_x=False, do_y=True)) | |
def train(local_rank:int, epochs:int=1): | |
path = untar_data(URLs.LSUN_BEDROOMS) | |
bs = 128 * 14 | |
data = get_data(path, bs, 64) | |
generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1) | |
critic = basic_critic (in_size=64, n_channels=3, n_extra_layers=1) | |
learn = GANLearner.wgan(data, generator, critic, switch_eval=False, | |
opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.) | |
learn = learn.to_distributed(local_rank) | |
learn.fit(epochs, 2e-4) | |
if __name__ == "__main__" : | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local_rank", type=int) | |
args = parser.parse_args() | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
train(args.local_rank, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hello, I met the same issue here.
Did you find a work around for this stuck problem? Thanks.