Skip to content

Instantly share code, notes, and snippets.

@simon-mo
Created August 3, 2022 01:58
Show Gist options
  • Select an option

  • Save simon-mo/bbd3a5382615e11d595187a32b23e0ec to your computer and use it in GitHub Desktop.

Select an option

Save simon-mo/bbd3a5382615e11d595187a32b23e0ec to your computer and use it in GitHub Desktop.
import asyncio
import types
from scanner import _PyObjScanner
corotinue = types.CoroutineType
async def main():
scanner = _PyObjScanner()
async def f():
pass
obj = {"result": f(), "b": [f(), f()]}
coros = scanner.find_nodes(obj)
print(coros)
resolved = {c: (await c) for c in coros}
final = scanner.replace_nodes(resolved)
print(final)
asyncio.get_event_loop().run_until_complete(main())
# output
# [<coroutine object main.<locals>.f at 0x124b5a640>, <coroutine object main.<locals>.f at 0x124b5a6c0>, <coroutine object main.<locals>.f at 0x124b5a740>]
# {'result': None, 'b': [None, None]}
import ray
import asyncio
import io
import sys
# For python < 3.8 we need to explicitly use pickle5 to support protocol 5
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
from typing import List, Dict, Any, TypeVar
T = TypeVar("T")
# Used in deserialization hooks to reference scanner instances.
_instances: Dict[int, "_PyObjScanner"] = {}
import types
corotinue = types.CoroutineType
def _get_node(instance_id: int, node_index: int):
"""Get the node instance.
Note: This function should be static and globally importable,
otherwise the serialization overhead would be very significant.
"""
return _instances[instance_id]._replace_index(node_index)
class _PyObjScanner(ray.cloudpickle.CloudPickler):
"""Utility to find and replace DAGNodes in Python objects.
This uses pickle to walk the PyObj graph and find first-level DAGNode
instances on ``find_nodes()``. The caller can then compute a replacement
table and then replace the nodes via ``replace_nodes()``.
"""
def __init__(self):
# Buffer to keep intermediate serialized state.
self._buf = io.BytesIO()
# List of top-level DAGNodes found during the serialization pass.
self._found = None
# Replacement table to consult during deserialization.
self._replace_table: Dict[Any, T] = None
_instances[id(self)] = self
super().__init__(self._buf)
def reducer_override(self, obj):
"""Hook for reducing objects."""
if isinstance(obj, corotinue):
index = len(self._found)
self._found.append(obj)
return _get_node, (id(self), index)
return super().reducer_override(obj)
def find_nodes(self, obj: Any) -> List[corotinue]:
"""Find top-level DAGNodes."""
assert (
self._found is None
), "find_nodes cannot be called twice on the same PyObjScanner instance."
self._found = []
self.dump(obj)
return self._found
def replace_nodes(self, table: Dict[corotinue, T]) -> Any:
"""Replace previously found DAGNodes per the given table."""
assert self._found is not None, "find_nodes must be called first"
self._replace_table = table
self._buf.seek(0)
return pickle.load(self._buf)
def _replace_index(self, i: int) -> corotinue:
return self._replace_table[self._found[i]]
def __del__(self):
del _instances[id(self)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment