Created
February 8, 2020 01:09
-
-
Save qxcv/d0f3d4a4324a768941ce52c06bbb571e to your computer and use it in GitHub Desktop.
Testing multiprocessing-transparent numpy array
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import ctypes\n", | |
"import multiprocessing as mp\n", | |
"import pickle\n", | |
"mp.set_start_method('spawn')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# original construction function from rlpyt\n", | |
"def np_mp_array(shape, dtype):\n", | |
" size = int(np.prod(shape))\n", | |
" nbytes = size * np.dtype(dtype).itemsize\n", | |
" mp_array = mp.RawArray(ctypes.c_char, nbytes)\n", | |
" return np.frombuffer(mp_array, dtype=dtype, count=size).reshape(shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"target_source = \"\"\"\n", | |
"import numpy as np\n", | |
"\n", | |
"def np_proc_target(array):\n", | |
" print(\"At start of np_proc_target, array is\", array)\n", | |
" array[:] = 42\n", | |
" print(\"At end of np_proc_target, array is\", array)\n", | |
"\n", | |
"def arr_proc_target(ctypes_array):\n", | |
" ctypes_view = np.frombuffer(ctypes_array, dtype=np.float32, count=3).reshape((3,))\n", | |
" print(\"At start of np_proc_target, local view is\", ctypes_view)\n", | |
" ctypes_view[:] = 42\n", | |
" print(\"At end of np_proc_target, local view is\", ctypes_view)\n", | |
"\"\"\"\n", | |
"with open(\"_testing_mp_bug.py\", \"w\") as fp:\n", | |
" fp.write(target_source)\n", | |
"import _testing_mp_bug" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def do_np_mp_test():\n", | |
" shared_array = np_mp_array((3,), np.float32)\n", | |
" shared_array[:] = 666\n", | |
" print(\"In do_np_mp_test, array is initially\", shared_array)\n", | |
" p = mp.Process(target=_testing_mp_bug.np_proc_target, args=(shared_array, ))\n", | |
" p.start()\n", | |
" p.join()\n", | |
" print(\"Subprocess exited with code\", p.exitcode)\n", | |
" print(\"At end of do_np_mp_test, array is finally\", shared_array)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"do_np_mp_test()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def do_array_test():\n", | |
" ctypes_array = mp.RawArray(ctypes.c_char, 3 * np.dtype(np.float32).itemsize)\n", | |
" ctypes_view = np.frombuffer(ctypes_array, dtype=np.float32, count=3).reshape((3,))\n", | |
" ctypes_view[:] = 666\n", | |
" print(\"In do_array_test, local view is initially\", ctypes_view)\n", | |
" p = mp.Process(target=_testing_mp_bug.arr_proc_target, args=(ctypes_array, ))\n", | |
" p.start()\n", | |
" p.join()\n", | |
" print(\"Subprocess exited with code\", p.exitcode)\n", | |
" print(\"At end of do_array_test, local view is finally\", ctypes_view)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"do_array_test()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"jupyter": { | |
"source_hidden": true | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from rlpyt.utils import shmemarray\n", | |
"import ctypes\n", | |
"\n", | |
"class jankarray(np.ndarray):\n", | |
" # \"inspired\" by https://github.com/numpy/numpy/pull/7533/files\n", | |
" # (which was rejected from numpy, probably for good reason).\n", | |
" # This version only supports Linux (probably, maybe other unix works too).\n", | |
" def __new__(cls, shape, dtype=None, buffer=None, offset=None, strides=None,\n", | |
" order=None):\n", | |
" if buffer is None:\n", | |
" assert offset is None\n", | |
" assert strides is None\n", | |
" size = int(np.prod(shape))\n", | |
" nbytes = size * np.dtype(dtype).itemsize\n", | |
" # creates a new buffer with random tag (we'll record it in __reduce__)\n", | |
" buffer = shmemarray.ShmemRawArray(ctypes.c_char, nbytes, tag=None, create=True)\n", | |
" offset = 0\n", | |
" elif isinstance(buffer, tuple):\n", | |
" # restoring from a pickle\n", | |
" buf_nbytes, buf_tag = buffer\n", | |
" assert isinstance(buf_tag, str)\n", | |
" assert isinstance(buf_nbytes, int)\n", | |
" buffer = shmemarray.ShmemRawArray(ctypes.c_char, buf_nbytes, tag=buf_tag, create=False)\n", | |
" else:\n", | |
" raise ValueError(f\"jankarray does not support specifying custom buffers, but was given {buffer!r}\")\n", | |
"\n", | |
" obj = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=buffer,\n", | |
" offset=offset, strides=strides, order=order)\n", | |
"\n", | |
" return obj\n", | |
" \n", | |
" def __reduce__(self):\n", | |
" # find the \"real\" base created by ShmemRawArray by unwinding nesting\n", | |
" truebase = self.base\n", | |
" while not hasattr(truebase, '_buffer'):\n", | |
" assert truebase is not None, \\\n", | |
" \"need meaningful self.base(.base(.base(.…))) to keep track of input buffer\"\n", | |
" truebase = truebase.base\n", | |
"\n", | |
" # credit to https://stackoverflow.com/a/53534485 for awful/wonderful __array_interface__ hack\n", | |
" absolute_offset = self.__array_interface__['data'][0]\n", | |
" base_address = ctypes.addressof(truebase)\n", | |
" offset = absolute_offset - base_address\n", | |
" buf_info = (truebase._buffer.size, truebase._buffer._mem.name)\n", | |
" \n", | |
" order = 'FC'[self.flags['C_CONTIGUOUS']]\n", | |
" \n", | |
" return (jankarray, (self.shape, self.dtype, buf_info, offset, self.strides, order))\n", | |
" \n", | |
"def make_jankarray(shape, dtype=None):\n", | |
" return jankarray(shape, dtype=dtype, buffer=None)\n", | |
"\n", | |
"janky = make_jankarray((5, 3))\n", | |
"print('initially janky is\\n', janky)\n", | |
"\n", | |
"import pickle\n", | |
"new_jank = pickle.loads(pickle.dumps(janky))\n", | |
"\n", | |
"print(\"\\ntesting that changes in janky are reflected in new_jank\")\n", | |
"janky[0, :] = 42\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"\n", | |
"print(\"\\nreverse: testing that changes in new_jank are reflected in janky\")\n", | |
"new_jank[0, :] = 7\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"\n", | |
"print('\\ntesting mutation of views')\n", | |
"new_jank_view = new_jank[1, 1:]\n", | |
"new_jank_view[:] = 99\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"print('new_jank_view\\n', new_jank_view)\n", | |
"\n", | |
"print('\\ntesting mutation of unpickled views')\n", | |
"up_new_jank_view = pickle.loads(pickle.dumps(new_jank_view))\n", | |
"up_new_jank_view[:] = 43\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"print('new_jank_view\\n', new_jank_view)\n", | |
"print('up_new_jank_view\\n', up_new_jank_view)\n", | |
"\n", | |
"print('\\ntesting mutation of reshaped/transposed views')\n", | |
"reshape_jank = pickle.loads(pickle.dumps(new_jank.reshape((1, 15)).T)).T.reshape((5, 3))\n", | |
"reshape_jank[1, :] = -12\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"print('reshape_jank\\n', reshape_jank)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"jupyter": { | |
"source_hidden": true | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from multiprocessing import shared_memory\n", | |
"import ctypes\n", | |
"\n", | |
"class jankarray(np.ndarray):\n", | |
" # \"inspired\" by https://github.com/numpy/numpy/pull/7533/files\n", | |
" # (which was rejected from numpy, probably for good reason).\n", | |
" # This version only supports Linux (probably, maybe other unix works too).\n", | |
" _shmem = None\n", | |
" def __new__(cls, shape, dtype=None, buffer=None, offset=None, strides=None,\n", | |
" order=None):\n", | |
" if buffer is None:\n", | |
" assert offset is None\n", | |
" assert strides is None\n", | |
" size = int(np.prod(shape))\n", | |
" nbytes = size * np.dtype(dtype).itemsize\n", | |
" # creates a new buffer with random tag (we'll record it in __reduce__)\n", | |
" shmem = shared_memory.SharedMemory(name=None, create=True, size=nbytes)\n", | |
" offset = 0\n", | |
" elif isinstance(buffer, shared_memory.SharedMemory):\n", | |
" # restoring from a pickle\n", | |
" shmem = buffer\n", | |
" else:\n", | |
" raise ValueError(f\"jankarray does not support specifying custom buffers, but was given {buffer!r}\")\n", | |
"\n", | |
" obj = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=shmem.buf,\n", | |
" offset=offset, strides=strides, order=order)\n", | |
" obj._shmem = shmem\n", | |
"\n", | |
" return obj\n", | |
" \n", | |
" def __array_finalize__(self, obj):\n", | |
" if obj is not None:\n", | |
" self._shmem = obj._shmem\n", | |
" \n", | |
" def __reduce__(self):\n", | |
" # credit to https://stackoverflow.com/a/53534485 for awful/wonderful __array_interface__ hack\n", | |
" absolute_offset = self.__array_interface__['data'][0]\n", | |
" base_address = ctypes.addressof(ctypes.c_char.from_buffer(self._shmem.buf))\n", | |
" offset = absolute_offset - base_address\n", | |
" assert offset <= self._shmem.size, (offset, self._shmem.size)\n", | |
" order = 'FC'[self.flags['C_CONTIGUOUS']]\n", | |
" newargs = (self.shape, self.dtype, self._shmem, offset, self.strides, order)\n", | |
" return (jankarray, newargs)\n", | |
" \n", | |
"def make_jankarray(shape, dtype=None):\n", | |
" return jankarray(shape, dtype=dtype, buffer=None)\n", | |
"\n", | |
"janky = make_jankarray((5, 3))\n", | |
"print('initially janky is\\n', janky)\n", | |
"\n", | |
"import pickle\n", | |
"new_jank = pickle.loads(pickle.dumps(janky))\n", | |
"\n", | |
"print(\"\\ntesting that changes in janky are reflected in new_jank\")\n", | |
"janky[0, :] = 42\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"\n", | |
"print(\"\\nreverse: testing that changes in new_jank are reflected in janky\")\n", | |
"new_jank[0, :] = 7\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"\n", | |
"print('\\ntesting mutation of views')\n", | |
"new_jank_view = new_jank[1, 1:]\n", | |
"new_jank_view[:] = 99\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"print('new_jank_view\\n', new_jank_view)\n", | |
"\n", | |
"print('\\ntesting mutation of unpickled views')\n", | |
"up_new_jank_view = pickle.loads(pickle.dumps(new_jank_view))\n", | |
"up_new_jank_view[:] = 43\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"print('new_jank_view\\n', new_jank_view)\n", | |
"print('up_new_jank_view\\n', up_new_jank_view)\n", | |
"\n", | |
"print('\\ntesting mutation of reshaped/transposed views')\n", | |
"reshape_jank = pickle.loads(pickle.dumps(new_jank.reshape((1, 15)).T)).T.reshape((5, 3))\n", | |
"reshape_jank[1, :] = -12\n", | |
"print('janky\\n', janky)\n", | |
"print('new_jank\\n', new_jank)\n", | |
"print('reshape_jank\\n', reshape_jank)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Things to test:\n", | |
"# (1) Simple changes in the child should be reflected in the parent.\n", | |
"# (2) Simple changes in the parent should be reflected in the child.\n", | |
"# (3) The result of several specific operations should obey (1) and (2):\n", | |
"# (i) Reversing the array with strided slices.\n", | |
"# (ii) Taking a sub-array somewhere (no striding/reversal).\n", | |
"# (iii) Tranposes\n", | |
"# If all that works then the thing is probably working fine.\n", | |
"\n", | |
"target_source = \"\"\"\n", | |
"import numpy as np\n", | |
"import multiprocessing as mp\n", | |
"import ctypes\n", | |
"\n", | |
"def do_crazy_manipulations(some_array):\n", | |
" print('(in do_crazy_manipulations) initial array:')\n", | |
" print(some_array)\n", | |
" assert some_array.shape == (5, 3)\n", | |
" crazy_slice = some_array[::-2, 1:]\n", | |
" crazy_slice[:] = 42\n", | |
" trans_array = some_array.T[:2, :1]\n", | |
" trans_array[:] = 7\n", | |
" plain_sub = some_array[3:5, 2:]\n", | |
" plain_sub[:] = 13\n", | |
" print('(in do_crazy_manipulations) final array:')\n", | |
" print(some_array)\n", | |
"\n", | |
"def do_wait_test(arr, parent_should_go, child_should_go):\n", | |
" print(\"[child] in do_wait_test, before waiting on event:\")\n", | |
" print(arr)\n", | |
" parent_should_go.set()\n", | |
" child_should_go.wait(30)\n", | |
" print(\"[child] in do_wait_test, after waiting on event:\")\n", | |
" print(arr)\n", | |
" print(\"[child] do_wait_test finished\")\n", | |
"\n", | |
"def do_manip_test(arr):\n", | |
" print(\"[child] now do_manip_test is doing its thing\")\n", | |
" do_crazy_manipulations(arr)\n", | |
" print(\"[child] now do_manip_test finished\")\n", | |
" \n", | |
"class mpjankarray(np.ndarray):\n", | |
" '''ndarray which can be shared between `multiprocessing` processes by\n", | |
" passing it to a `Process` init function (or similar). Note that this\n", | |
" can only be shared _on process startup_ otherwise, it will not work.\n", | |
" Also it cannot/should not be pickled (outside of multiprocessing's\n", | |
" internals).'''\n", | |
" _shmem = None\n", | |
" def __new__(cls, shape, dtype=None, buffer=None, offset=None, strides=None,\n", | |
" order=None):\n", | |
" if buffer is None:\n", | |
" assert offset is None\n", | |
" assert strides is None\n", | |
" size = int(np.prod(shape))\n", | |
" nbytes = size * np.dtype(dtype).itemsize\n", | |
" # this is the part that can be passed between processes\n", | |
" shmem = mp.RawArray(ctypes.c_char, nbytes)\n", | |
" offset = 0\n", | |
" elif isinstance(buffer, ctypes.Array):\n", | |
" # restoring from a pickle\n", | |
" shmem = buffer\n", | |
" else:\n", | |
" raise ValueError(\n", | |
" f\"jankarray does not support specifying custom buffers, but \"\n", | |
" f\"was given {buffer!r}\")\n", | |
"\n", | |
" obj = np.ndarray.__new__(cls, shape, dtype=dtype, buffer=shmem,\n", | |
" offset=offset, strides=strides, order=order)\n", | |
" obj._shmem = shmem\n", | |
"\n", | |
" return obj\n", | |
" \n", | |
" def __array_finalize__(self, obj):\n", | |
" if obj is not None:\n", | |
" self._shmem = obj._shmem\n", | |
" \n", | |
" def __reduce__(self):\n", | |
" # credit to https://stackoverflow.com/a/53534485 for awful/wonderful\n", | |
" # __array_interface__ hack\n", | |
" absolute_offset = self.__array_interface__['data'][0]\n", | |
" # base_address = ctypes.addressof(ctypes.c_char.from_buffer(self._shmem.buf))\n", | |
" base_address = ctypes.addressof(self._shmem)\n", | |
" offset = absolute_offset - base_address\n", | |
" # assert offset <= self._shmem.size, (offset, self._shmem.size)\n", | |
" assert offset <= len(self._shmem), (offset, len(self._shmem))\n", | |
" order = 'FC'[self.flags['C_CONTIGUOUS']]\n", | |
" newargs = (self.shape, self.dtype, self._shmem, offset, self.strides, order)\n", | |
" return (type(self), newargs)\n", | |
" \n", | |
"def make_mpjankarray(shape, dtype=None):\n", | |
" return mpjankarray(shape, dtype=dtype, buffer=None)\n", | |
"\"\"\"\n", | |
"with open(\"_testing_mp_fix.py\", \"w\") as fp:\n", | |
" fp.write(target_source)\n", | |
"import _testing_mp_fix\n", | |
"\n", | |
"# TODO: split this all out into client/server code; going to be a pain to\n", | |
"# shove into files\n", | |
"print(\"Baseline result for ordinary ndarray:\")\n", | |
"ref_array = np.zeros((5, 3))\n", | |
"_testing_mp_fix.do_crazy_manipulations(ref_array)\n", | |
"print(\"^^^ Baseline result done. Refer to that for ground truth ^^^\\n\\n\\n\")\n", | |
"\n", | |
"print(\"Doing test: parent sends to child -> child manipulates -> parent reads back\")\n", | |
"janky = _testing_mp_fix.make_mpjankarray((5, 3))\n", | |
"print(\"[parent] In parent, array looks like this:\")\n", | |
"print(janky)\n", | |
"p = mp.Process(target=_testing_mp_fix.do_manip_test, args=(janky, ))\n", | |
"p.start()\n", | |
"p.join()\n", | |
"print(\"[parent] Subprocess exited with code\", p.exitcode)\n", | |
"print(\"[parent] Now the array looks like this:\")\n", | |
"print(janky)\n", | |
"print(\"\\n\\n\\n\")\n", | |
"\n", | |
"print(\"Doing test: parent sends to child -> child prints array, parent waits -> parent manipulates, child waits -> child prints\")\n", | |
"janky = _testing_mp_fix.make_mpjankarray((5, 3))\n", | |
"parent_should_go_ev, child_should_go_ev = mp.Event(), mp.Event()\n", | |
"print(\"[parent] In parent, array looks like this:\")\n", | |
"print(janky)\n", | |
"p = mp.Process(target=_testing_mp_fix.do_wait_test, args=(janky, parent_should_go_ev, child_should_go_ev))\n", | |
"p.start()\n", | |
"parent_should_go_ev.wait(5)\n", | |
"print(\"[parent] now parent is manipulating\")\n", | |
"_testing_mp_fix.do_crazy_manipulations(janky)\n", | |
"print(\"[parent] parent done manipulations, going back to child\")\n", | |
"child_should_go_ev.set()\n", | |
"p.join()\n", | |
"print(\"[parent] Subprocess exited with code\", p.exitcode)\n", | |
"print(\"[parent] Now the array looks like this:\")\n", | |
"print(janky)\n", | |
"print(\"\\n\\n\\n\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment