Skip to content

Instantly share code, notes, and snippets.

@arunm8489
Created June 5, 2020 03:02
Show Gist options
  • Save arunm8489/cd3c362b3e8254b052506f5c7d2668f1 to your computer and use it in GitHub Desktop.
Save arunm8489/cd3c362b3e8254b052506f5c7d2668f1 to your computer and use it in GitHub Desktop.
batch_size = prediction.size(0)
write = False
# we can do non max suppression only on individual images so we will loop through images
for ind in range(batch_size):
image_pred = prediction[ind]
# we will take only those rows with maximm class probability
# and corresponding index
max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1)
max_conf = max_conf.float().unsqueeze(1)
max_conf_score = max_conf_score.float().unsqueeze(1)
combined = (image_pred[:,:5], max_conf, max_conf_score)
# concatinating index values and max probability with box cordinates as columns
image_pred = torch.cat(combined, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment