Created
July 16, 2021 09:52
-
-
Save Chris-hughes10/8a8748457a71bdbd6e1bb9d1f1d3c493 to your computer and use it in GitHub Desktop.
Effdet_blog_aggregate_outputs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastcore.basics import patch | |
@patch | |
def aggregate_prediction_outputs(self: EfficientDetModel, outputs): | |
detections = torch.cat( | |
[output["batch_predictions"]["predictions"] for output in outputs] | |
) | |
image_ids = [] | |
targets = [] | |
for output in outputs: | |
batch_predictions = output["batch_predictions"] | |
image_ids.extend(batch_predictions["image_ids"]) | |
targets.extend(batch_predictions["targets"]) | |
( | |
predicted_bboxes, | |
predicted_class_confidences, | |
predicted_class_labels, | |
) = self.post_process_detections(detections) | |
return ( | |
predicted_class_labels, | |
image_ids, | |
predicted_bboxes, | |
predicted_class_confidences, | |
targets, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment