Created
February 13, 2025 16:38
-
-
Save mathematicalmichael/126d9955dc1a797321f481784c65d93b to your computer and use it in GitHub Desktop.
Segment Auditing Streamlit App
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
uv run --isolated \ | |
--with streamlit==1.42.0 \ | |
--with opencv-python-headless==4.8.1.78 \ | |
--with matplotlib==3.10.0 \ | |
streamlit run \ | |
--server.headless true \ | |
--browser.gatherUsageStats false \ | |
--browser.serverAddress 0.0.0.0 \ | |
--server.port 8502 \ | |
segment_audit_st.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import cv2 | |
import matplotlib.patches as patches | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import streamlit as st | |
def load_image(image_file): | |
"""Loads an image from an uploaded file.""" | |
file_bytes = np.asarray(bytearray(image_file.read()), dtype=np.uint8) | |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
return image | |
def load_contours(json_file): | |
"""Loads contour data from a JSON file.""" | |
try: | |
data = json.load(json_file) | |
return data | |
except Exception as e: | |
st.error(f"Error loading JSON file: {e}") | |
return None | |
def calculate_bounding_box(segmentation, scale_factor=1.0): | |
"""Calculates the bounding box from the convex hull of the segmentation contour.""" | |
# Scale points first | |
scaled_points = np.array( | |
[[point[0] * scale_factor, point[1] * scale_factor] for point in segmentation] | |
) | |
# Calculate convex hull | |
hull = cv2.convexHull(scaled_points.astype(np.float32)) | |
# Get bounding box from hull points | |
x_coords = hull[:, 0, 0] # hull points are in a nested array | |
y_coords = hull[:, 0, 1] | |
x_min = np.min(x_coords) | |
x_max = np.max(x_coords) | |
y_min = np.min(y_coords) | |
y_max = np.max(y_coords) | |
# Calculate dimensions | |
w = x_max - x_min | |
h = y_max - y_min | |
# Add minimum padding for very small segments (e.g., 5 pixels) | |
min_dim = 5 | |
if w < min_dim: | |
padding = (min_dim - w) / 2 | |
x_min -= padding | |
w = min_dim | |
if h < min_dim: | |
padding = (min_dim - h) / 2 | |
y_min -= padding | |
h = min_dim | |
return [x_min, y_min, w, h] | |
def apply_segmentation_contours(image, contours, original_width, proxy_width): | |
"""Applies segmentation contours (bounding boxes and numbers) to the original image.""" | |
height, width, _ = image.shape | |
scale_factor = width / proxy_width | |
annotated_image = image.copy() | |
for i, contour_data in enumerate(contours): | |
segmentation = contour_data | |
# Calculate bounding box from segmentation with scale factor | |
bbox = calculate_bounding_box(segmentation, scale_factor) | |
# Use bbox directly without additional scaling | |
x, y, w, h = [int(coord) for coord in bbox] | |
# Draw bounding box | |
cv2.rectangle(annotated_image, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
# Put number on the segment | |
text_position = (x + 5, y + 20) # Offset slightly into the bounding box | |
cv2.putText( | |
annotated_image, | |
str(i + 1), | |
text_position, | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.7, | |
(255, 255, 255), | |
2, | |
cv2.LINE_AA, | |
) | |
return annotated_image | |
def visualize_segment(image, contour, original_width, proxy_width, segment_index): | |
"""Visualizes a single segment in a separate figure, cropped to the bounding box with margin, with contour lines.""" | |
height, width, _ = image.shape | |
scale_factor = width / proxy_width | |
segmentation = contour | |
# Calculate bounding box from segmentation with scale factor | |
bbox = calculate_bounding_box(segmentation, scale_factor) | |
# Use bbox directly without additional scaling | |
x, y, w, h = [int(coord) for coord in bbox] | |
# Calculate margins (10% of width/height) | |
margin_x = int(w * 0.1) | |
margin_y = int(h * 0.1) | |
# Calculate crop coordinates with margins, ensuring we stay within image bounds | |
crop_x1 = max(0, x - margin_x) | |
crop_y1 = max(0, y - margin_y) | |
crop_x2 = min(width, x + w + margin_x) | |
crop_y2 = min(height, y + h + margin_y) | |
# Crop the image with margins | |
cropped_segment = image[crop_y1:crop_y2, crop_x1:crop_x2] | |
# Scale the segmentation polygon to the cropped image's coordinates | |
scaled_segmentation = [] | |
for point in segmentation: | |
scaled_x = (point[0] * scale_factor) - crop_x1 # Adjust for new crop origin | |
scaled_y = (point[1] * scale_factor) - crop_y1 | |
scaled_segmentation.append([scaled_x, scaled_y]) | |
# Convert scaled segmentation to numpy array | |
scaled_segmentation = np.array(scaled_segmentation) | |
# Create a figure and axes | |
fig, ax = plt.subplots(1) | |
# Display the cropped segment | |
ax.imshow(cropped_segment) | |
ax.set_title(f"Segment {segment_index}") | |
ax.axis("off") | |
# Plot the contour lines | |
ax.plot( | |
scaled_segmentation[:, 0], scaled_segmentation[:, 1], color="r", linewidth=1.5 | |
) | |
# Draw bounding box (relative to the cropped image) | |
rect = patches.Rectangle( | |
(x - crop_x1, y - crop_y1), # Adjust rectangle position for crop | |
w, | |
h, | |
linewidth=1, | |
edgecolor="g", # Changed from 'r' to 'g' for green | |
facecolor="none", | |
) | |
ax.add_patch(rect) | |
plt.tight_layout() | |
return fig | |
def plot_segments_mpl(image, masks, original_width, proxy_width): | |
"""Plots segmentation masks using Matplotlib, similar to plot_segments.py (contours only).""" | |
# Calculate scale factor | |
scale_factor = original_width / proxy_width | |
# Create figure and axis | |
fig, ax = plt.subplots(figsize=(16, 12)) # Increased figure size | |
ax.imshow(image) | |
# Generate distinct colors for better visibility | |
colors = plt.cm.rainbow(np.linspace(0, 1, len(masks))) | |
# Plot each mask | |
for i, (mask, color) in enumerate(zip(masks, colors)): | |
# Scale the points | |
scaled_points = np.array( | |
[[point[0] * scale_factor, point[1] * scale_factor] for point in mask] | |
) | |
# Close the contour by adding the first point at the end | |
scaled_points = np.vstack([scaled_points, scaled_points[0]]) | |
# Plot contour with unique color | |
ax.plot( | |
scaled_points[:, 0], scaled_points[:, 1], "-", color=color, linewidth=1.5 | |
) | |
# Calculate scaled center | |
center = scaled_points.mean(axis=0) | |
ax.text( | |
center[0], | |
center[1], | |
str(i + 1), | |
color="white", | |
fontsize=8, | |
bbox=dict(facecolor=color, alpha=0.7, pad=0.5), | |
ha="center", | |
va="center", | |
) | |
ax.set_title(f"Segmented Regions: {len(masks)} found") | |
ax.axis("off") | |
plt.tight_layout() | |
return fig | |
def main(): | |
st.title("Segment Audit App") | |
# File uploaders | |
image_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) | |
json_file = st.file_uploader("Upload Contours JSON", type=["json"]) | |
# Input for proxy photo width | |
proxy_width = st.number_input( | |
"Proxy Photo Width", | |
min_value=1, | |
value=640, | |
help="Width of the image used to generate the contours", | |
) | |
if image_file and json_file: | |
# Load image and contours | |
image = load_image(image_file) | |
contours = load_contours(json_file) | |
if image is not None and contours is not None: | |
# Display total number of segments | |
st.write(f"Total number of segments: {len(contours)}") | |
# Apply segmentation (bounding boxes and numbers) | |
annotated_image = apply_segmentation_contours( | |
image.copy(), contours, image.shape[1], proxy_width | |
) | |
# Display annotated image (bounding boxes and numbers) | |
st.image( | |
annotated_image, | |
caption="Annotated Image (Bounding Boxes)", | |
use_container_width=True, | |
) | |
# Plot segments using Matplotlib (contours only) | |
st.markdown("Annotated Image (Contours)") # Add caption using st.markdown | |
fig = plot_segments_mpl(image, contours, image.shape[1], proxy_width) | |
st.pyplot(fig) | |
# Segment audit input | |
segment_index = st.text_input("Enter segment number to audit", value="") | |
if segment_index: | |
try: | |
segment_index = int(segment_index) - 1 # Adjust to 0-based index | |
if 0 <= segment_index < len(contours): | |
# Visualize the selected segment | |
segment_fig = visualize_segment( | |
image.copy(), | |
contours[segment_index], | |
image.shape[1], | |
proxy_width, | |
segment_index + 1, | |
) | |
st.pyplot(segment_fig) | |
else: | |
st.error("Invalid segment number.") | |
except ValueError: | |
st.error("Please enter a valid number.") | |
# Add cutoff index input and export functionality | |
st.markdown("---") | |
st.markdown("### Export Filtered Contours") | |
cutoff_index = st.number_input( | |
"Export contours up to index (exclusive)", | |
min_value=1, | |
max_value=len(contours), | |
value=len(contours), | |
) | |
# Add exclude indices input | |
exclude_indices_str = st.text_input( | |
"Indices to exclude (comma-separated)", value="", help="e.g. 1,5,10" | |
) | |
# Parse exclude indices | |
exclude_indices = set() | |
if exclude_indices_str: | |
try: | |
exclude_indices = { | |
int(idx.strip()) - 1 # Convert to 0-based index | |
for idx in exclude_indices_str.split(",") | |
if idx.strip() | |
} | |
except ValueError: | |
st.error( | |
"Invalid exclude indices format. Please use comma-separated numbers." | |
) | |
# Scale the contours | |
scale_factor = image.shape[1] / proxy_width | |
scaled_contours = [] | |
for i, contour in enumerate(contours): | |
if i < cutoff_index and i not in exclude_indices: | |
scaled_points = [ | |
[point[0] * scale_factor, point[1] * scale_factor] | |
for point in contour | |
] | |
scaled_contours.append(scaled_points) | |
# Create download button | |
json_str = json.dumps(scaled_contours) | |
st.download_button( | |
label="Download Scaled Contours", | |
data=json_str, | |
file_name="scaled_contours.json", | |
mime="application/json", | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment