Skip to content

Instantly share code, notes, and snippets.

@oeway
Last active August 14, 2019 20:45
Show Gist options
  • Save oeway/5edb106eb9360405412bba8ebd2dbeb5 to your computer and use it in GitHub Desktop.
Save oeway/5edb106eb9360405412bba8ebd2dbeb5 to your computer and use it in GitHub Desktop.
<docs lang="markdown">
[TODO: write documentation for this plugin.]
</docs>
<config lang="json">
{
"name": "Unet-Segmentation",
"type": "native-python",
"version": "0.1.0",
"description": "This plugin uses U-net to segment images.",
"tags": [],
"ui": "",
"cover": "",
"inputs": null,
"outputs": null,
"flags": [],
"icon": "extension",
"api_version": "0.1.6",
"env": "conda create -n unet-segmentation-cpu python=3.6.8",
"permissions": [],
"requirements": ["repo:https://github.com/imjoy-team/unet", "pip:keras==2.1.4 tensorflow==1.8.0"],
"dependencies": []
}
</config>
<script lang="python">
import os
os.chdir('unet')
from model import *
from data import *
from imjoy import api
from keras.callbacks import LambdaCallback
from keras.models import load_model
class ImJoyPlugin():
def setup(self):
api.register(name="train", ui="epochs: {id: 'epochs', type: 'number', min: 1, max: 100000, placeholder: 10}")
api.register(name="predict", ui="perform unet prediction")
async def predict(self):
ret = await api.showFileDialog(type="directory", title="please select a test input folder", engine=api.ENGINE_URL)
test_folder = ret.path # "data/membrane/test"
model = load_model('unet_membrane.hdf5')
testGene = testGenerator(test_folder)
api.showStatus('Start prediction...')
results = model.predict_generator(testGene,30,verbose=1)
saveResult(test_folder,results)
async def train(self, epochs):
data_gen_args = dict(rotation_range=0.2,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.05,
zoom_range=0.05,
horizontal_flip=True,
fill_mode='nearest')
ret = await api.showFileDialog(type="directory", title="please select a training input folder", engine=api.ENGINE_URL)
train_folder = ret.path
api.showStatus('Preparing data...')
myGene = trainGenerator(2, train_folder,'image','label',data_gen_args,save_to_dir = None)
model = unet()
model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', monitor='loss',verbose=1, save_best_only=True)
# define callback
def on_epoch_end(epoch, logs):
api.showStatus('Training epoch: {}/{} {}'.format(epoch+1, epochs, logs))
api.showProgress((epoch+1)/epochs*100)
progress_callback = LambdaCallback(on_epoch_end=on_epoch_end)
api.showStatus('Start training...')
# set the callback
model.fit_generator(myGene,steps_per_epoch=300,epochs=epochs,callbacks=[progress_callback, model_checkpoint])
async def run(self, ctx):
if ctx._op == 'train':
epochs = ctx.config.epochs
await self.train(int(epochs))
api.alert('training done!')
elif ctx._op == 'predict':
await self.predict()
api.alert('prediction done!')
api.export(ImJoyPlugin())
</script>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment