Created
March 1, 2021 18:00
-
-
Save carsen-stringer/05aeb1b4e49b3ea12bc2148d7129dfd8 to your computer and use it in GitHub Desktop.
test plugin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" first try at cellpose plugin for napari""" | |
from enum import Enum | |
import sys, pathlib | |
import numpy as np | |
import napari | |
from napari import Viewer, gui_qt | |
from napari.layers import Image | |
from napari.qt.threading import thread_worker | |
from napari_plugin_engine import napari_hook_implementation | |
from magicgui import magicgui, magic_factory | |
from cellpose.models import Cellpose | |
from cellpose.utils import masks_to_outlines | |
class Model(Enum): | |
cyto = 'cyto' | |
nuclei = 'nuclei' | |
class Channel1(Enum): | |
gray = 0 | |
red = 1 | |
green = 2 | |
blue = 3 | |
class Channel2(Enum): | |
none = 0 | |
red = 1 | |
green = 2 | |
blue = 3 | |
def create_image(layerA, layerB, operation): | |
return Image(operation.value(layerA.data, layerB.data)) | |
@napari_hook_implementation | |
def napari_experimental_provide_dock_widget(): | |
return start_cellpose | |
def reset_view(viewer: 'napari.Viewer', layer: 'napari.layers.Layer'): | |
if viewer.dims.ndisplay != 2: | |
return | |
extent = layer.extent.world[:, viewer.dims.displayed] | |
size = extent[1] - extent[0] | |
center = extent[0] + size / 2 | |
viewer.camera.center = center | |
viewer.camera.zoom = np.min(viewer._canvas_size) / np.max(size) | |
#@thread_worker | |
def run_cellpose(img, model_type, channels, diameter): | |
CP = Cellpose(model_type=model_type, gpu=True) | |
masks, flows, styles, diams = CP.eval(img, | |
channels=channels, | |
diameter=diameter) | |
outlines = masks_to_outlines(masks) * masks | |
return masks, flows, diams, outlines | |
@magic_factory( | |
call_button='Run', | |
layout='vertical', | |
#model={'mode': 'w'}, | |
model_match_threshold={"widget_type": "FloatSlider", "max": 6}, | |
viewer={'visible': False, 'label': ' '}, | |
) | |
def start_cellpose( | |
image_layer: Image, | |
viewer: napari.viewer.Viewer, | |
model_type=Model.cyto, | |
net_average=False, | |
resample_dynamics=False, | |
main_channel=Channel1.gray, | |
optional_nuclear_channel=Channel2.none, | |
diameter=30.0, | |
cellprob_threshold=0.0, | |
model_match_threshold=6, | |
clear_previous_segmentations=True, | |
): | |
if not hasattr(viewer, 'cellpose_layers'): | |
viewer.cellpose_layers = [] | |
if clear_previous_segmentations: | |
for seg in viewer.cellpose_layers: | |
for layer in seg: | |
viewer.layers.remove(layer) | |
viewer.cellpose_layers = [] | |
mode = start_cellpose._call_button.text # can be "Start" or "Finish" | |
# focus on the reference layer | |
reset_view(viewer, image_layer) | |
# run cellpose with parameters | |
masks, flows, diams, outlines = run_cellpose(img=image_layer.data, | |
model_type=model_type.value, | |
channels=[max(0, main_channel.value), | |
max(0, optional_nuclear_channel.value)], | |
diameter=diameter) | |
# make a points layer for each image | |
if len(viewer.cellpose_layers) > 0: | |
iseg = '_' + str(len(viewer.cellpose_layers)+1) | |
else: | |
iseg = '' | |
flow_layer = viewer.add_image(flows[0], name=image_layer.name + '_flows' + iseg) | |
cellprob_layer = viewer.add_image(flows[2], name=image_layer.name + '_cellprob' + iseg) | |
mask_layer = viewer.add_labels(masks, name=image_layer.name + '_masks' + iseg) | |
outline_layer = viewer.add_labels(outlines, name=image_layer.name + '_outlines' + iseg) | |
viewer.cellpose_layers.append([flow_layer, cellprob_layer, mask_layer, outline_layer]) | |
viewer.cellpose_settings = [model_type.value, | |
main_channel.value, | |
optional_nuclear_channel.value, | |
diameter, | |
model_match_threshold, | |
cellprob_threshold] | |
viewer.layers.unselect_all() | |
if len(viewer.cellpose_layers) > 1: | |
for seg in viewer.cellpose_layers[:-1]: | |
for layer in seg: | |
layer.visible = False | |
image_layer.visible = True | |
mask_layer.visible = True | |
outline_layer.visible = True | |
flow_layer.visible = False | |
cellprob_layer.visible = False | |
def main(): | |
fns = sys.argv[1:] | |
viewer = napari.Viewer() | |
if len(fns) > 0: | |
viewer.open(fns, stack=False) | |
viewer.window.add_dock_widget(start_cellpose(), area='right') | |
napari.run() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment