Last active
August 16, 2020 09:42
-
-
Save KerryHalupka/8f611165bc7ef9a62d948728fa776fbe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.tensorboard import SummaryWriter | |
# Setup tensorboard | |
file_writer = SummaryWriter(log_dir=log_dir + '/metrics') | |
best_val_acc = 0 # for model check pointing | |
# Epoch loop | |
for epoch in range(1, num_epoch + 1): | |
start_time = timer() | |
# Reset metrics | |
train_loss = 0.0 | |
val_loss = 0.0 | |
train_correct = 0.0 | |
val_correct = 0.0 | |
# Training loop | |
model.train() | |
for inputs, targets in train_gen: | |
# use GPU if available | |
inputs = inputs.to(device) | |
targets = targets.to(device) | |
inputs = inputs.float() | |
targets = targets.view(-1,1).float() | |
# Training steps | |
optimizer.zero_grad() # clear gradients | |
output = model(inputs) # forward pass: predict outputs for each image | |
loss = loss_fn(output, targets) # calculate loss | |
loss.backward() # backward pass: compute gradient of the loss wrt model parameters | |
optimizer.step() # update parameters | |
train_loss += loss.item() * inputs.size(0) # update training loss | |
train_correct += ((output>0.5) == targets).float().sum() # update training accuracy | |
# Validation loop | |
model.eval() | |
for inputs, targets in val_gen: | |
# use GPU if available | |
inputs = inputs.to(device) | |
targets = targets.to(device) | |
inputs=inputs.float() | |
targets = targets.view(-1,1).float() | |
# Validation steps | |
with torch.no_grad(): #not calculating gradients every step | |
output = model(inputs) # forward pass: predict outputs for each image | |
loss = loss_fn(output, targets) # calculate loss | |
val_loss += loss.item() * inputs.size(0) # update validation loss | |
val_correct += ((output>0.5) == targets).float().sum() # update validation accuracy | |
# calculate average losses and accuracy | |
train_loss = train_loss/len(train_gen.sampler) | |
val_loss = val_loss/len(val_gen.sampler) | |
train_acc = train_correct / len(train_gen.sampler) | |
val_acc = val_correct / len(val_gen.sampler) | |
end_time = timer() # get time taken for epoch | |
# Display metrics at the end of each epoch. | |
print(f'Epoch: {epoch} \tTraining Loss: {train_loss} \tValidation Loss: {val_loss} \tTraining Accuracy: {train_acc} \tValidation Accuracy: {val_acc} \t Time taken: {end_time - start_time}') | |
# Log metrics to tensorboard | |
file_writer.add_scalar('Loss/train', train_loss, epoch) | |
file_writer.add_scalar('Loss/validation', val_loss, epoch) | |
file_writer.add_scalar('Accuracy/train', train_acc, epoch) | |
file_writer.add_scalar('Accuracy/validation', val_acc, epoch) | |
file_writer.add_scalar('epoch_time', end_time - start_time, epoch) | |
# checkpoint if improved | |
if val_acc>best_val_acc: | |
state_dict = model.state_dict() | |
torch.save(state_dict, ckpt_dir+'.pt') | |
best_val_acc = val_acc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment