Skip to content

Instantly share code, notes, and snippets.

@elumixor
Created May 4, 2020 21:28
Show Gist options
  • Save elumixor/5a1ab6ecec754b1fe0083d9bde64c9fb to your computer and use it in GitHub Desktop.
Save elumixor/5a1ab6ecec754b1fe0083d9bde64c9fb to your computer and use it in GitHub Desktop.
import numpy as np
import torch
from torch.utils.data import SubsetRandomSampler, DataLoader
from time import sleep
from IPython.display import clear_output, display
import os
class Trainer:
def __init__(self, model, train_data, run_name, batch_size=128, epochs=10, lr=0.001,
loss=None, optimizer=None, validation_split=0., resume='last', device=None):
self.model = model
self.train_data = train_data
self.dataset_size = len(train_data)
self.epochs = epochs
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {self.device} for training.")
else:
self.device = device
self.model.to(device)
self.loss = loss if loss is not None else torch.nn.CrossEntropyLoss(reduction='none')
self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(model.parameters(), lr=lr)
self.do_validation = validation_split > 0
self.validation_split = validation_split
if self.do_validation:
indices = list(range(self.dataset_size))
split = int(np.floor(validation_split * self.dataset_size))
val_indices, train_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
self.train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, sampler=train_sampler)
self.validation_loader = DataLoader(train_data, batch_size=batch_size, num_workers=2, sampler=val_sampler)
else:
self.train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True)
self.validation_loader = None
# Load previous model if there is one
self.resume = resume
self.name = run_name
self.models_directory = f"./runs/{run_name}/models"
self.run_directory = f"./runs/{run_name}/"
runs_exist = os.path.isdir(self.run_directory)
models_exist = runs_exist and os.path.isdir(self.models_directory)
self.losses = []
self.validation_accuracies = []
self.losses_path = f"{self.run_directory}/losses.npy"
self.accuracies_path = f"{self.run_directory}/accuracies.npy"
self.epochs_passed = 0
if resume:
if run_name:
if runs_exist:
if os.path.isfile(self.losses_path):
self.losses = np.load(self.losses_path, allow_pickle=True).tolist()
print("Previous losses loaded")
else:
print("Losses file was not found.")
if self.do_validation and os.path.isfile(self.accuracies_path):
self.validation_accuracies = np.load(self.accuracies_path, allow_pickle=True).tolist()
print("Previous validation accuracies loaded")
else:
print("Accuracies file was not found.")
self.epochs_passed = len(self.losses)
if models_exist:
best_found = False
if resume == 'best':
best_path = f"{self.models_directory}/best.model"
if os.path.isfile(best_path):
print(f"Previous (best) model found. Using {best_path}")
model.load_state_dict(torch.load(best_path))
best_found = True
else:
print(f"Previous (best) model not found.")
if resume == 'last' or not best_found:
models = [os.path.join(self.models_directory, file) for file in
os.listdir(self.models_directory)
if file.endswith(".model")]
models = sorted(models)
if len(models) > 0:
last_model = models[-1]
print(f"Previous (last) model found. Using {last_model}")
model.load_state_dict(torch.load(last_model))
else:
print(f"No previous model found. Will train from scratch.")
else:
print("Please specify run name to resume.")
else:
files_found = False
for _, _, files in os.walk(self.run_directory):
if files:
files_found = True
break
if files_found:
raise Exception(f"Run directory {self.run_directory} is not empty.")
# Create directory for models if it does not exist
if not models_exist:
os.makedirs(self.models_directory)
def train(self):
print("Starting training")
self.model.train()
self.model.to(self.device)
if self.do_validation:
self.__train_with_validation()
else:
self.__train_without_validation()
def __train_with_validation(self):
best = self.model.state_dict()
torch.save(best, f'{self.models_directory}/best.model')
best_val_acc = float('-inf')
step = 0
for epoch in range(self.epochs_passed + 1, self.epochs + 1):
completed = 0
for data, labels in self.train_loader:
loss = self.__train_batch(data, labels)
completed += len(data)
step += len(data)
self.losses.append([loss, step])
clear_output(wait=True)
print(f'Epoch {epoch}/{self.epochs}: {completed / self.dataset_size * 100:.2f}% completed.\n'
f'Loss: {loss}.')
torch.save(self.model.state_dict(), f'{self.models_directory}/{epoch}.model')
np.save(self.losses_path, np.array(self.losses))
val_accuracy = self.__get_accuracy(self.validation_loader)
if val_accuracy > best_val_acc:
best_val_acc = val_accuracy
best = self.model.state_dict()
torch.save(best, f'{self.models_directory}/best.model')
print(f'Validation accuracy: {val_accuracy}')
self.validation_accuracies.append([val_accuracy, step])
np.save(self.accuracies_path, self.validation_accuracies)
def __train_without_validation(self):
step = 0
for epoch in range(self.epochs_passed + 1, self.epochs + 1):
completed = 0
for data, labels in self.train_loader:
loss = self.__train_batch(data, labels)
completed += len(data)
step += len(data)
self.losses.append((loss, step))
clear_output(wait=True)
print(f'Epoch {epoch}/{self.epochs}: {completed / self.dataset_size * 100:.2f}% completed.\n'
f'Loss: {loss}.')
torch.save(self.model.state_dict(), f'{self.models_directory}/{epoch}.model')
torch.save(self.losses, self.losses_path)
def __train_batch(self, data, labels):
data = data.to(self.device)
labels = labels.to(self.device)
out = self.model(data)
loss = self.loss(out, labels)
loss = loss.mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def __get_accuracy(self, loader):
self.model.eval()
with torch.no_grad():
L = 0
total = 0
for data, targets in loader:
data = data.to(self.device)
targets = targets.to(self.device)
out = self.model(data)
_, predicted = torch.max(out, 1)
# Get the total number of the wrong predictions
batch_loss = (1 - (targets == predicted).float()).sum()
L += batch_loss
total += len(data)
L /= total
return L
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment