Created
April 6, 2023 02:55
-
-
Save jbencina/d1200253ef1ae5d99287f47616687fef to your computer and use it in GitHub Desktop.
Interactive plot for Segment Anything Model example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%matplotlib widget # Install ipympl https://github.com/matplotlib/ipympl | |
import matplotlib.pyplot as plt | |
# Previous steps followed from example notebook at | |
# https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb | |
def show_mask(mask, ax): | |
# From SAM notebook | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
# From SAM notebook | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def get_masks(input_point, input_label, image_embedding): | |
# From Segment Anything repo compressed into single function | |
# Image embedding computed previously | |
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] | |
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) | |
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) | |
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | |
onnx_has_mask_input = np.zeros(1, dtype=np.float32) | |
ort_inputs = { | |
"image_embeddings": image_embedding, | |
"point_coords": onnx_coord, | |
"point_labels": onnx_label, | |
"mask_input": onnx_mask_input, | |
"has_mask_input": onnx_has_mask_input, | |
"orig_im_size": np.array(image.shape[:2], dtype=np.float32) | |
} | |
masks, _, low_res_logits = ort_session.run(None, ort_inputs) | |
return masks > predictor.model.mask_threshold | |
# Plot initial image and supress extra elements | |
fig, ax = plt.subplots() | |
fig.canvas.toolbar_visible = False | |
fig.canvas.header_visible = False | |
fig.canvas.footer_visible = False | |
ax.imshow(image) | |
ax.axis('off') | |
def onclick(event): | |
# Update mask on each click event | |
ax.cla() # Clear previous clicks | |
ax.imshow(image) | |
ax.axis('off') | |
input_point = np.array([[event.xdata, event.ydata]]) | |
input_label = np.array([1]) | |
masks = get_masks(input_point, input_label, image_embedding) | |
show_mask(masks, ax) | |
show_points(input_point, input_label, ax) | |
cid = fig.canvas.mpl_connect('button_press_event', onclick) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment