Created
July 26, 2017 07:56
-
-
Save bmerry/a254a69597dce9b6957fe4470c1bbf84 to your computer and use it in GitHub Desktop.
Benchmarks of different implementations of dask.array.where
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
No broadcasting or scalars | |
where_orig : 0.000482s ± 0.000005 (construct), 0.165047s ± 0.001631s (compute) | |
where_where: 0.000273s ± 0.000002 (construct), 0.070438s ± 0.000439s (compute) | |
where_new : 0.000289s ± 0.000004 (construct), 0.069269s ± 0.000712s (compute) | |
Scalar condition | |
where_orig : 0.000013s ± 0.000000 (construct), 0.024283s ± 0.000316s (compute) | |
where_where: 0.000012s ± 0.000000 (construct), 0.024131s ± 0.000329s (compute) | |
where_new : 0.000017s ± 0.000000 (construct), 0.023551s ± 0.000195s (compute) | |
Broadcasting | |
where_orig : 0.000810s ± 0.000008 (construct), 0.168275s ± 0.001371s (compute) | |
where_where: 0.000593s ± 0.000003 (construct), 0.061382s ± 0.000654s (compute) | |
where_new : 0.000392s ± 0.000002 (construct), 0.058013s ± 0.000645s (compute) | |
Broadcasting, many chunks | |
where_orig : 0.022362s ± 0.000423 (construct), 1.387578s ± 0.004326s (compute) | |
where_where: 0.022167s ± 0.000406 (construct), 1.315636s ± 0.004534s (compute) | |
where_new : 0.016981s ± 0.000361 (construct), 0.875134s ± 0.002577s (compute) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import print_function, division | |
import dask.array as da | |
import numpy as np | |
import scipy.stats | |
import time | |
# Original version function from master (0ef9424) | |
def where_orig(condition, x=None, y=None): | |
if x is None or y is None: | |
raise TypeError(da.core.where_error_message) | |
x = da.asarray(x) | |
y = da.asarray(y) | |
shape = da.core.broadcast_shapes(x.shape, y.shape) | |
dtype = np.promote_types(x.dtype, y.dtype) | |
x = da.broadcast_to(x, shape).astype(dtype) | |
y = da.broadcast_to(y, shape).astype(dtype) | |
if np.isscalar(condition): | |
return x if condition else y | |
else: | |
condition = da.asarray(condition).astype('bool') | |
return da.choose(condition, [y, x]) | |
# Variant of where_orig that uses np.where instead of np.choose underneath | |
def where_where(condition, x=None, y=None): | |
if x is None or y is None: | |
raise TypeError(da.core.where_error_message) | |
x = da.asarray(x) | |
y = da.asarray(y) | |
shape = da.core.broadcast_shapes(x.shape, y.shape) | |
dtype = np.promote_types(x.dtype, y.dtype) | |
x = da.broadcast_to(x, shape).astype(dtype) | |
y = da.broadcast_to(y, shape).astype(dtype) | |
if np.isscalar(condition): | |
return x if condition else y | |
else: | |
return da.core.elemwise(np.where, condition, x, y) | |
# New version from my branch (20d7682) | |
def result_type(*args): | |
args = [a if da.core.is_scalar_for_elemwise(a) else a.dtype for a in args] | |
return np.result_type(*args) | |
def where_new(condition, x=None, y=None): | |
if x is None or y is None: | |
raise TypeError(where_error_message) | |
if np.isscalar(condition): | |
dtype = result_type(x, y) | |
x = da.asarray(x) | |
y = da.asarray(y) | |
shape = da.core.broadcast_shapes(x.shape, y.shape) | |
out = x if condition else y | |
return da.broadcast_to(out, shape).astype(dtype) | |
else: | |
return da.core.elemwise(np.where, condition, x, y) | |
def time_function(func, passes): | |
func() # Warmup | |
times = [] | |
for i in range(passes): | |
start = time.time() | |
func() | |
end = time.time() | |
times.append(end - start) | |
return np.mean(times), scipy.stats.sem(times) | |
def timer(name, func, passes=20): | |
construct_time, construct_std = time_function(func, passes * 10) | |
array = func() | |
compute_time, compute_std = time_function(array.compute, passes) | |
print('{}: {:.6f}s ± {:.6f} (construct), {:.6f}s ± {:.6f}s (compute)'.format( | |
name, construct_time, construct_std, compute_time, compute_std)) | |
def time_all(*args, **kwargs): | |
timer('where_orig ', lambda: where_orig(*args, **kwargs)) | |
timer('where_where', lambda: where_where(*args, **kwargs)) | |
timer('where_new ', lambda: where_new(*args, **kwargs)) | |
print('No broadcasting or scalars') | |
a = da.zeros(10000000, chunks=1000000) | |
b = da.ones(10000000, chunks=1000000) | |
c = da.random.randint(0, 2, 10000000, chunks=1000000) | |
time_all(c, a, b) | |
print('Scalar condition') | |
time_all(True, a, b) | |
print('Broadcasting') | |
b2 = da.ones(1, chunks=1) | |
time_all(c, a, b2) | |
print('Broadcasting, many chunks') | |
a2 = da.zeros(10000000, chunks=10000) | |
time_all(c, a2, b2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment