Skip to content

Instantly share code, notes, and snippets.

@cwood1967
Last active October 7, 2021 18:28
Show Gist options
  • Save cwood1967/aa74d475c4f52108e7f6bcd30a878478 to your computer and use it in GitHub Desktop.
Save cwood1967/aa74d475c4f52108e7f6bcd30a878478 to your computer and use it in GitHub Desktop.
Training with mask_rcnn using cellfinder
import os
import warnings
import numpy as np
import torch
import tifffile
from cellfinder import train, infer
warnings.filterwarnings('ignore')
'''
https://github.com/cwood1967/xlearn
Arguments to train.main:
root : str
The path to the data directory. Default is 'Data'
image_dir : str
Directory inside to root where traing images are located. Default 'Images'
mask_dir : str
Directory for training masks. Default is 'Masks'
epochs : int
Number of epochs to train. Default 50
cropsize : tuple (int, int)
Size of patches to use in training. Default (400, 400)
batch_size : int
Size of image batches. Default 8
train.main will save a snapstop after any epoch that results in a lower loss. The directory
where the snapshots are saved is created during training.
'''
tm = train.main('.', 'Images', 'Masks', cropsize=(400, 400), batch_size=8, epochs=100)
'''
inferring
'''
# load the saved model files
model_files = sorted(glob.glob('model_2021-10-06-16-29-00/trained_min_mask_model_*.pt'))
# load the network. The int argument is the number of classes to train on (including background)
tmx = train.get_model(2)
# load the network with the saved weights
tmx.load_state_dict(torch.load(model_files[-1]))
# create the inference class
cnn = infer.predict(tmx, size=(400, 400), max_project=True, probability=.75)
data = tiffile.imread(filename)
pmap, bs = cnn(data)
'''
cnn output
pmap : array
probablity map
bs : tuple
bs[0] list of arrays of the bounding box boundaries
bs[1] score (probablity) of each object being predicted class
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment