Skip to content

Instantly share code, notes, and snippets.

@sofroniewn
Created September 17, 2019 23:37
Show Gist options
  • Save sofroniewn/2e1d5068a979e4393fd549dff675d543 to your computer and use it in GitHub Desktop.
Save sofroniewn/2e1d5068a979e4393fd549dff675d543 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dask with PyTorch for large scale image analysis\n",
"By Nicholas Sofroniew, Matthew Rocklin"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Executive Summary\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This post explores applying a pre-trained [PyTorch](https://pytorch.org/) model in parallel with Dask Array.\n",
"\n",
"We cover a simple example applying a pre-trained UNet to a stack of images to generate features for every pixel."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A Worked Example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let’s start with an example applying a pre-trained [UNet](https://arxiv.org/abs/1505.04597) to a stack of light sheet microscopy data. This particular UNet takes in an 2D image and returns a 2D x 16 array, where each pixel is now associate with a feature vector of length 16. Thanks to Mars Huang for training this particular UNet on a corpous of biological images to produce biologically relevant feature vectors during his work on [interactive bio-image segmentation](https://github.com/transformify-plugins/segmentify). These features can then be used for more downstream image processing tasks such as image segmentation. We will use the same data that we analysed in our last [blogpost on Dask and ITK](https://blog.dask.org/2019/08/09/image-itk), and you should note the similarities to that workflow even though we are now using new libaries and performing different analyses."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/nicholassofroniew/Github/image-demos/data/LLSM\n"
]
}
],
"source": [
"cd '/Users/nicholassofroniew/Github/image-demos/data/LLSM'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dask.array<from-zarr, shape=(20, 199, 768, 1024), dtype=float32, chunksize=(1, 1, 768, 1024)>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load our data from last time¶\n",
"import dask.array as da\n",
"imgs = da.from_zarr(\"AOLLSM_m4_560nm.zarr\")\n",
"imgs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Load our pretrained UNet¶\n",
"import torch\n",
"from segmentify.model import UNet, layers\n",
"\n",
"def load_unet(path):\n",
" \"\"\"Load a pretrained UNet model.\n",
" \"\"\"\n",
"\n",
" # load in saved model\n",
" pth = torch.load(path)\n",
" model_args = pth['model_args']\n",
" model_state = pth['model_state']\n",
" model = UNet(**model_args)\n",
" model.load_state_dict(model_state)\n",
"\n",
" # remove last layer and activation\n",
" model.segment = layers.Identity()\n",
" model.activate = layers.Identity()\n",
" model.eval()\n",
" \n",
" return model\n",
"\n",
"model = load_unet(\"HPA_3.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dask.array<unet_featurize, shape=(20, 199, 768, 1024, 16), dtype=float32, chunksize=(1, 1, 768, 1024, 16)>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Apply UNet featurization¶\n",
"import numpy as np\n",
"\n",
"def unet_featurize(image, model):\n",
" \"\"\"Featurize pixels in an image using pretrained UNet model.\n",
" \"\"\"\n",
" import numpy as np\n",
" import torch\n",
"\n",
" # remove leading two length-one dimensions\n",
" img = image[0, 0, ...]\n",
" \n",
" # make sure image has four dimentions (b,c,w,h)\n",
" img = np.expand_dims(np.expand_dims(img, 0), 0)\n",
" img = np.transpose(img, (1,0,2,3))\n",
" \n",
" # convert image to torch Tensor\n",
" img = torch.Tensor(img).float()\n",
"\n",
" # pass image through model\n",
" with torch.no_grad():\n",
" features = model(img).numpy()\n",
" \n",
" # generate feature vectors (w,h,f)\n",
" features = np.transpose(features, (0,2,3,1))[0]\n",
" \n",
" # Add back the leading length-one dimensions\n",
" result = features[None, None, ...]\n",
"\n",
" return result\n",
"\n",
"out = da.map_blocks(unet_featurize, imgs, model, dtype=np.float32, chunks=(1, 1, imgs.shape[2], imgs.shape[3], 16), new_axis=-1)\n",
"out"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Trigger computation and store\n",
"out.to_zarr(\"AOLLSM_featurized.zarr\", overwrite=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So in the example above we …\n",
"\n",
"1. Load data from Zarr into a multi-chunked Dask array\n",
"2. Load a pre-trained PyTorch model that featurizes images\n",
"3. Construct a function to apply the model onto each chunk\n",
"4. Apply that function across the dask array with the dask.array.map_blocks function.\n",
"5. Store the result back into Zarr format\n",
"\n",
"This workflow was very similar to our example using the dask.array.map_blocks function with ITK to perform image deconvolution. Because Dask arrays are just made out of Numpy arrays which are easily converted to Torch arrays, we're now also able to leverage the power of machine learning at scale."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@sofroniewn
Copy link
Author

@mrocklin here's a rough draft of potential blog post. It only does image "featurization" and doesn't go all the way to "segmentation", but does use a pre-trained UNet from PyTorch. Let me know how you want to progress / clean up etc.

@sofroniewn
Copy link
Author

oh also my chunks are slightly different from your chunks in the previous blog posts ...

@mrocklin
Copy link

Woo! I'm going to copy a link to this from a public issue so that hopefully we can get a few more eyes on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment