|
import mmap |
|
import tempfile |
|
import weakref |
|
import uuid |
|
|
|
import SharedArray as sa |
|
import numpy as np |
|
"""https://gitlab.com/tenzing/shared-array""" |
|
|
|
|
|
def from_shm(name): |
|
return sa.attach(name) |
|
|
|
|
|
def try_remove(u): |
|
try: |
|
sa.delete(u) |
|
except FileNotFoundError: |
|
# was already removed |
|
pass |
|
|
|
|
|
class SharedWrappedArray: |
|
def __init__(self, data=None, u=None): |
|
if data is None and u is None: |
|
raise ValueError("Must supply data or shm handle") |
|
self.data = data |
|
self.u = u |
|
self.m = sa.attach(u) if u else None |
|
|
|
def __getattr__(self, item): |
|
if item in ["__array__", "__array_function__", "__array_interface__", |
|
"__array_ufunc__"]: |
|
if self.data is not None: |
|
return getattr(self.data, item) |
|
return getattr(self.m, item) |
|
raise AttributeError(item) |
|
|
|
def __dask_tokenize__(self): |
|
return f"shared_np_{self._to_shm()}" |
|
|
|
def _to_shm(self): |
|
if self.m is None: |
|
u = uuid.uuid4().hex[:8] |
|
# copy - shared version is read-write, but does not change original |
|
self.m = sa.create(u, self.data.shape, self.data.dtype) |
|
self.m[:] = self.data |
|
weakref.finalize(self.m, lambda: try_remove(u)) |
|
self.u = u |
|
return self.u |
|
|
|
def __reduce__(self): |
|
self._to_shm() |
|
return from_shm, (self.u, ) |
|
|
|
|
|
def remake(fn, dtype, shape): |
|
with open(fn, 'r+b') as f: |
|
m = mmap.mmap(f.fileno(), 0) |
|
return np.frombuffer(m, dtype=dtype).reshape(shape) |
|
|
|
|
|
class SharedMMapArray: |
|
def __init__(self, data=None): |
|
if data is None and u is None: |
|
raise ValueError("Must supply data or shm handle") |
|
self.data = data |
|
self.fn = None |
|
|
|
def __dask_tokenize__(self): |
|
return f"shared_np_{self._to_shm()}" |
|
|
|
def _to_shm(self): |
|
if self.fn is None: |
|
self.fn = tempfile.mktemp() |
|
with open(self.fn, "wb") as f: |
|
self.data.tofile(f) |
|
|
|
def __reduce__(self): |
|
self._to_shm() |
|
return remake, (self.fn, self.data.dtype, self.data.shape) |
|
|
|
|
|
def simple_bench(): |
|
import dask.distributed |
|
import numpy |
|
import time |
|
x = np.ones((1000, 100000)) |
|
x2 = SharedWrappedArray(x) |
|
x3 = SharedMMapArray(x) |
|
with dask.distributed.Client(n_workers=1) as client: |
|
t9 = time.time() |
|
assert client.submit(lambda x: x.sum(), x3).result() == x.sum() |
|
t8 = time.time() |
|
assert client.submit(lambda x: x.sum(), x3).result() == x.sum() |
|
t0 = time.time() |
|
assert client.submit(lambda x: x.sum(), x).result() == x.sum() |
|
t1 = time.time() |
|
assert client.submit(lambda x: x.sum(), x2).result() == x.sum() |
|
t2 = time.time() |
|
assert client.submit(lambda x: x.sum(), x2).result() == x.sum() |
|
t3 = time.time() |
|
f = client.scatter(x) |
|
assert client.submit(lambda x: x.sum(), f).result() == x.sum() |
|
t4 = time.time() |
|
print("mmap 1:", t8 - t9) |
|
print("mmap2:", t0 - t8) |
|
print("numpy time:", t1 - t0) |
|
print("shared time, first:", t2 - t1) |
|
print("shared time, second:", t3 - t2) |
|
print("scatter time:", t4 - t3) |