Instantly share code, notes, and snippets.
Last active
March 28, 2022 13:57
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save mjm522/ce0aab9458c738b0ed0958f6e94abd5b to your computer and use it in GitHub Desktop.
Following is a class that creates shared numpy via Linux Posix.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
import mmap | |
import os | |
import stat | |
import numpy as np | |
from functools import reduce | |
from operator import mul | |
from typing import Any | |
rtld = ctypes.cdll.LoadLibrary("librt.so") | |
_shm_open = rtld.shm_open | |
_shm_unlink = rtld.shm_unlink | |
def _create_string_buffer(name): | |
if isinstance(name, bytes): | |
name = ctypes.create_string_buffer(name) | |
elif isinstance(name, str): | |
name = ctypes.create_unicode_buffer(name) | |
else: | |
raise TypeError("`name` must be `bytes` or `str`") | |
return name | |
def shm_open( | |
name, | |
oflag: int = ctypes.c_int(os.O_RDWR | os.O_CREAT), # type: ignore | |
mode: int = ctypes.c_ushort(stat.S_IRUSR | stat.S_IWUSR), # type: ignore | |
): | |
name = _create_string_buffer(name) | |
result = _shm_open(name, oflag, mode) | |
if result == -1: | |
raise RuntimeError(os.strerror(ctypes.get_errno())) | |
return result | |
def shm_unlink(name): | |
name = _create_string_buffer(name) | |
result = _shm_unlink(name) | |
if result == -1: | |
raise RuntimeError(os.strerror(ctypes.get_errno())) | |
class NpShm: | |
def __init__(self, name: str, shape: tuple, dtype: Any=np.float64): | |
self.name = (f"{name}").encode("utf-8") | |
self.shape = shape | |
self.dtype = dtype | |
self.num_data_bytes = reduce(mul, shape)*np.dtype(self.dtype).itemsize | |
self.num_bytes = self.num_data_bytes + len(self.name) | |
self.fid = shm_open( | |
self.name, | |
ctypes.c_int(os.O_RDWR | os.O_CREAT), # type: ignore | |
ctypes.c_ushort(stat.S_IRUSR | stat.S_IWUSR), # type: ignore | |
) | |
os.ftruncate(self.fid, self.num_bytes) | |
self.memory_map = mmap.mmap(self.fid, self.num_bytes) | |
self.memory_view = memoryview(self.memory_map) | |
offset = self._find_instance(self.name) | |
self.is_creator = False | |
if offset < 0: | |
print("Does not exist, creating it") | |
self.is_creator = True | |
self._create_instance(self.name) | |
offset = self._find_instance(self.name) | |
self.data = np.ndarray(shape=shape, dtype=self.dtype, buffer=self._get_mutable_buffer(offset, self.num_data_bytes)) | |
def _create_instance(self, name: bytes): | |
data = np.zeros(self.shape, dtype=self.dtype) | |
data_bytes = data.tobytes() | |
self.memory_map.write(data_bytes) | |
self.memory_map.seek(self.num_data_bytes) | |
self.memory_map.write(name) | |
self.memory_map.seek(0) | |
def _find_instance(self, name: bytes): | |
offset = self.memory_map.find(name) | |
if offset < 0: | |
name_string = name.decode("utf-8") | |
print(f'Instance "{name_string}" is not found') | |
return offset | |
def _get_mutable_buffer(self, idx, message_size): | |
return self.memory_view[idx - message_size : idx] | |
def __getitem__(self, i): | |
return self.data[i] | |
def __setitem__(self, i, val): | |
self.data[i] = val | |
def __len__(self): | |
return len(self.data) | |
def __repr__(self): | |
return f"{self.data}" | |
def __bytes__(self): | |
return self.data.tobytes() | |
def __del__(self): | |
if self.is_creator: | |
return | |
print(f"{self}: destructor") | |
shm_unlink(self.shared_memory_segment_name) | |
self.memory_view.release() | |
# self.memory_map.close() | |
def read(self) -> np.array: | |
return self.data | |
def write(self, data: np.array) -> None: | |
assert len(data) == self.array_size | |
np.copyto(self.data, data) | |
#### checks | |
def get_numpy_shm(name, size): | |
return NpShm(name, size) | |
def add_value_at_index(np_shm, idx, val): | |
np_shm[idx] = val | |
def main(): | |
shm = get_numpy_shm("bleeeh", (10,)) | |
print("before") | |
print(shm.read()) | |
shm.write(np.ones(10)*10) | |
print("read") | |
print(shm.read()) | |
shm[5] = 100 | |
print("read after changing element at index 5") | |
print(shm.read()) | |
shm[5:7] = np.array([200, 250]) | |
print("read after changing two elements at index 5, 6") | |
print(shm.read()) | |
while True: | |
pass | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment