Created
November 13, 2020 16:36
-
-
Save will-moore/819ade3c4e46864d9405555a1bf4933c to your computer and use it in GitHub Desktop.
Performance test for concatenating tiles into multi-dimensional pyramid with dask.array.stack vv map_blocks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Display a 5D dask multiscale pyramid | |
""" | |
from dask import array as da | |
from dask import delayed | |
import datetime | |
import numpy as np | |
import napari | |
from math import ceil | |
USE_MAP_BLOCKS = True | |
def stitch_planes(planes, axis=0): | |
if not USE_MAP_BLOCKS: | |
return da.concatenate(planes, axis=axis) | |
arrayfunc = np.asanyarray | |
stitched_shape = list(planes[0].shape) | |
chunk = planes[0].shape | |
stitched_shape[axis] = stitched_shape[axis] * len(planes) | |
stitched_shape = tuple(stitched_shape) | |
dtype = planes[0].dtype | |
chunks = da.core.normalize_chunks(chunk, stitched_shape) | |
a = da.map_blocks( | |
read_data, | |
chunks=chunks, | |
planes=planes, | |
axis=axis, | |
arrayfunc=arrayfunc, | |
meta=arrayfunc([]).astype(dtype), # meta overwrites `dtype` argument | |
) | |
return a | |
def read_data(planes, axis, block_info=None, **kwargs): | |
"""map_blocks passes in the planes""" | |
i = block_info[None]['chunk-location'][axis] | |
return planes[i] | |
def stack_planes(planes): | |
if not USE_MAP_BLOCKS: | |
return da.stack(planes) | |
arrayfunc = np.asanyarray | |
shape = (len(planes),) + planes[0].shape | |
dtype = planes[0].dtype | |
chunks = da.core.normalize_chunks((1,) + shape[1:], shape) | |
a = da.map_blocks( | |
read_data_with_extra_dimension, | |
chunks=chunks, | |
planes = planes, | |
arrayfunc=arrayfunc, | |
meta=arrayfunc([]).astype(dtype), # meta overwrites `dtype` argument | |
) | |
return a | |
def read_data_with_extra_dimension(planes, block_info=None, **kwargs): | |
"""map_blocks passes in the planes""" | |
i, j = block_info[None]['array-location'][0] | |
return np.expand_dims(planes[i], axis=0) | |
def get_tile(tile_name): | |
"""Return a tile for the given coordinates""" | |
print('get_tile level, t, c, z, y, x, w, h', tile_name) | |
level, t, c, z, y, x, w, h = [int(n) for n in tile_name.split(",")] | |
def f2(x, y): | |
# Try to return a tile that depends on level and z, c, t | |
if c % 2 == 1: | |
return (y + (2 * t) + (2 * z)) | |
else: | |
return (x + ((level % 2) * y)) // 2 | |
plane_2d = np.fromfunction(f2, (h, w), dtype=np.int16) | |
return plane_2d | |
lazy_reader = delayed(get_tile) | |
def get_lazy_plane(level, t, c, z, plane_y, plane_x, tile_shape): | |
print('get_lazy_plane: level, t, c, z, plane_y, plane_x', level, t, c, z, plane_y, plane_x) | |
tile_w, tile_h = tile_shape | |
rows = ceil(plane_y / tile_h) | |
cols = ceil(plane_x / tile_w) | |
print('rows', rows, 'cols', cols) | |
lazy_rows = [] | |
for row in range(rows): | |
lazy_row = [] | |
for col in range(cols): | |
x = col * tile_w | |
y = row * tile_h | |
w = min(tile_w, plane_x - x) | |
h = min(tile_h, plane_y - y) | |
tile_name = "%s,%s,%s,%s,%s,%s,%s,%s" % (level, t, c, z, y, x, w, h) | |
lazy_tile = da.from_delayed(lazy_reader(tile_name), shape=(h, w), dtype=np.int16) | |
lazy_row.append(lazy_tile) | |
lazy_row = stitch_planes(lazy_row, axis=1) | |
print('lazy_row.shape', lazy_row.shape) | |
lazy_rows.append(lazy_row) | |
return stitch_planes(lazy_rows, axis=0) | |
def get_pyramid_lazy(shape, tile_shape, levels): | |
"""Get a pyramid of rgb dask arrays, loading tiles from OMERO.""" | |
size_t, size_c, size_z, size_y, size_x = shape | |
pyramid = [] | |
plane_x = size_x | |
plane_y = size_y | |
for level in range(levels): | |
print('level', level) | |
t_stacks = [] | |
for t in range(size_t): | |
c_stacks = [] | |
for c in range(size_c): | |
z_stack = [] | |
for z in range(size_z): | |
lazy_plane = get_lazy_plane(level, t, c, z, plane_y, plane_x, tile_shape) | |
z_stack.append(lazy_plane) | |
c_stacks.append(stack_planes(z_stack)) | |
t_stacks.append(stack_planes(c_stacks)) | |
pyramid.append(stack_planes(t_stacks)) | |
plane_x = plane_x // 2 | |
plane_y = plane_y // 2 | |
print ('pyramid...') | |
for level in pyramid: | |
print(level.shape) | |
return pyramid | |
shape = (10, 2, 5, 3000, 5000) | |
tile_shape = (256, 256) | |
levels = 4 | |
start = datetime.datetime.now() | |
pyramid = get_pyramid_lazy(shape, tile_shape, levels) | |
lazy_timer = (datetime.datetime.now() - start).total_seconds() | |
print('lazy pyramid timer', lazy_timer) | |
# times = [] | |
# for level in range(levels, 0, -1): | |
# start = datetime.datetime.now() | |
# pyramid[level - 1].compute() | |
# timer = (datetime.datetime.now() - start).total_seconds() | |
# times.append(timer) | |
# print(f'Level {level - 1} compute timer', (datetime.datetime.now() - start).total_seconds()) | |
# print('shape', shape, 'tile_shape', tile_shape) | |
# print('lazy_pyramid creation', lazy_timer) | |
# print('compute times', times) | |
# Example output | |
# shape (10, 2, 5, 3000, 5000) tile_shape (256, 256) | |
# With USE_MAP_BLOCKS = False | |
# lazy_pyramid creation 8.401882 | |
# compute times [0.403972, 1.156574, 5.493262, 26.310839] | |
# With USE_MAP_BLOCKS = True (should be faster, but is slower!) | |
# lazy_pyramid creation 14.68406 | |
# compute times [1.341739, 2.146021, 8.515899, 36.635117] | |
with napari.gui_qt(): | |
viewer = napari.view_image(pyramid, channel_axis=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment