Skip to content

Instantly share code, notes, and snippets.

@qxcv
Created February 8, 2020 01:09
Show Gist options
  • Save qxcv/d0f3d4a4324a768941ce52c06bbb571e to your computer and use it in GitHub Desktop.
Save qxcv/d0f3d4a4324a768941ce52c06bbb571e to your computer and use it in GitHub Desktop.
Testing multiprocessing-transparent numpy array
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import ctypes\n",
"import multiprocessing as mp\n",
"import pickle\n",
"mp.set_start_method('spawn')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# original construction function from rlpyt\n",
"def np_mp_array(shape, dtype):\n",
" size = int(np.prod(shape))\n",
" nbytes = size * np.dtype(dtype).itemsize\n",
" mp_array = mp.RawArray(ctypes.c_char, nbytes)\n",
" return np.frombuffer(mp_array, dtype=dtype, count=size).reshape(shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_source = \"\"\"\n",
"import numpy as np\n",
"\n",
"def np_proc_target(array):\n",
" print(\"At start of np_proc_target, array is\", array)\n",
" array[:] = 42\n",
" print(\"At end of np_proc_target, array is\", array)\n",
"\n",
"def arr_proc_target(ctypes_array):\n",
" ctypes_view = np.frombuffer(ctypes_array, dtype=np.float32, count=3).reshape((3,))\n",
" print(\"At start of np_proc_target, local view is\", ctypes_view)\n",
" ctypes_view[:] = 42\n",
" print(\"At end of np_proc_target, local view is\", ctypes_view)\n",
"\"\"\"\n",
"with open(\"_testing_mp_bug.py\", \"w\") as fp:\n",
" fp.write(target_source)\n",
"import _testing_mp_bug"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def do_np_mp_test():\n",
" shared_array = np_mp_array((3,), np.float32)\n",
" shared_array[:] = 666\n",
" print(\"In do_np_mp_test, array is initially\", shared_array)\n",
" p = mp.Process(target=_testing_mp_bug.np_proc_target, args=(shared_array, ))\n",
" p.start()\n",
" p.join()\n",
" print(\"Subprocess exited with code\", p.exitcode)\n",
" print(\"At end of do_np_mp_test, array is finally\", shared_array)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"do_np_mp_test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def do_array_test():\n",
" ctypes_array = mp.RawArray(ctypes.c_char, 3 * np.dtype(np.float32).itemsize)\n",
" ctypes_view = np.frombuffer(ctypes_array, dtype=np.float32, count=3).reshape((3,))\n",
" ctypes_view[:] = 666\n",
" print(\"In do_array_test, local view is initially\", ctypes_view)\n",
" p = mp.Process(target=_testing_mp_bug.arr_proc_target, args=(ctypes_array, ))\n",
" p.start()\n",
" p.join()\n",
" print(\"Subprocess exited with code\", p.exitcode)\n",
" print(\"At end of do_array_test, local view is finally\", ctypes_view)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"do_array_test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from rlpyt.utils import shmemarray\n",
"import ctypes\n",
"\n",
"class jankarray(np.ndarray):\n",
" # \"inspired\" by https://github.com/numpy/numpy/pull/7533/files\n",
" # (which was rejected from numpy, probably for good reason).\n",
" # This version only supports Linux (probably, maybe other unix works too).\n",
" def __new__(cls, shape, dtype=None, buffer=None, offset=None, strides=None,\n",
" order=None):\n",
" if buffer is None:\n",
" assert offset is None\n",
" assert strides is None\n",
" size = int(np.prod(shape))\n",
" nbytes = size * np.dtype(dtype).itemsize\n",
" # creates a new buffer with random tag (we'll record it in __reduce__)\n",
" buffer = shmemarray.ShmemRawArray(ctypes.c_char, nbytes, tag=None, create=True)\n",
" offset = 0\n",
" elif isinstance(buffer, tuple):\n",
" # restoring from a pickle\n",
" buf_nbytes, buf_tag = buffer\n",
" assert isinstance(buf_tag, str)\n",
" assert isinstance(buf_nbytes, int)\n",
" buffer = shmemarray.ShmemRawArray(ctypes.c_char, buf_nbytes, tag=buf_tag, create=False)\n",
" else:\n",
" raise ValueError(f\"jankarray does not support specifying custom buffers, but was given {buffer!r}\")\n",
"\n",
" obj = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=buffer,\n",
" offset=offset, strides=strides, order=order)\n",
"\n",
" return obj\n",
" \n",
" def __reduce__(self):\n",
" # find the \"real\" base created by ShmemRawArray by unwinding nesting\n",
" truebase = self.base\n",
" while not hasattr(truebase, '_buffer'):\n",
" assert truebase is not None, \\\n",
" \"need meaningful self.base(.base(.base(.…))) to keep track of input buffer\"\n",
" truebase = truebase.base\n",
"\n",
" # credit to https://stackoverflow.com/a/53534485 for awful/wonderful __array_interface__ hack\n",
" absolute_offset = self.__array_interface__['data'][0]\n",
" base_address = ctypes.addressof(truebase)\n",
" offset = absolute_offset - base_address\n",
" buf_info = (truebase._buffer.size, truebase._buffer._mem.name)\n",
" \n",
" order = 'FC'[self.flags['C_CONTIGUOUS']]\n",
" \n",
" return (jankarray, (self.shape, self.dtype, buf_info, offset, self.strides, order))\n",
" \n",
"def make_jankarray(shape, dtype=None):\n",
" return jankarray(shape, dtype=dtype, buffer=None)\n",
"\n",
"janky = make_jankarray((5, 3))\n",
"print('initially janky is\\n', janky)\n",
"\n",
"import pickle\n",
"new_jank = pickle.loads(pickle.dumps(janky))\n",
"\n",
"print(\"\\ntesting that changes in janky are reflected in new_jank\")\n",
"janky[0, :] = 42\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"\n",
"print(\"\\nreverse: testing that changes in new_jank are reflected in janky\")\n",
"new_jank[0, :] = 7\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"\n",
"print('\\ntesting mutation of views')\n",
"new_jank_view = new_jank[1, 1:]\n",
"new_jank_view[:] = 99\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"print('new_jank_view\\n', new_jank_view)\n",
"\n",
"print('\\ntesting mutation of unpickled views')\n",
"up_new_jank_view = pickle.loads(pickle.dumps(new_jank_view))\n",
"up_new_jank_view[:] = 43\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"print('new_jank_view\\n', new_jank_view)\n",
"print('up_new_jank_view\\n', up_new_jank_view)\n",
"\n",
"print('\\ntesting mutation of reshaped/transposed views')\n",
"reshape_jank = pickle.loads(pickle.dumps(new_jank.reshape((1, 15)).T)).T.reshape((5, 3))\n",
"reshape_jank[1, :] = -12\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"print('reshape_jank\\n', reshape_jank)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from multiprocessing import shared_memory\n",
"import ctypes\n",
"\n",
"class jankarray(np.ndarray):\n",
" # \"inspired\" by https://github.com/numpy/numpy/pull/7533/files\n",
" # (which was rejected from numpy, probably for good reason).\n",
" # This version only supports Linux (probably, maybe other unix works too).\n",
" _shmem = None\n",
" def __new__(cls, shape, dtype=None, buffer=None, offset=None, strides=None,\n",
" order=None):\n",
" if buffer is None:\n",
" assert offset is None\n",
" assert strides is None\n",
" size = int(np.prod(shape))\n",
" nbytes = size * np.dtype(dtype).itemsize\n",
" # creates a new buffer with random tag (we'll record it in __reduce__)\n",
" shmem = shared_memory.SharedMemory(name=None, create=True, size=nbytes)\n",
" offset = 0\n",
" elif isinstance(buffer, shared_memory.SharedMemory):\n",
" # restoring from a pickle\n",
" shmem = buffer\n",
" else:\n",
" raise ValueError(f\"jankarray does not support specifying custom buffers, but was given {buffer!r}\")\n",
"\n",
" obj = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=shmem.buf,\n",
" offset=offset, strides=strides, order=order)\n",
" obj._shmem = shmem\n",
"\n",
" return obj\n",
" \n",
" def __array_finalize__(self, obj):\n",
" if obj is not None:\n",
" self._shmem = obj._shmem\n",
" \n",
" def __reduce__(self):\n",
" # credit to https://stackoverflow.com/a/53534485 for awful/wonderful __array_interface__ hack\n",
" absolute_offset = self.__array_interface__['data'][0]\n",
" base_address = ctypes.addressof(ctypes.c_char.from_buffer(self._shmem.buf))\n",
" offset = absolute_offset - base_address\n",
" assert offset <= self._shmem.size, (offset, self._shmem.size)\n",
" order = 'FC'[self.flags['C_CONTIGUOUS']]\n",
" newargs = (self.shape, self.dtype, self._shmem, offset, self.strides, order)\n",
" return (jankarray, newargs)\n",
" \n",
"def make_jankarray(shape, dtype=None):\n",
" return jankarray(shape, dtype=dtype, buffer=None)\n",
"\n",
"janky = make_jankarray((5, 3))\n",
"print('initially janky is\\n', janky)\n",
"\n",
"import pickle\n",
"new_jank = pickle.loads(pickle.dumps(janky))\n",
"\n",
"print(\"\\ntesting that changes in janky are reflected in new_jank\")\n",
"janky[0, :] = 42\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"\n",
"print(\"\\nreverse: testing that changes in new_jank are reflected in janky\")\n",
"new_jank[0, :] = 7\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"\n",
"print('\\ntesting mutation of views')\n",
"new_jank_view = new_jank[1, 1:]\n",
"new_jank_view[:] = 99\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"print('new_jank_view\\n', new_jank_view)\n",
"\n",
"print('\\ntesting mutation of unpickled views')\n",
"up_new_jank_view = pickle.loads(pickle.dumps(new_jank_view))\n",
"up_new_jank_view[:] = 43\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"print('new_jank_view\\n', new_jank_view)\n",
"print('up_new_jank_view\\n', up_new_jank_view)\n",
"\n",
"print('\\ntesting mutation of reshaped/transposed views')\n",
"reshape_jank = pickle.loads(pickle.dumps(new_jank.reshape((1, 15)).T)).T.reshape((5, 3))\n",
"reshape_jank[1, :] = -12\n",
"print('janky\\n', janky)\n",
"print('new_jank\\n', new_jank)\n",
"print('reshape_jank\\n', reshape_jank)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Things to test:\n",
"# (1) Simple changes in the child should be reflected in the parent.\n",
"# (2) Simple changes in the parent should be reflected in the child.\n",
"# (3) The result of several specific operations should obey (1) and (2):\n",
"# (i) Reversing the array with strided slices.\n",
"# (ii) Taking a sub-array somewhere (no striding/reversal).\n",
"# (iii) Tranposes\n",
"# If all that works then the thing is probably working fine.\n",
"\n",
"target_source = \"\"\"\n",
"import numpy as np\n",
"import multiprocessing as mp\n",
"import ctypes\n",
"\n",
"def do_crazy_manipulations(some_array):\n",
" print('(in do_crazy_manipulations) initial array:')\n",
" print(some_array)\n",
" assert some_array.shape == (5, 3)\n",
" crazy_slice = some_array[::-2, 1:]\n",
" crazy_slice[:] = 42\n",
" trans_array = some_array.T[:2, :1]\n",
" trans_array[:] = 7\n",
" plain_sub = some_array[3:5, 2:]\n",
" plain_sub[:] = 13\n",
" print('(in do_crazy_manipulations) final array:')\n",
" print(some_array)\n",
"\n",
"def do_wait_test(arr, parent_should_go, child_should_go):\n",
" print(\"[child] in do_wait_test, before waiting on event:\")\n",
" print(arr)\n",
" parent_should_go.set()\n",
" child_should_go.wait(30)\n",
" print(\"[child] in do_wait_test, after waiting on event:\")\n",
" print(arr)\n",
" print(\"[child] do_wait_test finished\")\n",
"\n",
"def do_manip_test(arr):\n",
" print(\"[child] now do_manip_test is doing its thing\")\n",
" do_crazy_manipulations(arr)\n",
" print(\"[child] now do_manip_test finished\")\n",
" \n",
"class mpjankarray(np.ndarray):\n",
" '''ndarray which can be shared between `multiprocessing` processes by\n",
" passing it to a `Process` init function (or similar). Note that this\n",
" can only be shared _on process startup_ otherwise, it will not work.\n",
" Also it cannot/should not be pickled (outside of multiprocessing's\n",
" internals).'''\n",
" _shmem = None\n",
" def __new__(cls, shape, dtype=None, buffer=None, offset=None, strides=None,\n",
" order=None):\n",
" if buffer is None:\n",
" assert offset is None\n",
" assert strides is None\n",
" size = int(np.prod(shape))\n",
" nbytes = size * np.dtype(dtype).itemsize\n",
" # this is the part that can be passed between processes\n",
" shmem = mp.RawArray(ctypes.c_char, nbytes)\n",
" offset = 0\n",
" elif isinstance(buffer, ctypes.Array):\n",
" # restoring from a pickle\n",
" shmem = buffer\n",
" else:\n",
" raise ValueError(\n",
" f\"jankarray does not support specifying custom buffers, but \"\n",
" f\"was given {buffer!r}\")\n",
"\n",
" obj = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=shmem,\n",
" offset=offset, strides=strides, order=order)\n",
" obj._shmem = shmem\n",
"\n",
" return obj\n",
" \n",
" def __array_finalize__(self, obj):\n",
" if obj is not None:\n",
" self._shmem = obj._shmem\n",
" \n",
" def __reduce__(self):\n",
" # credit to https://stackoverflow.com/a/53534485 for awful/wonderful\n",
" # __array_interface__ hack\n",
" absolute_offset = self.__array_interface__['data'][0]\n",
" # base_address = ctypes.addressof(ctypes.c_char.from_buffer(self._shmem.buf))\n",
" base_address = ctypes.addressof(self._shmem)\n",
" offset = absolute_offset - base_address\n",
" # assert offset <= self._shmem.size, (offset, self._shmem.size)\n",
" assert offset <= len(self._shmem), (offset, len(self._shmem))\n",
" order = 'FC'[self.flags['C_CONTIGUOUS']]\n",
" newargs = (self.shape, self.dtype, self._shmem, offset, self.strides, order)\n",
" return (type(self), newargs)\n",
" \n",
"def make_mpjankarray(shape, dtype=None):\n",
" return mpjankarray(shape, dtype=dtype, buffer=None)\n",
"\"\"\"\n",
"with open(\"_testing_mp_fix.py\", \"w\") as fp:\n",
" fp.write(target_source)\n",
"import _testing_mp_fix\n",
"\n",
"# TODO: split this all out into client/server code; going to be a pain to\n",
"# shove into files\n",
"print(\"Baseline result for ordinary ndarray:\")\n",
"ref_array = np.zeros((5, 3))\n",
"_testing_mp_fix.do_crazy_manipulations(ref_array)\n",
"print(\"^^^ Baseline result done. Refer to that for ground truth ^^^\\n\\n\\n\")\n",
"\n",
"print(\"Doing test: parent sends to child -> child manipulates -> parent reads back\")\n",
"janky = _testing_mp_fix.make_mpjankarray((5, 3))\n",
"print(\"[parent] In parent, array looks like this:\")\n",
"print(janky)\n",
"p = mp.Process(target=_testing_mp_fix.do_manip_test, args=(janky, ))\n",
"p.start()\n",
"p.join()\n",
"print(\"[parent] Subprocess exited with code\", p.exitcode)\n",
"print(\"[parent] Now the array looks like this:\")\n",
"print(janky)\n",
"print(\"\\n\\n\\n\")\n",
"\n",
"print(\"Doing test: parent sends to child -> child prints array, parent waits -> parent manipulates, child waits -> child prints\")\n",
"janky = _testing_mp_fix.make_mpjankarray((5, 3))\n",
"parent_should_go_ev, child_should_go_ev = mp.Event(), mp.Event()\n",
"print(\"[parent] In parent, array looks like this:\")\n",
"print(janky)\n",
"p = mp.Process(target=_testing_mp_fix.do_wait_test, args=(janky, parent_should_go_ev, child_should_go_ev))\n",
"p.start()\n",
"parent_should_go_ev.wait(5)\n",
"print(\"[parent] now parent is manipulating\")\n",
"_testing_mp_fix.do_crazy_manipulations(janky)\n",
"print(\"[parent] parent done manipulations, going back to child\")\n",
"child_should_go_ev.set()\n",
"p.join()\n",
"print(\"[parent] Subprocess exited with code\", p.exitcode)\n",
"print(\"[parent] Now the array looks like this:\")\n",
"print(janky)\n",
"print(\"\\n\\n\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment