Tip: when populating a grid of mpl axes with unknown actual number of datasets, some axes may be empty: clear them AFTER the main loop.
Code source: DETR_panoptic notebook on colab
- In the enumerated loop below, we don't know where
i
will end (becauseout["pred_masks"][keep]
was not saved into a variable, hence we can't access its length). - We then use the main loop index into a second one to clear any empty axis.
#[...]
scores = out["pred_logits"].softmax(-1)[..., :-1].max(-1)[0]
keep = scores > 0.85
n_cols = 5
n_rows = math.ceil(keep.sum().item()/n_cols)
fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(10, 10))
for i, mask in enumerate(out["pred_masks"][keep]):
ax = axs[i//n_cols, i%n_cols]
ax.imshow(mask, cmap="cividis")
ax.axis('off')
# Clear the unused axes beyond the last axis that was populated with an image:
for empty_ax in axs[i//n_cols, :]:
empty_ax.axis('off')
fig.tight_layout()