Skip to content

Instantly share code, notes, and snippets.

@philtrade
Created February 19, 2020 03:18
Show Gist options
  • Save philtrade/33d3982a8cc0f827b1b4e3d1304a08b9 to your computer and use it in GitHub Desktop.
Save philtrade/33d3982a8cc0f827b1b4e3d1304a08b9 to your computer and use it in GitHub Desktop.
Text Classification training accuracy problem in fastai distributed training due to samples not being shuffled
#!/usr/bin/env python3
import fastai
from fastai.text import *
from fastai.distributed import *
import torch
import argparse, os
def train(local_rank:int=None, epochs:int=1):
if local_rank is not None:
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
path = untar_data(URLs.IMDB)
bs = 48
data_clas = load_data(path, 'data_clas.pkl', bs=bs)
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)
learn.load_encoder('fine_tuned_enc')
if local_rank is not None: learn = learn.to_distributed(local_rank)
learn.fit_one_cycle(epochs, 2e-2, moms=(0.8,0.7))
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
args = parser.parse_args()
print(f"local_rank: {args.local_rank}", flush=sys.stderr)
train(args.local_rank, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment