Instantly share code, notes, and snippets.
Last active
July 28, 2021 05:07
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(1)
1
You must be signed in to fork a gist
-
Save s-m-e/bf01cba201d27c793873b819264d3a4b to your computer and use it in GitHub Desktop.
Pure Python & Numpy openMP-style for-loop based on POSIX shared memory and forks for Unix-like systems
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
SHARED MEMORY TESTS | |
Pure Python & Numpy openMP-style for-loop based on POSIX shared memory | |
and forks for Unix-like systems (tested on Linux) | |
Copyright (C) 2021 Sebastian M. Ernst <[email protected]> | |
Inspired by https://github.com/albertz/playground/blob/master/shared_mem.py | |
<LICENSE_BLOCK> | |
The contents of this file are subject to the GNU Lesser General Public License | |
Version 2.1 ("LGPL" or "License"). You may not use this file except in | |
compliance with the License. You may obtain a copy of the License at | |
https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt | |
Software distributed under the License is distributed on an "AS IS" basis, | |
WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for the | |
specific language governing rights and limitations under the License. | |
</LICENSE_BLOCK> | |
""" | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# IMPORT | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
from abc import ABC | |
import ctypes | |
from functools import partial, wraps | |
from math import ceil | |
import os | |
import sys | |
from types import FunctionType | |
from typing import Any, Callable, Generator, Dict, Tuple | |
import numpy as np | |
try: | |
from typeguard import typechecked | |
except ModuleNotFoundError: | |
typechecked = lambda x: x | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# LIB-C API | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
libc = ctypes.CDLL( | |
"", # linux only | |
use_errno = True, | |
use_last_error = True, | |
) | |
shm_key_t = ctypes.c_int | |
IPC_PRIVATE = 0 | |
IPC_RMID = 0 | |
# int shmget(key_t key, size_t size, int shmflg); | |
shmget = libc.shmget | |
shmget.restype = ctypes.c_int | |
shmget.argtypes = (shm_key_t, ctypes.c_size_t, ctypes.c_int) | |
# void* shmat(int shmid, const void *shmaddr, int shmflg); | |
shmat = libc.shmat | |
shmat.restype = ctypes.c_void_p | |
shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) | |
# int shmdt(const void *shmaddr); | |
shmdt = libc.shmdt | |
shmdt.restype = ctypes.c_int | |
shmdt.argtypes = (ctypes.c_void_p,) | |
# int shmctl(int shmid, int cmd, struct shmid_ds *buf); | |
shmctl = libc.shmctl | |
shmctl.restype = ctypes.c_int | |
shmctl.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_void_p) | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# SHARED MEMORY CLASSES | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
class SharedMemoryABC(ABC): | |
pass | |
@typechecked | |
class SharedMemory(SharedMemoryABC): | |
def __init__(self, ptr: int, shmid: int, size: int): | |
self._ptr = ptr | |
self._shmid = shmid | |
self._size = size | |
def close(self): | |
raise NotImplementedError() | |
def reattach(self) -> SharedMemoryABC: | |
return SharedMemoryClient(size = self._size, shmid = self._shmid) | |
@property | |
def closed(self) -> bool: | |
return self._ptr is None | |
@property | |
def ptr(self) -> int: | |
if self.closed: | |
raise ValueError('closed') | |
return self._ptr | |
@property | |
def shmid(self) -> int: | |
if self.closed: | |
raise ValueError('closed') | |
return self._shmid | |
@property | |
def size(self) -> int: | |
return self._size | |
@typechecked | |
class SharedMemoryServer(SharedMemory): | |
def __init__(self, size: int): | |
assert size > 0 | |
shmid = shmget(IPC_PRIVATE, size, 0o600) | |
assert shmid > 0 | |
ptr = shmat(shmid, 0, 0) | |
assert ptr | |
super().__init__(ptr = ptr, shmid = shmid, size = size) | |
def close(self): | |
shmdt(self._ptr) | |
self._ptr = None | |
shmctl(self._shmid, IPC_RMID, 0) | |
self._shmid = None | |
@typechecked | |
class SharedMemoryClient(SharedMemory): | |
def __init__(self, size: int, shmid: int): | |
assert size > 0 | |
assert shmid > 0 | |
ptr = shmat(shmid, 0, 0) | |
assert ptr | |
super().__init__(ptr = ptr, shmid = shmid, size = size) | |
def close(self): | |
shmdt(self._ptr) | |
self._ptr = None | |
self._shmid = None | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# SHARED NUMPY NDARRAY TYPE | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
class sndarrayABC(ABC): | |
pass | |
@typechecked | |
class sndarray(sndarrayABC, np.ndarray): | |
"shared nd array" | |
def __new__(cls, *args, shm: SharedMemory, **kwargs) -> sndarrayABC: | |
obj = np.ndarray.__new__(np.ndarray, *args, **kwargs).view(cls) | |
obj._shm = shm | |
return obj | |
@property | |
def shm(self) -> SharedMemory: | |
return self._shm | |
@shm.setter | |
def shm(self, value: SharedMemory): | |
self._shm = value | |
@staticmethod | |
def _make_buffer(array: np.ndarray, shm: SharedMemory) -> np.ndarray: | |
class Buffer: | |
__array_interface__ = { | |
"data": (shm.ptr, False), | |
"shape": array.shape, | |
"strides": array.__array_interface__['strides'], | |
"typestr": array.dtype.str, | |
"version": 3, | |
} | |
return np.array(Buffer, copy = False) | |
def close(self): | |
self.shm.close() | |
def info(self) -> str: | |
return f'<sndarray id={id(self):x} data={self.__array_interface__["data"][0]:x} shmid={self.shm.shmid:d}>' | |
def reattach(self) -> sndarrayABC: | |
shm = self.shm.reattach() | |
buffer = self._make_buffer(self, shm) | |
return type(self)( | |
shape = self.shape, | |
dtype = self.dtype, | |
order = 'F' if not self.flags.c_contiguous and self.flags.f_contiguous else None, | |
strides = self.__array_interface__['strides'], | |
buffer = buffer, | |
shm = shm, | |
) | |
@classmethod | |
def from_array(cls, array: np.ndarray) -> sndarrayABC: | |
shm = SharedMemoryServer(size = int(np.prod(array.shape) * array.itemsize)) | |
buffer = cls._make_buffer(array, shm) | |
buffer[...] = array | |
return cls( | |
shape = array.shape, | |
dtype = array.dtype, | |
order = 'F' if not array.flags.c_contiguous and array.flags.f_contiguous else None, | |
strides = array.__array_interface__['strides'], | |
buffer = buffer, | |
shm = shm, | |
) | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# PRANGE (similar to Cython's and numba's prange) | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
@typechecked | |
def prange( | |
*args: int, | |
**kwargs: Any, | |
) -> Generator[int, None, None]: | |
raise SystemError('prange called outside of parallel function') | |
@typechecked | |
def _check_prange_nesting(func: Callable) -> Callable: | |
@wraps(func) | |
def wrapped( | |
*args, | |
ctx: Dict[str, Any], | |
**kwargs, | |
) -> Generator[int, None, None]: | |
if ctx['in_prange']: | |
raise SystemError('nested prange not permitted') | |
ctx['in_prange'] = True | |
try: | |
yield from func(*args, **kwargs) | |
finally: | |
ctx['in_prange'] = False | |
return wrapped | |
@_check_prange_nesting | |
@typechecked | |
def _prange( | |
*args, | |
worker_id: int, | |
processes: int, | |
**kwargs, | |
) -> Generator[int, None, None]: | |
if worker_id < 0: | |
raise ValueError('negative worker id not permitted') | |
if worker_id >= processes: | |
raise ValueError('worker id greater or equal to number of processes not permitted') | |
if processes < 1: | |
raise ValueError('at least one process required') | |
values = range(*args, **kwargs) | |
if len(values) == 0: | |
return | |
width = ceil(len(values) / processes) | |
yield from values[worker_id * width : (worker_id + 1) * width] | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# DECORATOR FOR PARALLEL FUNCTIONS | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
@typechecked | |
def parallel( | |
processes: int = 1, | |
shared: Tuple[str, ...] = tuple(), | |
): | |
assert processes > 0 | |
@typechecked | |
def wrapper(func: Callable) -> Callable: | |
@wraps(func) | |
def wrapped(*args: Any, **kwargs: Any): | |
if len(args) > 0: | |
raise SystemError('positional arguments not permitted') | |
buffer = {} | |
for name in shared: | |
var = kwargs[name] | |
if not isinstance(var, np.ndarray): | |
raise TypeError('can only share numpy arrays') | |
if isinstance(var, sndarray): | |
raise TypeError('array already shared') | |
buffer[name] = var # buffering unshared memory | |
kwargs[name] = sndarray.from_array(var) # sharing memory | |
worker_id = None | |
worker_pids = [] | |
for idx in range(processes): | |
pid = os.fork() | |
if pid == 0: # in fork | |
worker_id = idx | |
break | |
# in parent | |
worker_pids.append(pid) | |
if worker_id is None: # in parent | |
clean_exit = True | |
for pid in worker_pids: | |
_, status = os.waitpid(pid, 0) # wait for forks | |
if status != 0: | |
clean_exit = False | |
for name, var in buffer.items(): | |
var[...] = kwargs[name] # sync shared with buffered unshared memory | |
kwargs[name].close() # close (buffered) shared memory in parent | |
if not clean_exit: | |
raise SystemError('worker exited with non-zero exit status') | |
return # leave parent | |
# in worker | |
kwargs_update = {} | |
for name, value in kwargs.items(): | |
if not isinstance(value, sndarray): # ignore unshared memory | |
continue | |
kwargs_update[name] = value.reattach() # attach to shared memory in worker | |
kwargs.update(kwargs_update) | |
new_globals = globals().copy() | |
new_globals.update(dict( | |
prange = partial( | |
_prange, | |
worker_id = worker_id, | |
processes = processes, | |
ctx = dict(in_prange = False), | |
), # custom prange | |
worker_id = worker_id, | |
processes = processes, | |
)) # extra parameters for parallel function | |
new_func = FunctionType( | |
code = getattr(func, '__code__'), | |
globals = new_globals, | |
) # injecting parameters into parallel function | |
ret = new_func(**kwargs) # call parallel function | |
for value in kwargs.values(): | |
if not isinstance(value, sndarray): # ignore unshared memory | |
continue | |
value.close() # close shared memory in worker | |
if ret is not None: | |
raise SystemError('return values are not supported') | |
sys.exit(0) # terminate fork | |
return wrapped | |
return wrapper | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# DEMO 1 | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
@typechecked | |
@parallel(processes = 2) | |
def task1(a: np.ndarray, b: int): | |
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] Hello!') | |
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] {a.info():s}') | |
for idx in prange(a.shape[0]): | |
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] idx=={idx:d}') | |
a[idx] += b | |
def demo1(): | |
print('demo1') | |
a = sndarray.from_array(np.arange(1, 11, dtype = 'f8')) | |
print(a.info()) | |
print(a) | |
task1(a = a, b = 7) | |
print(a.info()) | |
print(a) | |
a.close() | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# DEMO 2 | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
@typechecked | |
@parallel(processes = 2, shared = ('a',)) | |
def task2(a: np.ndarray, b: int): | |
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] Hello!') | |
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] {a.info():s}') | |
for idx in prange(a.shape[0]): | |
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] idx=={idx:d}') | |
a[idx] += b | |
def demo2(): | |
print('demo2') | |
a = np.arange(1, 11, dtype = 'f8') | |
print(id(a)) | |
print(a) | |
task2(a = a, b = 7) | |
print(id(a)) | |
print(a) | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
# ENTRY POINT | |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
if __name__ == "__main__": | |
demo1() | |
demo2() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment