Created
October 9, 2020 10:59
-
-
Save m-albert/a2aa957bf0e96b665d8d53c723506e41 to your computer and use it in GitHub Desktop.
Affine transformation for dask arrays: Wrapper around `ndimage.affine_transform`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Affine transformation for dask arrays: Wrapper around `ndimage.affine_transform` | |
""" | |
__author__ = "Marvin Albert" | |
__email__ = "[email protected]" | |
import numpy as np | |
import dask.array as da | |
from scipy import ndimage | |
def affine_transform_dask( | |
input, | |
matrix, | |
offset=0.0, | |
output_shape=None, | |
output_chunks=None, | |
**kwargs | |
): | |
""" | |
Wraps `ndimage.affine_transformation` for dask arrays. | |
For every output chunk, only the slice containing the | |
relevant part of the input is passed on to | |
`ndimage.affine_transformation`. | |
To do: | |
- optionally use cupyx.scipy.ndimage.affine_transform | |
API wraps `ndimage.affine_transformation`, except for `output_chunks`. | |
:param input: N-D numpy or dask array | |
:param matrix: | |
:param offset: | |
:param output_shape: | |
:param output_chunks: | |
:param kwargs: | |
:return: dask array | |
""" | |
def resample_chunk(chunk, matrix, offset, kwargs, block_info=None): | |
N = chunk.ndim | |
input_shape = input.shape | |
chunk_shape = chunk.shape | |
chunk_offset = [i[0] for i in block_info[0]['array-location']] | |
# print('chunk_offset', chunk_offset) | |
chunk_edges = np.array([i for i in np.ndindex(tuple([2] * N))])\ | |
* np.array(chunk_shape) + np.array(chunk_offset) | |
rel_input_edges = np.dot(matrix, chunk_edges.T).T + offset | |
# print('rel_input_edges', rel_input_edges) # ok | |
# print('chunk_edges', chunk_edges) # ok | |
rel_input_i = np.min(rel_input_edges, 0) | |
rel_input_f = np.max(rel_input_edges, 0) | |
# not sure yet how many additional pixels to include | |
# (depends on interp order?) | |
for dim, upper in zip(range(N), input_shape): | |
rel_input_i[dim] = np.clip(rel_input_i[dim] - 2, 0, upper) | |
rel_input_f[dim] = np.clip(rel_input_f[dim] + 2, 0, upper) | |
rel_input_i = rel_input_i.astype(np.int64) | |
rel_input_f = rel_input_f.astype(np.int64) | |
# print('min max input', rel_input_i, rel_input_f) | |
rel_input_slice = tuple([slice(int(rel_input_i[dim]), | |
int(rel_input_f[dim])) | |
for dim in range(N)]) | |
rel_input = input[rel_input_slice] | |
# print('rel_input_slice', rel_input_slice) | |
# modify offset to point into cropped input | |
# y = Mx + o | |
# coordinate substitution: | |
# y' = y - y0(min_coord_px) | |
# x' = x - x0(chunk_offset) | |
# then | |
# y' = Mx' + o + Mx0 - y0 | |
# M' = M | |
# o' = o + Mx0 - y0 | |
offset_prime = offset + np.dot(matrix, chunk_offset) - rel_input_i | |
chunk = ndimage.affine_transform(rel_input, | |
matrix, | |
offset_prime, | |
output_shape=chunk_shape, | |
**kwargs) | |
return chunk | |
if output_shape is None: output_shape = input.shape | |
transformed = da.zeros(output_shape, | |
dtype=input.dtype, | |
chunks=output_chunks) | |
transformed = transformed.map_blocks(resample_chunk, | |
dtype=input.dtype, | |
matrix=matrix, | |
offset=offset, | |
kwargs=kwargs, | |
) | |
return transformed | |
if __name__ == "__main__": | |
from timeit import default_timer as timer | |
from matplotlib import pyplot | |
import tifffile | |
# create test image | |
N = 3 | |
a = 100 | |
np.random.seed(0) | |
im = np.random.random([int(a / 20)] * N) | |
im = ndimage.zoom(im, 20, order=1) | |
im = im / im.max() | |
im *= 1000 | |
im = im.astype(np.uint16) | |
# transform into dask array | |
chunksize = [32] * N | |
dim = da.from_array(im, chunks=chunksize) | |
# define (random) transformation | |
matrix = np.eye(N) + (np.random.random((N, N)) - 0.5) / 5. | |
offset = (np.random.random(N) - 0.5) / 5. * np.array(im.shape) | |
print('matrix\n', matrix) | |
print('offset\n', offset) | |
# define resampling options | |
# output_shape = im.shape | |
output_shape = [int(a / 4)] * N | |
output_chunks = [32] * N | |
interp_order = 3 | |
# transform without dask | |
ti = timer() | |
im_t_nodask = ndimage.affine_transform(im, matrix, offset, | |
output_shape=output_shape, | |
order=interp_order) | |
tf = timer() | |
print('Timing without dask: %s seconds' %(tf-ti)) | |
# transform with function above using dask | |
ti = timer() | |
scheduler = 'single-threaded' | |
# scheduler = 'threads' | |
im_t_dask = affine_transform_dask(dim, matrix, offset, | |
output_shape=output_shape, | |
output_chunks=output_chunks, | |
order=interp_order) | |
im_t_dask_computed = im_t_dask.compute(scheduler=scheduler) | |
tf = timer() | |
print('Timing with dask: %s seconds' %(tf-ti)) | |
# write out dask graph to visualize chunk flow | |
# python-graphviz needs to be installed, see: | |
# https://docs.dask.org/en/latest/graphviz.html | |
# im_t_dask.visualize(filename='affine_transformation_dask.png') | |
# show and compare transformation results | |
tifffile.imshow(np.array([im_t_nodask, im_t_dask_computed]), vmin=0, vmax=1000) | |
pyplot.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment