Skip to content

Instantly share code, notes, and snippets.

@J3698
Last active July 1, 2021 04:27
Show Gist options
  • Save J3698/1612863d591815789a20e5defcf0b812 to your computer and use it in GitHub Desktop.
Save J3698/1612863d591815789a20e5defcf0b812 to your computer and use it in GitHub Desktop.
def train_epoch_style_loss(args, encoder, decoder, dataloader, val_dataloader,
optimizer, epoch_num, writer, run, device):
encoder.eval()
decoder.train()
total_loss = 0
num_batches = calc_num_batches(dataloader, args)
progress_bar = tqdm.tqdm(enumerate(dataloader), total = num_batches, dynamic_ncols = True)
for i, (content_image, style_image) in progress_bar:
# mvoe to gpu
content_image = content_image.to(device)
style_image = style_image.to(device)
# training
optimizer.zero_grad()
loss, stylized = get_style_transfer_loss(encoder, decoder, content_image, style_image, args.lambda_content, args.lambda_style)
loss.backward()
total_loss += loss.item()
optimizer.step()
# logging
iteration = epoch_num * num_batches + i
write_to_tensorboard(iteration, args, encoder, decoder, val_dataloader, writer, device)
progress_bar.set_postfix({'epoch': f"{epoch_num}", 'loss': f"{total_loss / (i + 1):.2f}"})
writer.add_scalar('Loss/train', total_loss, epoch_num)
return total_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment