Created
June 30, 2021 00:16
-
-
Save leimao/226345331961b825a8c10f39f7adbbb6 to your computer and use it in GitHub Desktop.
V2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import DataLoader | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
import argparse | |
import os | |
import random | |
import numpy as np | |
def set_random_seeds(random_seed=0): | |
torch.manual_seed(random_seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
np.random.seed(random_seed) | |
random.seed(random_seed) | |
def evaluate(model, device, test_loader): | |
model.eval() | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in test_loader: | |
images, labels = data[0].to(device), data[1].to(device) | |
outputs = model(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
accuracy = correct / total | |
return accuracy | |
def main(): | |
num_epochs_default = 10000 | |
batch_size_default = 256 # 1024 | |
learning_rate_default = 0.1 | |
random_seed_default = 0 | |
model_dir_default = "saved_models" | |
model_filename_default = "resnet_distributed.pth" | |
# Each process runs on 1 GPU device specified by the local_rank argument. | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("--local_rank", type=int, default=os.getenv('LOCAL_RANK', -1), help="Local rank. Necessary for using the torch.distributed.launch utility.") | |
parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default) | |
parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default) | |
parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default) | |
parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default) | |
parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=model_dir_default) | |
parser.add_argument("--model_filename", type=str, help="Model filename.", default=model_filename_default) | |
parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.") | |
argv = parser.parse_args() | |
local_rank = argv.local_rank | |
num_epochs = argv.num_epochs | |
batch_size = argv.batch_size | |
learning_rate = argv.learning_rate | |
random_seed = argv.random_seed | |
model_dir = argv.model_dir | |
model_filename = argv.model_filename | |
resume = argv.resume | |
# Create directories outside the PyTorch program | |
# Do not create directory here because it is not multiprocess safe | |
''' | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
''' | |
model_filepath = os.path.join(model_dir, model_filename) | |
# We need to use seeds to make sure that the models initialized in different processes are the same | |
set_random_seeds(random_seed=random_seed) | |
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs | |
torch.distributed.init_process_group(backend="nccl") | |
# torch.distributed.init_process_group(backend="gloo") | |
# Encapsulate the model on the GPU assigned to the current process | |
model = torchvision.models.resnet18(pretrained=False) | |
device = torch.device("cuda:{}".format(local_rank)) | |
model = model.to(device) | |
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) | |
# We only save the model who uses device "cuda:0" | |
# To resume, the device for the saved model would also be "cuda:0" | |
if resume == True: | |
map_location = {"cuda:0": "cuda:{}".format(local_rank)} | |
ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location)) | |
# Prepare dataset and dataloader | |
transform = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
# Data should be prefetched | |
# Download should be set to be False, because it is not multiprocess safe | |
train_set = torchvision.datasets.CIFAR10(root="data", train=True, download=False, transform=transform) | |
test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=False, transform=transform) | |
# Restricts data loading to a subset of the dataset exclusive to the current process | |
train_sampler = DistributedSampler(dataset=train_set) | |
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8) | |
# Test loader does not have to follow distributed sampling strategy | |
test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5) | |
# Loop over the dataset multiple times | |
for epoch in range(num_epochs): | |
print("Local Rank: {}, Epoch: {}, Training ...".format(local_rank, epoch)) | |
# Save and evaluate model routinely | |
if epoch % 10 == 0: | |
if local_rank == 0: | |
accuracy = evaluate(model=ddp_model, device=device, test_loader=test_loader) | |
torch.save(ddp_model.state_dict(), model_filepath) | |
print("-" * 75) | |
print("Epoch: {}, Accuracy: {}".format(epoch, accuracy)) | |
print("-" * 75) | |
ddp_model.train() | |
for data in train_loader: | |
inputs, labels = data[0].to(device), data[1].to(device) | |
optimizer.zero_grad() | |
outputs = ddp_model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment