Skip to content

Instantly share code, notes, and snippets.

@shahpnmlab
Created March 26, 2025 10:05
Show Gist options
  • Save shahpnmlab/bfed09e668c7a0225c5ba44cb38277f2 to your computer and use it in GitHub Desktop.
Save shahpnmlab/bfed09e668c7a0225c5ba44cb38277f2 to your computer and use it in GitHub Desktop.
napari mesh to mask
import numpy as np
import napari
from pathlib import Path
from scipy.interpolate import splprep, splev
from scipy import ndimage as ndi
from skimage import draw, filters, measure
from skimage.morphology import binary_closing, binary_dilation, binary_erosion
from napari.qt.threading import thread_worker
from magicgui import magicgui
from qtpy.QtWidgets import QVBoxLayout, QFileDialog
from magicgui.widgets import Container, FileEdit, PushButton, Slider
import mrcfile
from scipy.spatial import Delaunay
import trimesh
# Global variables
point_sets = {}
current_tomogram = None
RESULTS_DIRECTORY = './results'
# Initialize viewer
viewer = napari.Viewer()
# File browser widget
file_browser = FileEdit(label='Open Tomogram', mode='r')
def open_file(file_path):
global current_tomogram
if not file_path:
return
try:
# Load the MRC file
with mrcfile.open(file_path) as mrc:
volume_data = mrc.data
# Get the filename without extension
current_tomogram = Path(file_path).stem
# Add volume to viewer
if 'tomogram' in viewer.layers:
viewer.layers['tomogram'].data = volume_data
else:
viewer.add_image(volume_data, name='tomogram')
# Add points layer if it doesn't exist
if 'object_points' not in viewer.layers:
viewer.add_points(
np.empty((0, 3)),
name='object_points',
size=5,
face_color='red'
)
# Connect click event
viewer.layers['tomogram'].mouse_drag_callbacks.append(on_click)
except Exception as e:
print(f"Error loading file: {e}")
file_browser.changed.connect(open_file)
# Handle click events to collect points
def on_click(layer, event):
if 'tomogram' in viewer.layers and event.button == 1:
# Get current z-slice from viewer
current_z = viewer.dims.current_step[0]
# Get clicked coordinates directly from the event position
pos = viewer.layers['tomogram'].world_to_data(event.position)
# Store point as [z, y, x] for proper 3D coordinates
point = np.array([current_z, pos[1], pos[2]])
# Store point in the global point set
z_key = f"z_{current_z}"
if z_key not in point_sets:
point_sets[z_key] = []
point_sets[z_key].append(point)
# Add point to the points layer
if 'object_points' in viewer.layers:
viewer.layers['object_points'].add(point[np.newaxis, :])
# Fit spline to points in a slice
def fit_spline(points, smoothing=0.1, num_points=100):
if len(points) < 4:
return None
# Extract y,x coordinates (we keep z constant for each slice)
y = [p[1] for p in points]
x = [p[2] for p in points]
# Close the curve
y.append(y[0])
x.append(x[0])
# Fit spline
tck, u = splprep([y, x], s=smoothing, per=True)
# Generate points along spline
u_new = np.linspace(0, 1, num_points)
y_new, x_new = splev(u_new, tck)
return np.column_stack((y_new, x_new))
# Fit spline to points in a slice - version for direct point input
def fit_spline_from_points(points, smoothing=0.1, num_points=100):
if len(points) < 4:
return None
try:
# Extract y,x coordinates
y = [p[0] for p in points] # y-coordinates
x = [p[1] for p in points] # x-coordinates
# Close the curve
y.append(y[0])
x.append(x[0])
# Fit spline
tck, u = splprep([y, x], s=smoothing, per=True)
# Generate points along spline
u_new = np.linspace(0, 1, num_points)
y_new, x_new = splev(u_new, tck)
return np.column_stack((y_new, x_new))
except Exception as e:
print(f"Error in spline fitting: {e}")
import traceback
traceback.print_exc()
return None
# Create 3D mask from contours
def create_3d_mesh():
print("Starting mesh creation function")
# Get points from the object_points layer instead of using point_sets
if 'object_points' not in viewer.layers:
print("No object_points layer found")
return None, None
if 'tomogram' not in viewer.layers:
print("No tomogram layer found")
return None, None
# Get all points from the points layer
all_points = viewer.layers['object_points'].data
print(f"Found {len(all_points)} points in object_points layer")
if len(all_points) < 4:
print("Need at least 4 points to create a mask")
return None, None
# Get volume dimensions
volume_shape = viewer.layers['tomogram'].data.shape
print(f"Volume shape: {volume_shape}")
smoothing = smoothing_slider.value
print(f"Using smoothing value: {smoothing}")
try:
# Create a mask volume
mask = np.zeros(volume_shape, dtype=np.int8)
# Organize points by z-slice
points_by_slice = {}
for point in all_points:
z = int(round(point[0])) # Points are in (z,y,x) format
if z not in points_by_slice:
points_by_slice[z] = []
points_by_slice[z].append(point)
z_slices = sorted(points_by_slice.keys())
print(f"Found points on z slices: {z_slices}")
valid_slices = []
contours = {}
# Fit contours to each slice with enough points
for z in z_slices:
points = points_by_slice[z]
print(f"Z-slice {z} has {len(points)} points")
if len(points) >= 4:
print(f"Fitting spline to points in z={z}")
# Extract y,x coordinates for spline fitting
spline_points = [[p[1], p[2]] for p in points] # Get y,x from [z,y,x]
spline = fit_spline_from_points(spline_points, smoothing=smoothing)
if spline is not None:
valid_slices.append(z)
contours[z] = spline
print(f"Successfully created contour for z={z}")
else:
print(f"Spline fitting failed for z={z}")
else:
print(f"Not enough points on z={z} for spline fitting")
print(f"Valid slices with contours: {valid_slices}")
if not valid_slices:
print("Not enough valid slices with points")
return None, None
# Fill 2D masks for valid slices
for z in valid_slices:
print(f"Processing slice z={z}")
contour = contours[z]
print(f"Contour has {len(contour)} points")
# Create polygon from contour points
polygon = [(y, x) for y, x in contour]
# Draw filled polygon
try:
rr, cc = draw.polygon(
[p[0] for p in polygon],
[p[1] for p in polygon],
shape=volume_shape[1:3]
)
print(f"Polygon drawn with {len(rr)} points")
if len(rr) > 0 and len(cc) > 0:
mask[z, rr, cc] = 1
print(f"Filled mask for z={z}")
else:
print(f"Warning: Empty polygon for z={z}")
except Exception as e:
print(f"Error drawing polygon for z={z}: {e}")
import traceback
traceback.print_exc()
# Check if any masks were created
mask_sum = np.sum(mask)
print(f"Total filled voxels in mask: {mask_sum}")
if mask_sum == 0:
print("No voxels were filled in the mask")
return None, None
# Interpolate between slices if we have multiple valid slices
if len(valid_slices) > 1:
print("Interpolating between slices with improved smoothing")
# First create distance transforms for all valid slices
dt_maps = {}
for z in valid_slices:
# Create signed distance transform
dt_pos = ndi.distance_transform_edt(mask[z] == 0)
dt_neg = ndi.distance_transform_edt(mask[z] == 1)
dt_maps[z] = dt_pos - dt_neg
# Interpolate between all slice pairs
for i in range(len(valid_slices)-1):
z1, z2 = valid_slices[i], valid_slices[i+1]
if z2 - z1 > 1: # Only interpolate if slices aren't adjacent
print(f"Creating smooth interpolation between z1={z1} and z2={z2}")
dt1 = dt_maps[z1]
dt2 = dt_maps[z2]
# Use more slices for smoother transition
for z in range(z1+1, z2):
# Cubic interpolation weight for smoother transitions
t = (z - z1) / (z2 - z1)
# Apply smoothstep function for more natural transitions
alpha = t * t * (3 - 2 * t)
# Interpolate the distance fields
interp_dt = (1-alpha) * dt1 + alpha * dt2
# Convert back to binary mask
mask[z] = (interp_dt <= 0).astype(np.int8)
print(f"Interpolated slice z={z} with smoothstep alpha={alpha:.3f}")
# Apply hole-filling directly in 3D
print("Performing 3D hole filling")
# This fully closes any cavities in the 3D volume
mask = ndi.binary_fill_holes(mask).astype(np.int8)
# Apply 3D Gaussian smoothing to smooth transitions
print("Applying 3D smoothing filter")
# Use anisotropic smoothing with different sigma for z vs xy
mask_float = mask.astype(float)
# More smoothing in z-direction (axis=0) than in xy plane
mask_smooth = filters.gaussian(mask_float, sigma=(2.0, 0.7, 0.7))
mask = (mask_smooth > 0.5).astype(np.int8)
# Final component analysis to remove any small disconnected regions
print("Cleaning up with component analysis")
labeled, num = ndi.label(mask)
if num > 1:
# Get sizes of all components
sizes = np.bincount(labeled.ravel())
sizes[0] = 0 # Ignore background
largest_label = np.argmax(sizes)
mask = (labeled == largest_label).astype(np.int8)
print("Finalizing mask")
# Clean up the mask with morphological operations
mask = binary_closing(mask)
print("Applied binary closing")
try:
# Create a simple mesh from the mask for visualization
print("Creating 3D mesh with marching cubes")
verts, faces, _, _ = measure.marching_cubes(mask)
print(f"Mesh created with {len(verts)} vertices and {len(faces)} faces")
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
print("Trimesh object created")
return mesh, mask
except Exception as e:
print(f"Error in marching cubes: {e}")
import traceback
traceback.print_exc()
# Even if mesh creation fails, return the mask
return None, mask
except Exception as e:
print(f"Error creating mask: {e}")
import traceback
traceback.print_exc()
return None, None
# Button to create mesh and mask
create_mask_button = PushButton(label='Create 3D Mesh & Mask')
def on_create_mesh():
print("Starting create mesh process")
mesh, mask = create_3d_mesh()
try:
if mask is not None:
# Display the mask
print("Displaying mask in viewer")
if 'object_mask' in viewer.layers:
viewer.layers['object_mask'].data = mask
else:
viewer.add_labels(mask, name='object_mask', opacity=0.5)
# Try to visualize the mesh if available
if mesh is not None:
print("Displaying mesh in viewer")
if 'object_mesh' in viewer.layers:
viewer.layers.remove('object_mesh')
try:
viewer.add_surface((mesh.vertices, mesh.faces), name='object_mesh', colormap='red')
print("Mesh visualization successful")
except Exception as e:
print(f"Error displaying mesh: {e}")
else:
print("No mesh generated, but mask is available")
print("Mask created successfully")
else:
print("Failed to create mask and mesh")
except Exception as e:
print(f"Error in on_create_mesh: {e}")
import traceback
traceback.print_exc()
create_mask_button.changed.connect(on_create_mesh)
# Clear points button
clear_points_button = PushButton(label='Clear Points')
def clear_points():
global point_sets
point_sets = {}
if 'object_points' in viewer.layers:
viewer.layers['object_points'].data = np.empty((0, 3))
clear_points_button.changed.connect(clear_points)
# Save mask
def save_mask(*args):
if 'object_mask' not in viewer.layers:
print('No mask layer available')
return
try:
if current_tomogram:
output_file = Path(RESULTS_DIRECTORY) / f'{current_tomogram}_mask.mrc'
else:
output_file = Path(RESULTS_DIRECTORY) / 'unnamed_mask.mrc'
# Save the mask as an MRC file
mrcfile.write(output_file, viewer.layers['object_mask'].data.astype(np.int8), overwrite=True)
print(f"Mask saved to {output_file}")
except Exception as e:
print(f"Error saving mask: {e}")
save_button = PushButton(label='Save Mask')
save_button.changed.connect(save_mask)
# Apply lowpass filter
lowpass_button = PushButton(label='Apply Lowpass Filter')
def apply_lowpass(*args):
if 'tomogram' not in viewer.layers:
return
filtered_data = filters.gaussian(viewer.layers['tomogram'].data, sigma=1)
viewer.layers['tomogram'].data = filtered_data
lowpass_button.changed.connect(apply_lowpass)
# Smoothing factor for splines
smoothing_slider = Slider(value=0.1, min=0.01, max=1.0, step=0.01, label='Spline Smoothing')
# Container for all widgets
container_widget = Container()
container_widget.append(file_browser)
container_widget.append(smoothing_slider)
container_widget.append(create_mask_button)
container_widget.append(clear_points_button)
container_widget.append(save_button)
container_widget.append(lowpass_button)
# Setup output directory
Path(RESULTS_DIRECTORY).mkdir(exist_ok=True, parents=True)
# Run napari
viewer.window.add_dock_widget(container_widget)
napari.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment