Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save farukcankaya/e8c8a4ffeb53a70e5ee5d0f96bdb2441 to your computer and use it in GitHub Desktop.
Save farukcankaya/e8c8a4ffeb53a70e5ee5d0f96bdb2441 to your computer and use it in GitHub Desktop.
def calculate_validation_loss(validation_dataset_loader, model, storage):
data = next(validation_dataset_loader)
with torch.no_grad():
loss_dict = model(data)
losses = sum(loss_dict.values())
assert torch.isfinite(losses).all(), loss_dict
loss_dict_reduced = {"val_" + k: v.item() for k, v in
comm.reduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
if comm.is_main_process():
storage.put_scalars(total_val_loss=losses_reduced,
**loss_dict_reduced)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment