Last active
October 7, 2021 18:28
-
-
Save cwood1967/aa74d475c4f52108e7f6bcd30a878478 to your computer and use it in GitHub Desktop.
Training with mask_rcnn using cellfinder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import warnings | |
import numpy as np | |
import torch | |
import tifffile | |
from cellfinder import train, infer | |
warnings.filterwarnings('ignore') | |
''' | |
https://github.com/cwood1967/xlearn | |
Arguments to train.main: | |
root : str | |
The path to the data directory. Default is 'Data' | |
image_dir : str | |
Directory inside to root where traing images are located. Default 'Images' | |
mask_dir : str | |
Directory for training masks. Default is 'Masks' | |
epochs : int | |
Number of epochs to train. Default 50 | |
cropsize : tuple (int, int) | |
Size of patches to use in training. Default (400, 400) | |
batch_size : int | |
Size of image batches. Default 8 | |
train.main will save a snapstop after any epoch that results in a lower loss. The directory | |
where the snapshots are saved is created during training. | |
''' | |
tm = train.main('.', 'Images', 'Masks', cropsize=(400, 400), batch_size=8, epochs=100) | |
''' | |
inferring | |
''' | |
# load the saved model files | |
model_files = sorted(glob.glob('model_2021-10-06-16-29-00/trained_min_mask_model_*.pt')) | |
# load the network. The int argument is the number of classes to train on (including background) | |
tmx = train.get_model(2) | |
# load the network with the saved weights | |
tmx.load_state_dict(torch.load(model_files[-1])) | |
# create the inference class | |
cnn = infer.predict(tmx, size=(400, 400), max_project=True, probability=.75) | |
data = tiffile.imread(filename) | |
pmap, bs = cnn(data) | |
''' | |
cnn output | |
pmap : array | |
probablity map | |
bs : tuple | |
bs[0] list of arrays of the bounding box boundaries | |
bs[1] score (probablity) of each object being predicted class | |
''' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment