Skip to content

Instantly share code, notes, and snippets.

@mathematicalmichael
Last active February 12, 2025 03:53
Show Gist options
  • Save mathematicalmichael/63e04c727225303229ed57a543d966a3 to your computer and use it in GitHub Desktop.
Save mathematicalmichael/63e04c727225303229ed57a543d966a3 to your computer and use it in GitHub Desktop.
Streamlit App for Insect Segmentation with Segment Anything
#!/usr/bin/env python3
"""
Streamlit app for segmenting insects using SAM (Segment Anything Model).
uv run \
--with streamlit \
--with segment_anything \
--with opencv-python-headless \
--with torch \
--with matplotlib \
streamlit run sam_segment_st.py
"""
import json
import cv2
import numpy as np
import streamlit as st
import torch
from matplotlib import pyplot as plt
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
@st.cache_resource
def load_sam_model():
"""Load SAM model with caching."""
model_type = "vit_h" # Using the highest quality model
checkpoint = "sam_vit_h_4b8939.pth"
# Force CPU for now due to MPS float64 issues
device = "cpu"
st.info(f"Using device: {device}")
# Load model
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
return sam, device
def process_image(image, mask_generator, min_area=0.0001, max_area=0.1):
"""
Generate segments using SAM's automatic mask generator.
Args:
image: RGB image array
mask_generator: SAM automatic mask generator
min_area: Minimum area as fraction of image area
max_area: Maximum area as fraction of image area
"""
# Ensure image is uint8
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
# Get image area for filtering
image_area = image.shape[0] * image.shape[1]
min_area_pixels = image_area * min_area
max_area_pixels = image_area * max_area
# Generate masks
with torch.inference_mode():
masks = mask_generator.generate(image)
# Filter masks by area and sort by area
filtered_masks = []
for mask in masks:
area = mask["area"]
if min_area_pixels <= area <= max_area_pixels:
filtered_masks.append(mask)
# Sort by area, largest first
filtered_masks = sorted(filtered_masks, key=lambda x: x["area"], reverse=True)
return filtered_masks
def plot_results(image, masks):
"""Plot original image and segmentation results."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
# Original image
ax1.imshow(image)
ax1.set_title("Original")
ax1.axis("off")
# Segmentation
ax2.imshow(image)
# Plot masks with random colors and transparency
for mask in masks:
color = np.random.rand(
3,
).astype(
np.float32
) # Force float32
mask_array = mask["segmentation"]
# Create mask overlay
mask_overlay = np.zeros_like(image, dtype=np.float32) # Force float32
mask_overlay[mask_array] = color
# Blend with original image
ax2.imshow(mask_overlay, alpha=0.35)
# Draw contour
contour = mask["bbox"] # [x, y, w, h]
rect = plt.Rectangle(
(contour[0], contour[1]),
contour[2],
contour[3],
linewidth=1,
edgecolor=color,
facecolor="none",
)
ax2.add_patch(rect)
ax2.set_title(f"Segmentation ({len(masks)} segments)")
ax2.axis("off")
plt.tight_layout()
return fig
def main():
st.title("Insect Segmentation with SAM")
# Load SAM model
try:
sam, device = load_sam_model()
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Minimum area in pixels
)
except FileNotFoundError:
st.error(
"""
Please download the SAM checkpoint file:
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
"""
)
return
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Convert uploaded file to numpy array
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Parameter controls
col1, col2 = st.columns(2)
with col1:
min_area_pct = st.slider(
"Minimum Area (%)",
min_value=0.01,
max_value=1.0,
value=0.05,
step=0.01,
help="Minimum segment area as percentage of image area",
)
pred_iou_thresh = st.slider(
"Prediction IoU Threshold",
min_value=0.0,
max_value=1.0,
value=0.86,
help="Higher values = more selective segmentation",
)
with col2:
max_area_pct = st.slider(
"Maximum Area (%)",
min_value=1.0,
max_value=20.0,
value=5.0,
step=0.1,
help="Maximum segment area as percentage of image area",
)
stability_score_thresh = st.slider(
"Stability Score Threshold",
min_value=0.0,
max_value=1.0,
value=0.92,
help="Higher values = more stable segments",
)
# Process image
with st.spinner("Processing image with SAM..."):
masks = process_image(
image,
mask_generator,
min_area=min_area_pct / 100,
max_area=max_area_pct / 100,
)
# Plot results
fig = plot_results(image, masks)
st.pyplot(fig)
# Add download button for masks
if st.button("Download Masks as JSON"):
# Convert masks to JSON-serializable format
masks_json = [
{
"segmentation": mask["segmentation"].tolist(),
"area": float(mask["area"]),
"bbox": [float(x) for x in mask["bbox"]],
}
for mask in masks
]
st.download_button(
"Download JSON",
data=json.dumps(masks_json),
file_name="masks.json",
mime="application/json",
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment