Created
July 16, 2021 09:56
-
-
Save Chris-hughes10/0cf9030076346b2b38db2279ccd91e90 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from objdetecteval.metrics.coco_metrics import get_coco_stats | |
@patch | |
def validation_epoch_end(self: EfficientDetModel, outputs): | |
"""Compute and log training loss and accuracy at the epoch level.""" | |
validation_loss_mean = torch.stack( | |
[output["loss"] for output in outputs] | |
).mean() | |
( | |
predicted_class_labels, | |
image_ids, | |
predicted_bboxes, | |
predicted_class_confidences, | |
targets, | |
) = self.aggregate_prediction_outputs(outputs) | |
truth_image_ids = [target["image_id"].detach().item() for target in targets] | |
truth_boxes = [ | |
target["bboxes"].detach()[:, [1, 0, 3, 2]].tolist() for target in targets | |
] # convert to xyxy for evaluation | |
truth_labels = [target["labels"].detach().tolist() for target in targets] | |
stats = get_coco_stats( | |
prediction_image_ids=image_ids, | |
predicted_class_confidences=predicted_class_confidences, | |
predicted_bboxes=predicted_bboxes, | |
predicted_class_labels=predicted_class_labels, | |
target_image_ids=truth_image_ids, | |
target_bboxes=truth_boxes, | |
target_class_labels=truth_labels, | |
)['All'] | |
return {"val_loss": validation_loss_mean, "metrics": stats} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment