-
-
Save simon-mo/bbd3a5382615e11d595187a32b23e0ec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import types | |
| from scanner import _PyObjScanner | |
| corotinue = types.CoroutineType | |
| async def main(): | |
| scanner = _PyObjScanner() | |
| async def f(): | |
| pass | |
| obj = {"result": f(), "b": [f(), f()]} | |
| coros = scanner.find_nodes(obj) | |
| print(coros) | |
| resolved = {c: (await c) for c in coros} | |
| final = scanner.replace_nodes(resolved) | |
| print(final) | |
| asyncio.get_event_loop().run_until_complete(main()) | |
| # output | |
| # [<coroutine object main.<locals>.f at 0x124b5a640>, <coroutine object main.<locals>.f at 0x124b5a6c0>, <coroutine object main.<locals>.f at 0x124b5a740>] | |
| # {'result': None, 'b': [None, None]} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ray | |
| import asyncio | |
| import io | |
| import sys | |
| # For python < 3.8 we need to explicitly use pickle5 to support protocol 5 | |
| if sys.version_info < (3, 8): | |
| try: | |
| import pickle5 as pickle # noqa: F401 | |
| except ImportError: | |
| import pickle # noqa: F401 | |
| else: | |
| import pickle # noqa: F401 | |
| from typing import List, Dict, Any, TypeVar | |
| T = TypeVar("T") | |
| # Used in deserialization hooks to reference scanner instances. | |
| _instances: Dict[int, "_PyObjScanner"] = {} | |
| import types | |
| corotinue = types.CoroutineType | |
| def _get_node(instance_id: int, node_index: int): | |
| """Get the node instance. | |
| Note: This function should be static and globally importable, | |
| otherwise the serialization overhead would be very significant. | |
| """ | |
| return _instances[instance_id]._replace_index(node_index) | |
| class _PyObjScanner(ray.cloudpickle.CloudPickler): | |
| """Utility to find and replace DAGNodes in Python objects. | |
| This uses pickle to walk the PyObj graph and find first-level DAGNode | |
| instances on ``find_nodes()``. The caller can then compute a replacement | |
| table and then replace the nodes via ``replace_nodes()``. | |
| """ | |
| def __init__(self): | |
| # Buffer to keep intermediate serialized state. | |
| self._buf = io.BytesIO() | |
| # List of top-level DAGNodes found during the serialization pass. | |
| self._found = None | |
| # Replacement table to consult during deserialization. | |
| self._replace_table: Dict[Any, T] = None | |
| _instances[id(self)] = self | |
| super().__init__(self._buf) | |
| def reducer_override(self, obj): | |
| """Hook for reducing objects.""" | |
| if isinstance(obj, corotinue): | |
| index = len(self._found) | |
| self._found.append(obj) | |
| return _get_node, (id(self), index) | |
| return super().reducer_override(obj) | |
| def find_nodes(self, obj: Any) -> List[corotinue]: | |
| """Find top-level DAGNodes.""" | |
| assert ( | |
| self._found is None | |
| ), "find_nodes cannot be called twice on the same PyObjScanner instance." | |
| self._found = [] | |
| self.dump(obj) | |
| return self._found | |
| def replace_nodes(self, table: Dict[corotinue, T]) -> Any: | |
| """Replace previously found DAGNodes per the given table.""" | |
| assert self._found is not None, "find_nodes must be called first" | |
| self._replace_table = table | |
| self._buf.seek(0) | |
| return pickle.load(self._buf) | |
| def _replace_index(self, i: int) -> corotinue: | |
| return self._replace_table[self._found[i]] | |
| def __del__(self): | |
| del _instances[id(self)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment