Created
December 11, 2017 11:34
-
-
Save dfee/a2f5cf0d3e017a94b79a81aefbbb332a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from inspect import iscoroutinefunction | |
import pytest | |
import rx | |
# pylint: disable=W0621, redefined-outer-name | |
def debug(value): | |
import ipdb; ipdb.set_trace() | |
print(value) | |
### Start AsyncObservable | |
class AsyncObservableMeta(type): | |
def __call__(cls, *args, **kwargs): | |
if cls.__dict__.get('__abstract__'): | |
return type( | |
'{}{}'.format(args[0].__name__, cls.__name__), | |
(cls,), | |
dict(handler=staticmethod(args[0])), | |
) | |
ins = super().__call__(*args, **kwargs) | |
return rx.Observable.using( | |
lambda: ins, | |
lambda async_func_obj: rx.Observable.from_future( | |
asyncio.ensure_future(async_func_obj()) | |
), | |
) | |
class as_async_observable(metaclass=AsyncObservableMeta): | |
# pylint: disable=C0103, invalid-name | |
__abstract__ = True | |
__slots__ = ('value', 'iteration') | |
def __init__(self, value, iteration=None): | |
self.value = value | |
self.iteration = iteration | |
@staticmethod | |
async def handler(value): | |
raise NotImplementedError() | |
def dispose(self): | |
pass | |
def __call__(self): | |
return self.handler(self.value) | |
### End AsyncObservable | |
### Start user supplied | |
class BaseModel: | |
def __init__(self, pk): | |
self.pk = pk | |
def __eq__(self, other): | |
return type(self) is type(other) and self.pk == other.pk | |
@classmethod | |
def load_from_pubsub_payload(cls, payload): | |
return cls(payload['node']['pk']) | |
@classmethod | |
async def async_load_from_pubsub_payload(cls, payload): | |
# simulating an actual async db call | |
return cls.load_from_pubsub_payload(payload) | |
class Author(BaseModel): pass | |
class Book(BaseModel): pass | |
### End user supplied | |
### Start my framework API | |
class PubSub: | |
def __init__(self): | |
self.queue = asyncio.Queue() | |
self.subject = rx.subjects.Subject() | |
self.observables = {} | |
self._receive_task = None | |
async def _receive(self): | |
while True: | |
# simulate a listening aioredis channel, for instance | |
received = await self.queue.get() | |
self.subject.on_next(received) | |
async def start(self): | |
# typically would await on subscribing to I/O resource | |
self._receive_task = asyncio.ensure_future(self._receive()) | |
async def stop(self): | |
self._receive_task.cancel() | |
self.subject.dispose() | |
@staticmethod | |
def result_mapper(observable, result, iteration=None): | |
# pylint: disable=W0613, unused-argument | |
return { | |
'mutation': observable['mutation'], | |
'node': result, | |
} | |
def register(self, model_cls, resolver, scheduler=None): | |
key = (model_cls, resolver) | |
if key not in self.observables: | |
scheduled = ( | |
self.subject.observe_on(scheduler) \ | |
if scheduler \ | |
else self.subject | |
).\ | |
filter(lambda payload: payload['type'] == model_cls.__name__) | |
if iscoroutinefunction(resolver): | |
self.observables[key] = scheduled.\ | |
flat_map(as_async_observable(resolver), self.result_mapper) | |
else: | |
self.observables[key] = scheduled.\ | |
map(lambda v: self.result_mapper(v, resolver(v))) | |
return self.observables[key] | |
### End my framework API | |
@pytest.fixture | |
@pytest.mark.asyncio | |
async def pubsub(): | |
pubsub = PubSub() | |
await pubsub.start() | |
yield pubsub | |
await pubsub.stop() | |
@pytest.mark.asyncio | |
async def test_async_subscription(pubsub, event_loop): | |
author_results = [] | |
book_results = [] | |
scheduler = rx.concurrency.AsyncIOScheduler(loop=event_loop) | |
author_sub = pubsub.register( | |
Author, | |
Author.async_load_from_pubsub_payload, | |
scheduler, | |
).subscribe(author_results.append) | |
book_sub = pubsub.register( | |
Book, | |
Book.async_load_from_pubsub_payload, | |
scheduler, | |
).subscribe(book_results.append) | |
pubsub.queue.put_nowait({ | |
'type': Author.__name__, | |
'node': {'pk': 0}, | |
'mutation': 'created', | |
}) | |
await asyncio.sleep(.1) | |
pubsub.queue.put_nowait({ | |
'type': Author.__name__, | |
'node': {'pk': 0}, | |
'mutation': 'updated', | |
}) | |
await asyncio.sleep(.1) | |
pubsub.queue.put_nowait({ | |
'type': Book.__name__, | |
'node': {'pk': 0}, | |
'mutation': 'created', | |
}) | |
await asyncio.sleep(.1) | |
assert len(author_results) == 2 | |
assert author_results == [ | |
dict(mutation='created', node=Author(0)), | |
dict(mutation='updated', node=Author(0)), | |
] | |
assert len(book_results) == 1 | |
assert book_results == [ | |
dict(mutation='created', node=Book(0)), | |
] | |
@pytest.mark.asyncio | |
async def test_sync_subscription(pubsub, event_loop): | |
author_results = [] | |
book_results = [] | |
author_sub = pubsub.register( | |
Author, | |
Author.load_from_pubsub_payload, | |
).subscribe(author_results.append) | |
book_sub = pubsub.register( | |
Book, | |
Book.load_from_pubsub_payload, | |
).subscribe(book_results.append) | |
pubsub.queue.put_nowait({ | |
'type': Author.__name__, | |
'node': {'pk': 0}, | |
'mutation': 'created', | |
}) | |
await asyncio.sleep(.1) | |
pubsub.queue.put_nowait({ | |
'type': Author.__name__, | |
'node': {'pk': 0}, | |
'mutation': 'updated', | |
}) | |
await asyncio.sleep(.1) | |
pubsub.queue.put_nowait({ | |
'type': Book.__name__, | |
'node': {'pk': 0}, | |
'mutation': 'created', | |
}) | |
await asyncio.sleep(.1) | |
assert len(author_results) == 2 | |
assert author_results == [ | |
dict(mutation='created', node=Author(0)), | |
dict(mutation='updated', node=Author(0)), | |
] | |
assert len(book_results) == 1 | |
assert book_results == [ | |
dict(mutation='created', node=Book(0)), | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment