Last active
August 14, 2019 20:45
-
-
Save oeway/5edb106eb9360405412bba8ebd2dbeb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<docs lang="markdown"> | |
[TODO: write documentation for this plugin.] | |
</docs> | |
<config lang="json"> | |
{ | |
"name": "Unet-Segmentation", | |
"type": "native-python", | |
"version": "0.1.0", | |
"description": "This plugin uses U-net to segment images.", | |
"tags": [], | |
"ui": "", | |
"cover": "", | |
"inputs": null, | |
"outputs": null, | |
"flags": [], | |
"icon": "extension", | |
"api_version": "0.1.6", | |
"env": "conda create -n unet-segmentation-cpu python=3.6.8", | |
"permissions": [], | |
"requirements": ["repo:https://github.com/imjoy-team/unet", "pip:keras==2.1.4 tensorflow==1.8.0"], | |
"dependencies": [] | |
} | |
</config> | |
<script lang="python"> | |
import os | |
os.chdir('unet') | |
from model import * | |
from data import * | |
from imjoy import api | |
from keras.callbacks import LambdaCallback | |
from keras.models import load_model | |
class ImJoyPlugin(): | |
def setup(self): | |
api.register(name="train", ui="epochs: {id: 'epochs', type: 'number', min: 1, max: 100000, placeholder: 10}") | |
api.register(name="predict", ui="perform unet prediction") | |
async def predict(self): | |
ret = await api.showFileDialog(type="directory", title="please select a test input folder", engine=api.ENGINE_URL) | |
test_folder = ret.path # "data/membrane/test" | |
model = load_model('unet_membrane.hdf5') | |
testGene = testGenerator(test_folder) | |
api.showStatus('Start prediction...') | |
results = model.predict_generator(testGene,30,verbose=1) | |
saveResult(test_folder,results) | |
async def train(self, epochs): | |
data_gen_args = dict(rotation_range=0.2, | |
width_shift_range=0.05, | |
height_shift_range=0.05, | |
shear_range=0.05, | |
zoom_range=0.05, | |
horizontal_flip=True, | |
fill_mode='nearest') | |
ret = await api.showFileDialog(type="directory", title="please select a training input folder", engine=api.ENGINE_URL) | |
train_folder = ret.path | |
api.showStatus('Preparing data...') | |
myGene = trainGenerator(2, train_folder,'image','label',data_gen_args,save_to_dir = None) | |
model = unet() | |
model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', monitor='loss',verbose=1, save_best_only=True) | |
# define callback | |
def on_epoch_end(epoch, logs): | |
api.showStatus('Training epoch: {}/{} {}'.format(epoch+1, epochs, logs)) | |
api.showProgress((epoch+1)/epochs*100) | |
progress_callback = LambdaCallback(on_epoch_end=on_epoch_end) | |
api.showStatus('Start training...') | |
# set the callback | |
model.fit_generator(myGene,steps_per_epoch=300,epochs=epochs,callbacks=[progress_callback, model_checkpoint]) | |
async def run(self, ctx): | |
if ctx._op == 'train': | |
epochs = ctx.config.epochs | |
await self.train(int(epochs)) | |
api.alert('training done!') | |
elif ctx._op == 'predict': | |
await self.predict() | |
api.alert('prediction done!') | |
api.export(ImJoyPlugin()) | |
</script> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment