Created
October 2, 2022 22:39
-
-
Save farukcankaya/e8c8a4ffeb53a70e5ee5d0f96bdb2441 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def calculate_validation_loss(validation_dataset_loader, model, storage): | |
data = next(validation_dataset_loader) | |
with torch.no_grad(): | |
loss_dict = model(data) | |
losses = sum(loss_dict.values()) | |
assert torch.isfinite(losses).all(), loss_dict | |
loss_dict_reduced = {"val_" + k: v.item() for k, v in | |
comm.reduce_dict(loss_dict).items()} | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
if comm.is_main_process(): | |
storage.put_scalars(total_val_loss=losses_reduced, | |
**loss_dict_reduced) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment