Skip to content

Instantly share code, notes, and snippets.

@mjm522
Last active March 28, 2022 13:57
Show Gist options
  • Save mjm522/ce0aab9458c738b0ed0958f6e94abd5b to your computer and use it in GitHub Desktop.
Save mjm522/ce0aab9458c738b0ed0958f6e94abd5b to your computer and use it in GitHub Desktop.
Following is a class that creates shared numpy via Linux Posix.
import ctypes
import mmap
import os
import stat
import numpy as np
from functools import reduce
from operator import mul
from typing import Any
rtld = ctypes.cdll.LoadLibrary("librt.so")
_shm_open = rtld.shm_open
_shm_unlink = rtld.shm_unlink
def _create_string_buffer(name):
if isinstance(name, bytes):
name = ctypes.create_string_buffer(name)
elif isinstance(name, str):
name = ctypes.create_unicode_buffer(name)
else:
raise TypeError("`name` must be `bytes` or `str`")
return name
def shm_open(
name,
oflag: int = ctypes.c_int(os.O_RDWR | os.O_CREAT), # type: ignore
mode: int = ctypes.c_ushort(stat.S_IRUSR | stat.S_IWUSR), # type: ignore
):
name = _create_string_buffer(name)
result = _shm_open(name, oflag, mode)
if result == -1:
raise RuntimeError(os.strerror(ctypes.get_errno()))
return result
def shm_unlink(name):
name = _create_string_buffer(name)
result = _shm_unlink(name)
if result == -1:
raise RuntimeError(os.strerror(ctypes.get_errno()))
class NpShm:
def __init__(self, name: str, shape: tuple, dtype: Any=np.float64):
self.name = (f"{name}").encode("utf-8")
self.shape = shape
self.dtype = dtype
self.num_data_bytes = reduce(mul, shape)*np.dtype(self.dtype).itemsize
self.num_bytes = self.num_data_bytes + len(self.name)
self.fid = shm_open(
self.name,
ctypes.c_int(os.O_RDWR | os.O_CREAT), # type: ignore
ctypes.c_ushort(stat.S_IRUSR | stat.S_IWUSR), # type: ignore
)
os.ftruncate(self.fid, self.num_bytes)
self.memory_map = mmap.mmap(self.fid, self.num_bytes)
self.memory_view = memoryview(self.memory_map)
offset = self._find_instance(self.name)
self.is_creator = False
if offset < 0:
print("Does not exist, creating it")
self.is_creator = True
self._create_instance(self.name)
offset = self._find_instance(self.name)
self.data = np.ndarray(shape=shape, dtype=self.dtype, buffer=self._get_mutable_buffer(offset, self.num_data_bytes))
def _create_instance(self, name: bytes):
data = np.zeros(self.shape, dtype=self.dtype)
data_bytes = data.tobytes()
self.memory_map.write(data_bytes)
self.memory_map.seek(self.num_data_bytes)
self.memory_map.write(name)
self.memory_map.seek(0)
def _find_instance(self, name: bytes):
offset = self.memory_map.find(name)
if offset < 0:
name_string = name.decode("utf-8")
print(f'Instance "{name_string}" is not found')
return offset
def _get_mutable_buffer(self, idx, message_size):
return self.memory_view[idx - message_size : idx]
def __getitem__(self, i):
return self.data[i]
def __setitem__(self, i, val):
self.data[i] = val
def __len__(self):
return len(self.data)
def __repr__(self):
return f"{self.data}"
def __bytes__(self):
return self.data.tobytes()
def __del__(self):
if self.is_creator:
return
print(f"{self}: destructor")
shm_unlink(self.shared_memory_segment_name)
self.memory_view.release()
# self.memory_map.close()
def read(self) -> np.array:
return self.data
def write(self, data: np.array) -> None:
assert len(data) == self.array_size
np.copyto(self.data, data)
#### checks
def get_numpy_shm(name, size):
return NpShm(name, size)
def add_value_at_index(np_shm, idx, val):
np_shm[idx] = val
def main():
shm = get_numpy_shm("bleeeh", (10,))
print("before")
print(shm.read())
shm.write(np.ones(10)*10)
print("read")
print(shm.read())
shm[5] = 100
print("read after changing element at index 5")
print(shm.read())
shm[5:7] = np.array([200, 250])
print("read after changing two elements at index 5, 6")
print(shm.read())
while True:
pass
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment