Skip to content

Instantly share code, notes, and snippets.

@vsay01
Created January 22, 2020 02:05
Show Gist options
  • Save vsay01/749a689cd3c1d85cd809a2f2780827dd to your computer and use it in GitHub Desktop.
Save vsay01/749a689cd3c1d85cd809a2f2780827dd to your computer and use it in GitHub Desktop.
def train(start_epochs, n_epochs, valid_loss_min_input, loaders, model, optimizer, criterion, use_cuda, checkpoint_path, best_model_path):
"""
Keyword arguments:
start_epochs -- the real part (default 0.0)
n_epochs -- the imaginary part (default 0.0)
valid_loss_min_input
loaders
model
optimizer
criterion
use_cuda
checkpoint_path
best_model_path
returns trained model
"""
# initialize tracker for minimum validation loss
valid_loss_min = valid_loss_min_input
for epoch in range(start_epochs, n_epochs+1):
# initialize variables to monitor training and validation loss
train_loss = 0.0
valid_loss = 0.0
###################
# train the model #
###################
model.train()
for batch_idx, (data, target) in enumerate(loaders['train']):
# move to GPU
if use_cuda:
data, target = data.cuda(), target.cuda()
## find the loss and update the model parameters accordingly
# clear the gradients of all optimized variables
optimizer.zero_grad()
# forward pass: compute predicted outputs by passing inputs to the model
output = model(data)
# calculate the batch loss
loss = criterion(output, target)
# backward pass: compute gradient of the loss with respect to model parameters
loss.backward()
# perform a single optimization step (parameter update)
optimizer.step()
## record the average training loss, using something like
## train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))
######################
# validate the model #
######################
model.eval()
for batch_idx, (data, target) in enumerate(loaders['test']):
# move to GPU
if use_cuda:
data, target = data.cuda(), target.cuda()
## update the average validation loss
# forward pass: compute predicted outputs by passing inputs to the model
output = model(data)
# calculate the batch loss
loss = criterion(output, target)
# update average validation loss
valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.data - valid_loss))
# calculate average losses
train_loss = train_loss/len(loaders['train'].dataset)
valid_loss = valid_loss/len(loaders['test'].dataset)
# print training/validation statistics
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
epoch,
train_loss,
valid_loss
))
# create checkpoint variable and add important data
checkpoint = {
'epoch': epoch + 1,
'valid_loss_min': valid_loss,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
# save checkpoint
save_ckp(checkpoint, False, checkpoint_path, best_model_path)
## TODO: save the model if validation loss has decreased
if valid_loss <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(valid_loss_min,valid_loss))
# save checkpoint as best model
save_ckp(checkpoint, True, checkpoint_path, best_model_path)
valid_loss_min = valid_loss
# return trained model
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment