Skip to content

Instantly share code, notes, and snippets.

@s-m-e
Last active July 28, 2021 05:07
Show Gist options
  • Save s-m-e/bf01cba201d27c793873b819264d3a4b to your computer and use it in GitHub Desktop.
Save s-m-e/bf01cba201d27c793873b819264d3a4b to your computer and use it in GitHub Desktop.
Pure Python & Numpy openMP-style for-loop based on POSIX shared memory and forks for Unix-like systems
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
SHARED MEMORY TESTS
Pure Python & Numpy openMP-style for-loop based on POSIX shared memory
and forks for Unix-like systems (tested on Linux)
Copyright (C) 2021 Sebastian M. Ernst <[email protected]>
Inspired by https://github.com/albertz/playground/blob/master/shared_mem.py
<LICENSE_BLOCK>
The contents of this file are subject to the GNU Lesser General Public License
Version 2.1 ("LGPL" or "License"). You may not use this file except in
compliance with the License. You may obtain a copy of the License at
https://www.gnu.org/licenses/old-licenses/lgpl-2.1.txt
Software distributed under the License is distributed on an "AS IS" basis,
WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for the
specific language governing rights and limitations under the License.
</LICENSE_BLOCK>
"""
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# IMPORT
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from abc import ABC
import ctypes
from functools import partial, wraps
from math import ceil
import os
import sys
from types import FunctionType
from typing import Any, Callable, Generator, Dict, Tuple
import numpy as np
try:
from typeguard import typechecked
except ModuleNotFoundError:
typechecked = lambda x: x
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# LIB-C API
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
libc = ctypes.CDLL(
"", # linux only
use_errno = True,
use_last_error = True,
)
shm_key_t = ctypes.c_int
IPC_PRIVATE = 0
IPC_RMID = 0
# int shmget(key_t key, size_t size, int shmflg);
shmget = libc.shmget
shmget.restype = ctypes.c_int
shmget.argtypes = (shm_key_t, ctypes.c_size_t, ctypes.c_int)
# void* shmat(int shmid, const void *shmaddr, int shmflg);
shmat = libc.shmat
shmat.restype = ctypes.c_void_p
shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int)
# int shmdt(const void *shmaddr);
shmdt = libc.shmdt
shmdt.restype = ctypes.c_int
shmdt.argtypes = (ctypes.c_void_p,)
# int shmctl(int shmid, int cmd, struct shmid_ds *buf);
shmctl = libc.shmctl
shmctl.restype = ctypes.c_int
shmctl.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_void_p)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# SHARED MEMORY CLASSES
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class SharedMemoryABC(ABC):
pass
@typechecked
class SharedMemory(SharedMemoryABC):
def __init__(self, ptr: int, shmid: int, size: int):
self._ptr = ptr
self._shmid = shmid
self._size = size
def close(self):
raise NotImplementedError()
def reattach(self) -> SharedMemoryABC:
return SharedMemoryClient(size = self._size, shmid = self._shmid)
@property
def closed(self) -> bool:
return self._ptr is None
@property
def ptr(self) -> int:
if self.closed:
raise ValueError('closed')
return self._ptr
@property
def shmid(self) -> int:
if self.closed:
raise ValueError('closed')
return self._shmid
@property
def size(self) -> int:
return self._size
@typechecked
class SharedMemoryServer(SharedMemory):
def __init__(self, size: int):
assert size > 0
shmid = shmget(IPC_PRIVATE, size, 0o600)
assert shmid > 0
ptr = shmat(shmid, 0, 0)
assert ptr
super().__init__(ptr = ptr, shmid = shmid, size = size)
def close(self):
shmdt(self._ptr)
self._ptr = None
shmctl(self._shmid, IPC_RMID, 0)
self._shmid = None
@typechecked
class SharedMemoryClient(SharedMemory):
def __init__(self, size: int, shmid: int):
assert size > 0
assert shmid > 0
ptr = shmat(shmid, 0, 0)
assert ptr
super().__init__(ptr = ptr, shmid = shmid, size = size)
def close(self):
shmdt(self._ptr)
self._ptr = None
self._shmid = None
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# SHARED NUMPY NDARRAY TYPE
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class sndarrayABC(ABC):
pass
@typechecked
class sndarray(sndarrayABC, np.ndarray):
"shared nd array"
def __new__(cls, *args, shm: SharedMemory, **kwargs) -> sndarrayABC:
obj = np.ndarray.__new__(np.ndarray, *args, **kwargs).view(cls)
obj._shm = shm
return obj
@property
def shm(self) -> SharedMemory:
return self._shm
@shm.setter
def shm(self, value: SharedMemory):
self._shm = value
@staticmethod
def _make_buffer(array: np.ndarray, shm: SharedMemory) -> np.ndarray:
class Buffer:
__array_interface__ = {
"data": (shm.ptr, False),
"shape": array.shape,
"strides": array.__array_interface__['strides'],
"typestr": array.dtype.str,
"version": 3,
}
return np.array(Buffer, copy = False)
def close(self):
self.shm.close()
def info(self) -> str:
return f'<sndarray id={id(self):x} data={self.__array_interface__["data"][0]:x} shmid={self.shm.shmid:d}>'
def reattach(self) -> sndarrayABC:
shm = self.shm.reattach()
buffer = self._make_buffer(self, shm)
return type(self)(
shape = self.shape,
dtype = self.dtype,
order = 'F' if not self.flags.c_contiguous and self.flags.f_contiguous else None,
strides = self.__array_interface__['strides'],
buffer = buffer,
shm = shm,
)
@classmethod
def from_array(cls, array: np.ndarray) -> sndarrayABC:
shm = SharedMemoryServer(size = int(np.prod(array.shape) * array.itemsize))
buffer = cls._make_buffer(array, shm)
buffer[...] = array
return cls(
shape = array.shape,
dtype = array.dtype,
order = 'F' if not array.flags.c_contiguous and array.flags.f_contiguous else None,
strides = array.__array_interface__['strides'],
buffer = buffer,
shm = shm,
)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# PRANGE (similar to Cython's and numba's prange)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@typechecked
def prange(
*args: int,
**kwargs: Any,
) -> Generator[int, None, None]:
raise SystemError('prange called outside of parallel function')
@typechecked
def _check_prange_nesting(func: Callable) -> Callable:
@wraps(func)
def wrapped(
*args,
ctx: Dict[str, Any],
**kwargs,
) -> Generator[int, None, None]:
if ctx['in_prange']:
raise SystemError('nested prange not permitted')
ctx['in_prange'] = True
try:
yield from func(*args, **kwargs)
finally:
ctx['in_prange'] = False
return wrapped
@_check_prange_nesting
@typechecked
def _prange(
*args,
worker_id: int,
processes: int,
**kwargs,
) -> Generator[int, None, None]:
if worker_id < 0:
raise ValueError('negative worker id not permitted')
if worker_id >= processes:
raise ValueError('worker id greater or equal to number of processes not permitted')
if processes < 1:
raise ValueError('at least one process required')
values = range(*args, **kwargs)
if len(values) == 0:
return
width = ceil(len(values) / processes)
yield from values[worker_id * width : (worker_id + 1) * width]
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# DECORATOR FOR PARALLEL FUNCTIONS
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@typechecked
def parallel(
processes: int = 1,
shared: Tuple[str, ...] = tuple(),
):
assert processes > 0
@typechecked
def wrapper(func: Callable) -> Callable:
@wraps(func)
def wrapped(*args: Any, **kwargs: Any):
if len(args) > 0:
raise SystemError('positional arguments not permitted')
buffer = {}
for name in shared:
var = kwargs[name]
if not isinstance(var, np.ndarray):
raise TypeError('can only share numpy arrays')
if isinstance(var, sndarray):
raise TypeError('array already shared')
buffer[name] = var # buffering unshared memory
kwargs[name] = sndarray.from_array(var) # sharing memory
worker_id = None
worker_pids = []
for idx in range(processes):
pid = os.fork()
if pid == 0: # in fork
worker_id = idx
break
# in parent
worker_pids.append(pid)
if worker_id is None: # in parent
clean_exit = True
for pid in worker_pids:
_, status = os.waitpid(pid, 0) # wait for forks
if status != 0:
clean_exit = False
for name, var in buffer.items():
var[...] = kwargs[name] # sync shared with buffered unshared memory
kwargs[name].close() # close (buffered) shared memory in parent
if not clean_exit:
raise SystemError('worker exited with non-zero exit status')
return # leave parent
# in worker
kwargs_update = {}
for name, value in kwargs.items():
if not isinstance(value, sndarray): # ignore unshared memory
continue
kwargs_update[name] = value.reattach() # attach to shared memory in worker
kwargs.update(kwargs_update)
new_globals = globals().copy()
new_globals.update(dict(
prange = partial(
_prange,
worker_id = worker_id,
processes = processes,
ctx = dict(in_prange = False),
), # custom prange
worker_id = worker_id,
processes = processes,
)) # extra parameters for parallel function
new_func = FunctionType(
code = getattr(func, '__code__'),
globals = new_globals,
) # injecting parameters into parallel function
ret = new_func(**kwargs) # call parallel function
for value in kwargs.values():
if not isinstance(value, sndarray): # ignore unshared memory
continue
value.close() # close shared memory in worker
if ret is not None:
raise SystemError('return values are not supported')
sys.exit(0) # terminate fork
return wrapped
return wrapper
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# DEMO 1
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@typechecked
@parallel(processes = 2)
def task1(a: np.ndarray, b: int):
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] Hello!')
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] {a.info():s}')
for idx in prange(a.shape[0]):
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] idx=={idx:d}')
a[idx] += b
def demo1():
print('demo1')
a = sndarray.from_array(np.arange(1, 11, dtype = 'f8'))
print(a.info())
print(a)
task1(a = a, b = 7)
print(a.info())
print(a)
a.close()
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# DEMO 2
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@typechecked
@parallel(processes = 2, shared = ('a',))
def task2(a: np.ndarray, b: int):
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] Hello!')
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] {a.info():s}')
for idx in prange(a.shape[0]):
print(f'[worker_id={worker_id:d} pid={os.getpid():d}] idx=={idx:d}')
a[idx] += b
def demo2():
print('demo2')
a = np.arange(1, 11, dtype = 'f8')
print(id(a))
print(a)
task2(a = a, b = 7)
print(id(a))
print(a)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# ENTRY POINT
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
if __name__ == "__main__":
demo1()
demo2()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment