Created
May 25, 2019 08:20
-
-
Save koshian2/372a44b0178ff8ed12eb917d53aa987c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from pytorch_models import Layer10CNN, WideResNet | |
| from apex import amp | |
| import numpy as np | |
| import datetime | |
| import time | |
| import pickle | |
| def dataloaders(batch_size): | |
| # torchvisionの出力は[0, 1] | |
| trans = transforms.Compose([ | |
| transforms.ToTensor() | |
| ]) | |
| trainset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
| download=True, transform=trans) | |
| trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, | |
| shuffle=True, num_workers=4) | |
| testset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
| download=True, transform=trans) | |
| testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, | |
| shuffle=False, num_workers=4) | |
| # workers数のチューニング | |
| # WRN 512 1GPU | |
| # 01 = workers1:78s / workers4:84s / workers8=90s | |
| # WRN 512 2GPU | |
| # 01 = workers1:46s / workers4:48s / workers8=55s | |
| # workersが1のほうが速度が出る | |
| return trainloader, testloader | |
| def train(batch_size, network, use_device, opt_level): | |
| if network == 0: | |
| model = Layer10CNN() | |
| elif network == 1: | |
| model = WideResNet() | |
| device = "cuda" | |
| torch.backends.cudnn.benchmark = True | |
| model = model.cuda() | |
| train_loader, test_loader = dataloaders(batch_size) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| initial_lr = 0.1 * batch_size / 128 | |
| optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) | |
| scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1) | |
| # opt_level : O1=Mixed Precision(Recommend), O2=Almost FP16, O3=FP16 | |
| if opt_level != "O0": | |
| model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) | |
| if use_device == "multigpu": | |
| model = torch.nn.DataParallel(model) | |
| result = {} | |
| result["train_begin"] = datetime.datetime.now() | |
| result["times"] = [] | |
| result["val_acc"] = [] | |
| result["loss"] = [] | |
| for epoch in range(3): | |
| start_time = time.time() | |
| # train | |
| train_loss = 0.0 | |
| for i, (inputs, labels) in enumerate(train_loader): | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| if opt_level == "O0": | |
| loss.backward() | |
| else: | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_loss /= i+1 # per batch loss | |
| # Validation | |
| with torch.no_grad(): | |
| correct, total = 0, 0 | |
| for inputs, labels in test_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| _, pred = torch.max(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (pred == labels).sum().item() | |
| val_acc = correct / total | |
| # log | |
| elapsed = time.time() - start_time | |
| result["times"].append(elapsed) | |
| result["loss"].append(train_loss) | |
| result["val_acc"].append(val_acc) | |
| print(f"Epoch {epoch+1} loss = {train_loss:.06} val_acc = {val_acc:.04} | {elapsed:0.4}s") | |
| result["train_end"] = datetime.datetime.now() | |
| with open(f"result/{use_device}_{network}_{batch_size}.pkl", "wb") as fp: | |
| pickle.dump(result, fp) | |
| if __name__ == "__main__": | |
| train(512, 1, "multigpu", "O0") # FP32 | |
| train(512, 1, "multigpu", "O1") # Mixed-precision | |
| # 71.17s 59.82s 61.78s | |
| # 62.01s 50.67s 50.81s | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # 10 Layers Network | |
| class Layer10CNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.block1 = self.create_block(3, 64, 3, False) | |
| self.block2 = self.create_block(64, 128, 3, True) | |
| self.block3 = self.create_block(128, 256, 3, True) | |
| self.pool = nn.AvgPool2d(8) | |
| self.fc = nn.Linear(256, 10) | |
| def create_block(self, in_ch, out_ch, reps, initial_pool): | |
| # Conv->BN->ReLU x reps | |
| layers = [] | |
| if initial_pool: | |
| layers.append(nn.AvgPool2d(2)) | |
| for i in range(reps): | |
| in_n = in_ch if i == 0 else out_ch | |
| layers.append(nn.Conv2d(in_n, out_ch, 3, padding=1)) | |
| layers.append(nn.BatchNorm2d(out_ch)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out = self.block1(x) | |
| out = self.block2(out) | |
| out = self.block3(out) | |
| out = self.pool(out) | |
| out = out.view(out.size(0), -1) | |
| out = self.fc(out) | |
| return out | |
| # WideResNet 28-10 | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_ch, ch, stride, conv_before_skip): | |
| super().__init__() | |
| self.conv_before_skip = conv_before_skip | |
| if conv_before_skip: | |
| self.skip_conv = nn.Conv2d(in_ch, ch, 1, stride=stride) | |
| self.skip_bn = nn.BatchNorm2d(ch) | |
| self.main_conv1 = nn.Conv2d(in_ch, ch, 3, stride=stride, padding=1) | |
| self.main_bn1 = nn.BatchNorm2d(ch) | |
| self.main_conv2 = nn.Conv2d(ch, ch, 3, padding=1) | |
| self.main_bn2 = nn.BatchNorm2d(ch) | |
| def forward(self, x): | |
| if self.conv_before_skip: | |
| skip = F.relu(self.skip_bn(self.skip_conv(x)), True) | |
| else: | |
| skip = x | |
| main = F.relu(self.main_bn1(self.main_conv1(x)), True) | |
| main = F.relu(self.main_bn2(self.main_conv2(main)), True) | |
| return main + skip | |
| class WideResNet(nn.Module): | |
| def __init__(self, N=4, k=10): | |
| super().__init__() | |
| self.initial_conv = nn.Conv2d(3, 16, 3, padding=1) | |
| self.initial_bn = nn.BatchNorm2d(16) | |
| self.block1 = self.create_residual_blocks(16, 16*k, N, 1) | |
| self.block2 = self.create_residual_blocks(16 * k, 32 * k, N, 2) | |
| self.block3 = self.create_residual_blocks(32*k, 64*k, N, 2) | |
| self.pool = nn.AvgPool2d(8) | |
| self.fc = nn.Linear(64*k, 10) | |
| def create_residual_blocks(self, in_ch, out_ch, N, stride): | |
| layers = [] | |
| for i in range(N): | |
| if i == 0: | |
| layers.append(ResidualBlock(in_ch, out_ch, stride, True)) | |
| else: | |
| layers.append(ResidualBlock(out_ch, out_ch, 1, False)) | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| out = F.relu(self.initial_bn(self.initial_conv(x)), True) | |
| out = self.block1(out) | |
| out = self.block2(out) | |
| out = self.block3(out) | |
| out = self.pool(out) | |
| out = out.view(out.size(0), -1) | |
| out = self.fc(out) | |
| return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment