Created
March 26, 2025 10:05
-
-
Save shahpnmlab/bfed09e668c7a0225c5ba44cb38277f2 to your computer and use it in GitHub Desktop.
napari mesh to mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import napari | |
from pathlib import Path | |
from scipy.interpolate import splprep, splev | |
from scipy import ndimage as ndi | |
from skimage import draw, filters, measure | |
from skimage.morphology import binary_closing, binary_dilation, binary_erosion | |
from napari.qt.threading import thread_worker | |
from magicgui import magicgui | |
from qtpy.QtWidgets import QVBoxLayout, QFileDialog | |
from magicgui.widgets import Container, FileEdit, PushButton, Slider | |
import mrcfile | |
from scipy.spatial import Delaunay | |
import trimesh | |
# Global variables | |
point_sets = {} | |
current_tomogram = None | |
RESULTS_DIRECTORY = './results' | |
# Initialize viewer | |
viewer = napari.Viewer() | |
# File browser widget | |
file_browser = FileEdit(label='Open Tomogram', mode='r') | |
def open_file(file_path): | |
global current_tomogram | |
if not file_path: | |
return | |
try: | |
# Load the MRC file | |
with mrcfile.open(file_path) as mrc: | |
volume_data = mrc.data | |
# Get the filename without extension | |
current_tomogram = Path(file_path).stem | |
# Add volume to viewer | |
if 'tomogram' in viewer.layers: | |
viewer.layers['tomogram'].data = volume_data | |
else: | |
viewer.add_image(volume_data, name='tomogram') | |
# Add points layer if it doesn't exist | |
if 'object_points' not in viewer.layers: | |
viewer.add_points( | |
np.empty((0, 3)), | |
name='object_points', | |
size=5, | |
face_color='red' | |
) | |
# Connect click event | |
viewer.layers['tomogram'].mouse_drag_callbacks.append(on_click) | |
except Exception as e: | |
print(f"Error loading file: {e}") | |
file_browser.changed.connect(open_file) | |
# Handle click events to collect points | |
def on_click(layer, event): | |
if 'tomogram' in viewer.layers and event.button == 1: | |
# Get current z-slice from viewer | |
current_z = viewer.dims.current_step[0] | |
# Get clicked coordinates directly from the event position | |
pos = viewer.layers['tomogram'].world_to_data(event.position) | |
# Store point as [z, y, x] for proper 3D coordinates | |
point = np.array([current_z, pos[1], pos[2]]) | |
# Store point in the global point set | |
z_key = f"z_{current_z}" | |
if z_key not in point_sets: | |
point_sets[z_key] = [] | |
point_sets[z_key].append(point) | |
# Add point to the points layer | |
if 'object_points' in viewer.layers: | |
viewer.layers['object_points'].add(point[np.newaxis, :]) | |
# Fit spline to points in a slice | |
def fit_spline(points, smoothing=0.1, num_points=100): | |
if len(points) < 4: | |
return None | |
# Extract y,x coordinates (we keep z constant for each slice) | |
y = [p[1] for p in points] | |
x = [p[2] for p in points] | |
# Close the curve | |
y.append(y[0]) | |
x.append(x[0]) | |
# Fit spline | |
tck, u = splprep([y, x], s=smoothing, per=True) | |
# Generate points along spline | |
u_new = np.linspace(0, 1, num_points) | |
y_new, x_new = splev(u_new, tck) | |
return np.column_stack((y_new, x_new)) | |
# Fit spline to points in a slice - version for direct point input | |
def fit_spline_from_points(points, smoothing=0.1, num_points=100): | |
if len(points) < 4: | |
return None | |
try: | |
# Extract y,x coordinates | |
y = [p[0] for p in points] # y-coordinates | |
x = [p[1] for p in points] # x-coordinates | |
# Close the curve | |
y.append(y[0]) | |
x.append(x[0]) | |
# Fit spline | |
tck, u = splprep([y, x], s=smoothing, per=True) | |
# Generate points along spline | |
u_new = np.linspace(0, 1, num_points) | |
y_new, x_new = splev(u_new, tck) | |
return np.column_stack((y_new, x_new)) | |
except Exception as e: | |
print(f"Error in spline fitting: {e}") | |
import traceback | |
traceback.print_exc() | |
return None | |
# Create 3D mask from contours | |
def create_3d_mesh(): | |
print("Starting mesh creation function") | |
# Get points from the object_points layer instead of using point_sets | |
if 'object_points' not in viewer.layers: | |
print("No object_points layer found") | |
return None, None | |
if 'tomogram' not in viewer.layers: | |
print("No tomogram layer found") | |
return None, None | |
# Get all points from the points layer | |
all_points = viewer.layers['object_points'].data | |
print(f"Found {len(all_points)} points in object_points layer") | |
if len(all_points) < 4: | |
print("Need at least 4 points to create a mask") | |
return None, None | |
# Get volume dimensions | |
volume_shape = viewer.layers['tomogram'].data.shape | |
print(f"Volume shape: {volume_shape}") | |
smoothing = smoothing_slider.value | |
print(f"Using smoothing value: {smoothing}") | |
try: | |
# Create a mask volume | |
mask = np.zeros(volume_shape, dtype=np.int8) | |
# Organize points by z-slice | |
points_by_slice = {} | |
for point in all_points: | |
z = int(round(point[0])) # Points are in (z,y,x) format | |
if z not in points_by_slice: | |
points_by_slice[z] = [] | |
points_by_slice[z].append(point) | |
z_slices = sorted(points_by_slice.keys()) | |
print(f"Found points on z slices: {z_slices}") | |
valid_slices = [] | |
contours = {} | |
# Fit contours to each slice with enough points | |
for z in z_slices: | |
points = points_by_slice[z] | |
print(f"Z-slice {z} has {len(points)} points") | |
if len(points) >= 4: | |
print(f"Fitting spline to points in z={z}") | |
# Extract y,x coordinates for spline fitting | |
spline_points = [[p[1], p[2]] for p in points] # Get y,x from [z,y,x] | |
spline = fit_spline_from_points(spline_points, smoothing=smoothing) | |
if spline is not None: | |
valid_slices.append(z) | |
contours[z] = spline | |
print(f"Successfully created contour for z={z}") | |
else: | |
print(f"Spline fitting failed for z={z}") | |
else: | |
print(f"Not enough points on z={z} for spline fitting") | |
print(f"Valid slices with contours: {valid_slices}") | |
if not valid_slices: | |
print("Not enough valid slices with points") | |
return None, None | |
# Fill 2D masks for valid slices | |
for z in valid_slices: | |
print(f"Processing slice z={z}") | |
contour = contours[z] | |
print(f"Contour has {len(contour)} points") | |
# Create polygon from contour points | |
polygon = [(y, x) for y, x in contour] | |
# Draw filled polygon | |
try: | |
rr, cc = draw.polygon( | |
[p[0] for p in polygon], | |
[p[1] for p in polygon], | |
shape=volume_shape[1:3] | |
) | |
print(f"Polygon drawn with {len(rr)} points") | |
if len(rr) > 0 and len(cc) > 0: | |
mask[z, rr, cc] = 1 | |
print(f"Filled mask for z={z}") | |
else: | |
print(f"Warning: Empty polygon for z={z}") | |
except Exception as e: | |
print(f"Error drawing polygon for z={z}: {e}") | |
import traceback | |
traceback.print_exc() | |
# Check if any masks were created | |
mask_sum = np.sum(mask) | |
print(f"Total filled voxels in mask: {mask_sum}") | |
if mask_sum == 0: | |
print("No voxels were filled in the mask") | |
return None, None | |
# Interpolate between slices if we have multiple valid slices | |
if len(valid_slices) > 1: | |
print("Interpolating between slices with improved smoothing") | |
# First create distance transforms for all valid slices | |
dt_maps = {} | |
for z in valid_slices: | |
# Create signed distance transform | |
dt_pos = ndi.distance_transform_edt(mask[z] == 0) | |
dt_neg = ndi.distance_transform_edt(mask[z] == 1) | |
dt_maps[z] = dt_pos - dt_neg | |
# Interpolate between all slice pairs | |
for i in range(len(valid_slices)-1): | |
z1, z2 = valid_slices[i], valid_slices[i+1] | |
if z2 - z1 > 1: # Only interpolate if slices aren't adjacent | |
print(f"Creating smooth interpolation between z1={z1} and z2={z2}") | |
dt1 = dt_maps[z1] | |
dt2 = dt_maps[z2] | |
# Use more slices for smoother transition | |
for z in range(z1+1, z2): | |
# Cubic interpolation weight for smoother transitions | |
t = (z - z1) / (z2 - z1) | |
# Apply smoothstep function for more natural transitions | |
alpha = t * t * (3 - 2 * t) | |
# Interpolate the distance fields | |
interp_dt = (1-alpha) * dt1 + alpha * dt2 | |
# Convert back to binary mask | |
mask[z] = (interp_dt <= 0).astype(np.int8) | |
print(f"Interpolated slice z={z} with smoothstep alpha={alpha:.3f}") | |
# Apply hole-filling directly in 3D | |
print("Performing 3D hole filling") | |
# This fully closes any cavities in the 3D volume | |
mask = ndi.binary_fill_holes(mask).astype(np.int8) | |
# Apply 3D Gaussian smoothing to smooth transitions | |
print("Applying 3D smoothing filter") | |
# Use anisotropic smoothing with different sigma for z vs xy | |
mask_float = mask.astype(float) | |
# More smoothing in z-direction (axis=0) than in xy plane | |
mask_smooth = filters.gaussian(mask_float, sigma=(2.0, 0.7, 0.7)) | |
mask = (mask_smooth > 0.5).astype(np.int8) | |
# Final component analysis to remove any small disconnected regions | |
print("Cleaning up with component analysis") | |
labeled, num = ndi.label(mask) | |
if num > 1: | |
# Get sizes of all components | |
sizes = np.bincount(labeled.ravel()) | |
sizes[0] = 0 # Ignore background | |
largest_label = np.argmax(sizes) | |
mask = (labeled == largest_label).astype(np.int8) | |
print("Finalizing mask") | |
# Clean up the mask with morphological operations | |
mask = binary_closing(mask) | |
print("Applied binary closing") | |
try: | |
# Create a simple mesh from the mask for visualization | |
print("Creating 3D mesh with marching cubes") | |
verts, faces, _, _ = measure.marching_cubes(mask) | |
print(f"Mesh created with {len(verts)} vertices and {len(faces)} faces") | |
mesh = trimesh.Trimesh(vertices=verts, faces=faces) | |
print("Trimesh object created") | |
return mesh, mask | |
except Exception as e: | |
print(f"Error in marching cubes: {e}") | |
import traceback | |
traceback.print_exc() | |
# Even if mesh creation fails, return the mask | |
return None, mask | |
except Exception as e: | |
print(f"Error creating mask: {e}") | |
import traceback | |
traceback.print_exc() | |
return None, None | |
# Button to create mesh and mask | |
create_mask_button = PushButton(label='Create 3D Mesh & Mask') | |
def on_create_mesh(): | |
print("Starting create mesh process") | |
mesh, mask = create_3d_mesh() | |
try: | |
if mask is not None: | |
# Display the mask | |
print("Displaying mask in viewer") | |
if 'object_mask' in viewer.layers: | |
viewer.layers['object_mask'].data = mask | |
else: | |
viewer.add_labels(mask, name='object_mask', opacity=0.5) | |
# Try to visualize the mesh if available | |
if mesh is not None: | |
print("Displaying mesh in viewer") | |
if 'object_mesh' in viewer.layers: | |
viewer.layers.remove('object_mesh') | |
try: | |
viewer.add_surface((mesh.vertices, mesh.faces), name='object_mesh', colormap='red') | |
print("Mesh visualization successful") | |
except Exception as e: | |
print(f"Error displaying mesh: {e}") | |
else: | |
print("No mesh generated, but mask is available") | |
print("Mask created successfully") | |
else: | |
print("Failed to create mask and mesh") | |
except Exception as e: | |
print(f"Error in on_create_mesh: {e}") | |
import traceback | |
traceback.print_exc() | |
create_mask_button.changed.connect(on_create_mesh) | |
# Clear points button | |
clear_points_button = PushButton(label='Clear Points') | |
def clear_points(): | |
global point_sets | |
point_sets = {} | |
if 'object_points' in viewer.layers: | |
viewer.layers['object_points'].data = np.empty((0, 3)) | |
clear_points_button.changed.connect(clear_points) | |
# Save mask | |
def save_mask(*args): | |
if 'object_mask' not in viewer.layers: | |
print('No mask layer available') | |
return | |
try: | |
if current_tomogram: | |
output_file = Path(RESULTS_DIRECTORY) / f'{current_tomogram}_mask.mrc' | |
else: | |
output_file = Path(RESULTS_DIRECTORY) / 'unnamed_mask.mrc' | |
# Save the mask as an MRC file | |
mrcfile.write(output_file, viewer.layers['object_mask'].data.astype(np.int8), overwrite=True) | |
print(f"Mask saved to {output_file}") | |
except Exception as e: | |
print(f"Error saving mask: {e}") | |
save_button = PushButton(label='Save Mask') | |
save_button.changed.connect(save_mask) | |
# Apply lowpass filter | |
lowpass_button = PushButton(label='Apply Lowpass Filter') | |
def apply_lowpass(*args): | |
if 'tomogram' not in viewer.layers: | |
return | |
filtered_data = filters.gaussian(viewer.layers['tomogram'].data, sigma=1) | |
viewer.layers['tomogram'].data = filtered_data | |
lowpass_button.changed.connect(apply_lowpass) | |
# Smoothing factor for splines | |
smoothing_slider = Slider(value=0.1, min=0.01, max=1.0, step=0.01, label='Spline Smoothing') | |
# Container for all widgets | |
container_widget = Container() | |
container_widget.append(file_browser) | |
container_widget.append(smoothing_slider) | |
container_widget.append(create_mask_button) | |
container_widget.append(clear_points_button) | |
container_widget.append(save_button) | |
container_widget.append(lowpass_button) | |
# Setup output directory | |
Path(RESULTS_DIRECTORY).mkdir(exist_ok=True, parents=True) | |
# Run napari | |
viewer.window.add_dock_widget(container_widget) | |
napari.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment