Created
May 4, 2020 21:28
-
-
Save elumixor/5a1ab6ecec754b1fe0083d9bde64c9fb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch.utils.data import SubsetRandomSampler, DataLoader | |
from time import sleep | |
from IPython.display import clear_output, display | |
import os | |
class Trainer: | |
def __init__(self, model, train_data, run_name, batch_size=128, epochs=10, lr=0.001, | |
loss=None, optimizer=None, validation_split=0., resume='last', device=None): | |
self.model = model | |
self.train_data = train_data | |
self.dataset_size = len(train_data) | |
self.epochs = epochs | |
if device is None: | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using {self.device} for training.") | |
else: | |
self.device = device | |
self.model.to(device) | |
self.loss = loss if loss is not None else torch.nn.CrossEntropyLoss(reduction='none') | |
self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(model.parameters(), lr=lr) | |
self.do_validation = validation_split > 0 | |
self.validation_split = validation_split | |
if self.do_validation: | |
indices = list(range(self.dataset_size)) | |
split = int(np.floor(validation_split * self.dataset_size)) | |
val_indices, train_indices = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_indices) | |
val_sampler = SubsetRandomSampler(val_indices) | |
self.train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, sampler=train_sampler) | |
self.validation_loader = DataLoader(train_data, batch_size=batch_size, num_workers=2, sampler=val_sampler) | |
else: | |
self.train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True) | |
self.validation_loader = None | |
# Load previous model if there is one | |
self.resume = resume | |
self.name = run_name | |
self.models_directory = f"./runs/{run_name}/models" | |
self.run_directory = f"./runs/{run_name}/" | |
runs_exist = os.path.isdir(self.run_directory) | |
models_exist = runs_exist and os.path.isdir(self.models_directory) | |
self.losses = [] | |
self.validation_accuracies = [] | |
self.losses_path = f"{self.run_directory}/losses.npy" | |
self.accuracies_path = f"{self.run_directory}/accuracies.npy" | |
self.epochs_passed = 0 | |
if resume: | |
if run_name: | |
if runs_exist: | |
if os.path.isfile(self.losses_path): | |
self.losses = np.load(self.losses_path, allow_pickle=True).tolist() | |
print("Previous losses loaded") | |
else: | |
print("Losses file was not found.") | |
if self.do_validation and os.path.isfile(self.accuracies_path): | |
self.validation_accuracies = np.load(self.accuracies_path, allow_pickle=True).tolist() | |
print("Previous validation accuracies loaded") | |
else: | |
print("Accuracies file was not found.") | |
self.epochs_passed = len(self.losses) | |
if models_exist: | |
best_found = False | |
if resume == 'best': | |
best_path = f"{self.models_directory}/best.model" | |
if os.path.isfile(best_path): | |
print(f"Previous (best) model found. Using {best_path}") | |
model.load_state_dict(torch.load(best_path)) | |
best_found = True | |
else: | |
print(f"Previous (best) model not found.") | |
if resume == 'last' or not best_found: | |
models = [os.path.join(self.models_directory, file) for file in | |
os.listdir(self.models_directory) | |
if file.endswith(".model")] | |
models = sorted(models) | |
if len(models) > 0: | |
last_model = models[-1] | |
print(f"Previous (last) model found. Using {last_model}") | |
model.load_state_dict(torch.load(last_model)) | |
else: | |
print(f"No previous model found. Will train from scratch.") | |
else: | |
print("Please specify run name to resume.") | |
else: | |
files_found = False | |
for _, _, files in os.walk(self.run_directory): | |
if files: | |
files_found = True | |
break | |
if files_found: | |
raise Exception(f"Run directory {self.run_directory} is not empty.") | |
# Create directory for models if it does not exist | |
if not models_exist: | |
os.makedirs(self.models_directory) | |
def train(self): | |
print("Starting training") | |
self.model.train() | |
self.model.to(self.device) | |
if self.do_validation: | |
self.__train_with_validation() | |
else: | |
self.__train_without_validation() | |
def __train_with_validation(self): | |
best = self.model.state_dict() | |
torch.save(best, f'{self.models_directory}/best.model') | |
best_val_acc = float('-inf') | |
step = 0 | |
for epoch in range(self.epochs_passed + 1, self.epochs + 1): | |
completed = 0 | |
for data, labels in self.train_loader: | |
loss = self.__train_batch(data, labels) | |
completed += len(data) | |
step += len(data) | |
self.losses.append([loss, step]) | |
clear_output(wait=True) | |
print(f'Epoch {epoch}/{self.epochs}: {completed / self.dataset_size * 100:.2f}% completed.\n' | |
f'Loss: {loss}.') | |
torch.save(self.model.state_dict(), f'{self.models_directory}/{epoch}.model') | |
np.save(self.losses_path, np.array(self.losses)) | |
val_accuracy = self.__get_accuracy(self.validation_loader) | |
if val_accuracy > best_val_acc: | |
best_val_acc = val_accuracy | |
best = self.model.state_dict() | |
torch.save(best, f'{self.models_directory}/best.model') | |
print(f'Validation accuracy: {val_accuracy}') | |
self.validation_accuracies.append([val_accuracy, step]) | |
np.save(self.accuracies_path, self.validation_accuracies) | |
def __train_without_validation(self): | |
step = 0 | |
for epoch in range(self.epochs_passed + 1, self.epochs + 1): | |
completed = 0 | |
for data, labels in self.train_loader: | |
loss = self.__train_batch(data, labels) | |
completed += len(data) | |
step += len(data) | |
self.losses.append((loss, step)) | |
clear_output(wait=True) | |
print(f'Epoch {epoch}/{self.epochs}: {completed / self.dataset_size * 100:.2f}% completed.\n' | |
f'Loss: {loss}.') | |
torch.save(self.model.state_dict(), f'{self.models_directory}/{epoch}.model') | |
torch.save(self.losses, self.losses_path) | |
def __train_batch(self, data, labels): | |
data = data.to(self.device) | |
labels = labels.to(self.device) | |
out = self.model(data) | |
loss = self.loss(out, labels) | |
loss = loss.mean() | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
return loss.item() | |
def __get_accuracy(self, loader): | |
self.model.eval() | |
with torch.no_grad(): | |
L = 0 | |
total = 0 | |
for data, targets in loader: | |
data = data.to(self.device) | |
targets = targets.to(self.device) | |
out = self.model(data) | |
_, predicted = torch.max(out, 1) | |
# Get the total number of the wrong predictions | |
batch_loss = (1 - (targets == predicted).float()).sum() | |
L += batch_loss | |
total += len(data) | |
L /= total | |
return L |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment