Created
January 19, 2020 21:13
-
-
Save lukas/6f4c9db080af2fbe152f1d820d7ae664 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import ArgumentParser | |
import wandb | |
import torch | |
from torch import nn | |
from torch.optim import SGD | |
from torch.utils.data import DataLoader | |
import torch.nn.functional as F | |
from torchvision.transforms import Compose, ToTensor, Normalize | |
from torchvision.datasets import MNIST | |
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator | |
from ignite.metrics import Accuracy, Loss | |
from tqdm import tqdm | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.conv2_drop = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
x = x.view(-1, 320) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return F.log_softmax(x, dim=-1) | |
def get_data_loaders(train_batch_size, val_batch_size): | |
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) | |
train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True), | |
batch_size=train_batch_size, shuffle=True) | |
val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False), | |
batch_size=val_batch_size, shuffle=False) | |
return train_loader, val_loader | |
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval): | |
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) | |
model = Net() | |
wandb.watch(model) | |
device = 'cpu' | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) | |
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device) | |
evaluator = create_supervised_evaluator(model, | |
metrics={'accuracy': Accuracy(), | |
'nll': Loss(F.nll_loss)}, | |
device=device) | |
desc = "ITERATION - loss: {:.2f}" | |
pbar = tqdm( | |
initial=0, leave=False, total=len(train_loader), | |
desc=desc.format(0) | |
) | |
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) | |
def log_training_loss(engine): | |
pbar.desc = desc.format(engine.state.output) | |
pbar.update(log_interval) | |
wandb.log({"train loss": engine.state.output}) | |
@trainer.on(Events.EPOCH_COMPLETED) | |
def log_training_results(engine): | |
pbar.refresh() | |
evaluator.run(train_loader) | |
metrics = evaluator.state.metrics | |
avg_accuracy = metrics['accuracy'] | |
avg_nll = metrics['nll'] | |
tqdm.write( | |
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" | |
.format(engine.state.epoch, avg_accuracy, avg_nll) | |
) | |
@trainer.on(Events.EPOCH_COMPLETED) | |
def log_validation_results(engine): | |
evaluator.run(val_loader) | |
metrics = evaluator.state.metrics | |
avg_accuracy = metrics['accuracy'] | |
avg_nll = metrics['nll'] | |
tqdm.write( | |
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" | |
.format(engine.state.epoch, avg_accuracy, avg_nll)) | |
pbar.n = pbar.last_print_n = 0 | |
wandb.log({"validation loss": engine.state.metrics['nll']}) | |
wandb.log({"validation accuracy": engine.state.metrics['accuracy']}) | |
trainer.run(train_loader, max_epochs=epochs) | |
pbar.close() | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument('--batch_size', type=int, default=64, | |
help='input batch size for training (default: 64)') | |
parser.add_argument('--val_batch_size', type=int, default=1000, | |
help='input batch size for validation (default: 1000)') | |
parser.add_argument('--epochs', type=int, default=10, | |
help='number of epochs to train (default: 10)') | |
parser.add_argument('--lr', type=float, default=0.01, | |
help='learning rate (default: 0.01)') | |
parser.add_argument('--momentum', type=float, default=0.5, | |
help='SGD momentum (default: 0.5)') | |
parser.add_argument('--log_interval', type=int, default=10, | |
help='how many batches to wait before logging training status') | |
args = parser.parse_args() | |
wandb.init(config=args) | |
run(args.batch_size, args.val_batch_size, args.epochs, args.lr, args.momentum, args.log_interval) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment