Last active
October 31, 2019 18:03
-
-
Save philtrade/e0ecbfd33eed7dad80642318e806a90f to your computer and use it in GitHub Desktop.
Simulate race condition in fastai.callbacks.tracker.SaveModelCallback
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| The script simulates a potential race condition in SaveModelCallBack, when 'every='improvement' is specified. | |
| When launched in PyTorch's DistributedDataParallel (DDP) mode on a single host with multiple GPU: | |
| python3 -m torch.distributed.launch --nproc_per_node=3 test_barrier_load.py | |
| The master process (Rank 0) would sleep a few seconds before saving the model after the last epoch, in on_epoch_end(). | |
| Other process would attempt to load in on_train_end(). Without synchronization the script would crash. | |
| When properly synchronized, other processes will wait for the master process arrive at the post-write barrier as well, before proceeding to read the file, as the following run on a single host with 3 GPUs, 3 epochs: | |
| $ python3 -m torch.distributed.launch --nproc_per_node=3 test_barrier_load.py | |
| epoch train_loss valid_loss accuracy time | |
| 0 0.651191 0.644135 0.555392 00:01 | |
| 1 0.623568 0.608578 0.697059 00:01 | |
| 2 0.602933 0.600173 0.734804 00:01 | |
| [Fri Nov 1 01:57:31 2019]: Rank 2 Attempting to load model. | |
| [Fri Nov 1 01:57:31 2019]: Rank 1 Attempting to load model. | |
| [Fri Nov 1 01:57:34 2019] Rank 0 Model saved. | |
| [Fri Nov 1 01:57:34 2019]: Rank 0 Attempting to load model. | |
| [Fri Nov 1 01:57:35 2019]: Rank 1 Loaded model. | |
| [Fri Nov 1 01:57:35 2019]: Rank 2 Loaded model. | |
| [Fri Nov 1 01:57:35 2019]: Rank 0 Loaded model. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import fastai, time, argparse, os | |
| from fastai.vision import * | |
| from fastai.distributed import * | |
| from fastai.callbacks.tracker import SaveModelCallback | |
| from fastai.torch_core import rank_distrib | |
| from torch.distributed import * | |
| class SaveModelCallback_DDP(SaveModelCallback): | |
| "A `TrackerCallback` that saves the model when monitored quantity is best." | |
| def __init__(self, learn:Learner, monitor:str='valid_loss', mode:str='auto', every:str='improvement', name:str= | |
| 'bestmodel'): | |
| super().__init__(learn, monitor=monitor, mode=mode, every=every, name=name) | |
| self.in_DDP = torch.distributed.is_available() and torch.distributed.is_initialized() | |
| if (learn.path/f'{learn.model_dir}/{name}.pth').is_file(): os.unlink(learn.path/f'{learn.model_dir}/{name}. | |
| pth') | |
| def on_epoch_end(self, epoch:int, n_epochs:int, **kwargs:Any)->None: | |
| if epoch == n_epochs - 1: # Simulate saving the best model on the last epoch | |
| if self.in_DDP and not rank_distrib(): time.sleep(3) # sleeping in master proc for simulate GPU speed s | |
| kew | |
| self.learn.save(f'{self.name}') | |
| if self.in_DDP and not rank_distrib(): print(f'[{time.ctime()}] Rank {rank_distrib()} Model saved.') | |
| def on_train_end(self, **kwargs): | |
| "Load the best model." | |
| if self.every=="improvement": | |
| if self.in_DDP: print(f'[{time.ctime()}]: Rank {rank_distrib()} Attempting to load model.') | |
| self.learn.load(f'{self.name}', purge=False) | |
| print(f'[{time.ctime()}]: Rank {rank_distrib()} Loaded model.') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--local_rank", type=int) | |
| args = parser.parse_args() | |
| if args.local_rank is not None: | |
| torch.cuda.set_device(args.local_rank) | |
| torch.distributed.init_process_group(backend='nccl', init_method='env://') | |
| path = untar_data(URLs.MNIST_SAMPLE) | |
| data = ImageDataBunch.from_folder(path) | |
| learn = Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy]) | |
| if args.local_rank is not None: learn = learn.to_distributed(args.local_rank) | |
| learn.fit_one_cycle(3,1e-4, callbacks=[SaveModelCallback_DDP(learn, every='improvement', monitor='accuracy', name=' | |
| best')]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment