Skip to content

Instantly share code, notes, and snippets.

@erykml
Last active January 25, 2019 21:56
Show Gist options
  • Select an option

  • Save erykml/52bba6103ba1aa233d3b2e2b480423c7 to your computer and use it in GitHub Desktop.

Select an option

Save erykml/52bba6103ba1aa233d3b2e2b480423c7 to your computer and use it in GitHub Desktop.
def train_cnn(model, train_loader, valid_loader,
criterion, optimizer, n_epochs = 30, train_on_gpu = False,
save_model_on_improvement = True, plot_loss = True):
'''
Function for training the CNN given input parameters. Can be run on CPU or GPU.
The function automatically verifies whether the selected criterion is Binary cross-entropy and if so
converts tensors to appropriate type.
Inputs:
model - architecture of the neural network defined using either Class approach or Sequential
train_loader - loader of the dataset used for training
valid_loader - loader of the dataset used for validation
criterion - loss function
optimizer - selected optimizer
n_epochs - number of epochs
train_on_gpu - boolean; whether to train using GPU
save_model_on_improvement - boolean; whether to save the model when validation loss decreases compared to previous epoch
plot_loss - boolean; whether to plot the train/validation loss over epochs
'''
valid_loss_min = np.Inf # track change in validation loss
train_losses, valid_losses = [], []
times = []
for epoch in range(1, n_epochs + 1):
# keep track of training and validation loss
train_loss = 0.0
valid_loss = 0.0
# keep track of time
if train_on_gpu:
torch.cuda.synchronize()
t0 = time.perf_counter()
# train the model ----
model.train()
for data, target in train_loader:
# move tensors to GPU if CUDA is available
if train_on_gpu:
data, target = data.cuda(), target.cuda()
# convert long to float tensor if eval. criterion = BCEWithLogitsLoss
if ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss)
and (target.type() == 'torch.cuda.LongTensor')
and train_on_gpu):
target = target.type(torch.cuda.FloatTensor)
elif ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss) and
(target.type() == 'torch.LongTensor')):
target = target.type(torch.FloatTensor)
# reset the gradients of all optimized variables
optimizer.zero_grad()
# 1. forward pass
output = model(data).squeeze()
# 2. calculate the batch loss
loss = criterion(output, target)
# 3. backward pass
loss.backward()
# 4. perform a single optimization step (parameter update)
optimizer.step()
# update training loss
train_loss += loss.item()*data.size(0)
# validate the model ----
# turn off gradients
with torch.no_grad():
model.eval()
for data, target in valid_loader:
# move tensors to GPU if CUDA is available
if train_on_gpu:
data, target = data.cuda(), target.cuda()
# convert long to float tensor if eval. criterion = BCEWithLogitsLoss
if ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss)
and (target.type() == 'torch.cuda.LongTensor')
and train_on_gpu):
target = target.type(torch.cuda.FloatTensor)
elif ((type(criterion) == torch.nn.modules.loss.BCEWithLogitsLoss) and
(target.type() == 'torch.LongTensor')):
target = target.type(torch.FloatTensor)
# 1. forward pass
output = model(data).squeeze()
# 2. calculate the batch loss
loss = criterion(output, target)
# update average validation loss
valid_loss += loss.item()*data.size(0)
# calculate average losses
train_loss = train_loss/len(train_loader.dataset)
train_losses.append(train_loss)
valid_loss = valid_loss/len(valid_loader.dataset)
valid_losses.append(valid_loss)
# append time of entire epoch (not counting model saving etc.)
if train_on_gpu:
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append(t1 - t0)
# print training/validation statistics
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTraining time: {:.2f} s'.format(
epoch, train_loss, valid_loss, t1 - t0))
# save model if validation loss has decreased
if valid_loss <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
valid_loss_min,
valid_loss))
torch.save(model.state_dict(), 'model_mvw.pt')
valid_loss_min = valid_loss
if plot_loss:
plt.plot(train_losses, label='Training loss')
plt.plot(valid_losses, label='Validation loss')
plt.legend(frameon=False)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment