Skip to content

Instantly share code, notes, and snippets.

@KerryHalupka
Last active August 16, 2020 09:42
Show Gist options
  • Save KerryHalupka/8f611165bc7ef9a62d948728fa776fbe to your computer and use it in GitHub Desktop.
Save KerryHalupka/8f611165bc7ef9a62d948728fa776fbe to your computer and use it in GitHub Desktop.
from torch.utils.tensorboard import SummaryWriter
# Setup tensorboard
file_writer = SummaryWriter(log_dir=log_dir + '/metrics')
best_val_acc = 0 # for model check pointing
# Epoch loop
for epoch in range(1, num_epoch + 1):
start_time = timer()
# Reset metrics
train_loss = 0.0
val_loss = 0.0
train_correct = 0.0
val_correct = 0.0
# Training loop
model.train()
for inputs, targets in train_gen:
# use GPU if available
inputs = inputs.to(device)
targets = targets.to(device)
inputs = inputs.float()
targets = targets.view(-1,1).float()
# Training steps
optimizer.zero_grad() # clear gradients
output = model(inputs) # forward pass: predict outputs for each image
loss = loss_fn(output, targets) # calculate loss
loss.backward() # backward pass: compute gradient of the loss wrt model parameters
optimizer.step() # update parameters
train_loss += loss.item() * inputs.size(0) # update training loss
train_correct += ((output>0.5) == targets).float().sum() # update training accuracy
# Validation loop
model.eval()
for inputs, targets in val_gen:
# use GPU if available
inputs = inputs.to(device)
targets = targets.to(device)
inputs=inputs.float()
targets = targets.view(-1,1).float()
# Validation steps
with torch.no_grad(): #not calculating gradients every step
output = model(inputs) # forward pass: predict outputs for each image
loss = loss_fn(output, targets) # calculate loss
val_loss += loss.item() * inputs.size(0) # update validation loss
val_correct += ((output>0.5) == targets).float().sum() # update validation accuracy
# calculate average losses and accuracy
train_loss = train_loss/len(train_gen.sampler)
val_loss = val_loss/len(val_gen.sampler)
train_acc = train_correct / len(train_gen.sampler)
val_acc = val_correct / len(val_gen.sampler)
end_time = timer() # get time taken for epoch
# Display metrics at the end of each epoch.
print(f'Epoch: {epoch} \tTraining Loss: {train_loss} \tValidation Loss: {val_loss} \tTraining Accuracy: {train_acc} \tValidation Accuracy: {val_acc} \t Time taken: {end_time - start_time}')
# Log metrics to tensorboard
file_writer.add_scalar('Loss/train', train_loss, epoch)
file_writer.add_scalar('Loss/validation', val_loss, epoch)
file_writer.add_scalar('Accuracy/train', train_acc, epoch)
file_writer.add_scalar('Accuracy/validation', val_acc, epoch)
file_writer.add_scalar('epoch_time', end_time - start_time, epoch)
# checkpoint if improved
if val_acc>best_val_acc:
state_dict = model.state_dict()
torch.save(state_dict, ckpt_dir+'.pt')
best_val_acc = val_acc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment