Skip to content

Instantly share code, notes, and snippets.

@jbencina
Created April 6, 2023 02:55
Show Gist options
  • Save jbencina/d1200253ef1ae5d99287f47616687fef to your computer and use it in GitHub Desktop.
Save jbencina/d1200253ef1ae5d99287f47616687fef to your computer and use it in GitHub Desktop.
Interactive plot for Segment Anything Model example
%matplotlib widget # Install ipympl https://github.com/matplotlib/ipympl
import matplotlib.pyplot as plt
# Previous steps followed from example notebook at
# https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb
def show_mask(mask, ax):
# From SAM notebook
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
# From SAM notebook
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def get_masks(input_point, input_label, image_embedding):
# From Segment Anything repo compressed into single function
# Image embedding computed previously
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
ort_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}
masks, _, low_res_logits = ort_session.run(None, ort_inputs)
return masks > predictor.model.mask_threshold
# Plot initial image and supress extra elements
fig, ax = plt.subplots()
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
ax.imshow(image)
ax.axis('off')
def onclick(event):
# Update mask on each click event
ax.cla() # Clear previous clicks
ax.imshow(image)
ax.axis('off')
input_point = np.array([[event.xdata, event.ydata]])
input_label = np.array([1])
masks = get_masks(input_point, input_label, image_embedding)
show_mask(masks, ax)
show_points(input_point, input_label, ax)
cid = fig.canvas.mpl_connect('button_press_event', onclick)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment