Last active
December 12, 2017 10:14
-
-
Save bluetech/108bc32dfa97bae22118baed9809f26b to your computer and use it in GitHub Desktop.
Trio #369
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import logging | |
import trio | |
import attr | |
import protocol | |
logging.basicConfig(level=logging.DEBUG) | |
log = logging.getLogger('client') | |
@attr.s | |
class Future: | |
_result = attr.ib(default=None) | |
_finished = attr.ib(default=attr.Factory(trio.Event)) | |
def set_result(self, result): | |
assert not self._finished.is_set() | |
self._result = result | |
self._finished.set() | |
async def get(self): | |
await self._finished.wait() | |
return self._result.unwrap() | |
def set_value(self, value): | |
self.set_result(trio.hazmat.Value(value)) | |
def set_exception(self, exc): | |
self.set_result(trio.hazmat.Error(exc)) | |
def done(self): | |
return self._finished.is_set() | |
class RpcResponseError(Exception): | |
def __init__(self, message): | |
self.message = message | |
class RpcConnectionError(Exception): | |
pass | |
class RpcClient: | |
def __init__(self, stream): | |
self.stream = stream | |
self.send_lock = trio.StrictFIFOLock() | |
self.pending_requests = {} | |
self.id_generator = itertools.count(1) | |
async def _call(self, method, params): | |
id = next(self.id_generator) | |
request = { | |
'id': id, | |
'method': method, | |
'params': params, | |
} | |
self.pending_requests[id] = Future() | |
try: | |
async with self.send_lock: | |
try: | |
await protocol.send_message(self.stream, request) | |
except trio.BrokenStreamError as e: | |
raise RpcConnectionError() from e | |
response = await self.pending_requests[id].get() | |
finally: | |
del self.pending_requests[id] | |
error, result = response['error'], response['result'] | |
if error is not None: | |
raise RpcResponseError(error) | |
else: | |
return result | |
async def receiver(self): | |
try: | |
async for response in protocol.receive_messages(self.stream): | |
id = response['id'] | |
try: | |
future = self.pending_requests[id] | |
except KeyError: | |
log.warning('{id:02} DISCARDED') | |
continue | |
if future.done(): | |
log.warning('{id:02} DUPLICATED') | |
continue | |
future.set_value(response) | |
except trio.BrokenStreamError as e: | |
pass | |
finally: | |
for future in self.pending_requests.values(): | |
future.set_exception(RpcConnectionError()) | |
async def what_is_the_meaning_of_life(self): | |
return await self._call('meaning-of-life', params=None) | |
async def add(self, **params): | |
return await self._call('add', params) | |
async def find_out_the_meaning_of_life(): | |
async def ask_deep_thought(rpc_client, i): | |
log.info(f'{i:02} STARTED') | |
try: | |
# We don't have 7.5 million years. | |
with trio.fail_after(5): | |
result = await rpc_client.what_is_the_meaning_of_life() | |
except trio.TooSlowError: | |
log.warning(f'{i:02} ABANDONED') | |
except Exception as e: | |
log.info(f'{i:02} FAILED: {type(e).__name__}: {e}') | |
else: | |
log.info(f'{i:02} FINISHED: {result}') | |
async def ask_deep_thoughts(rpc_client, shutdown): | |
# Ask several times to ensure the meaning stays the same. | |
async with trio.open_nursery() as nursery: | |
for i in range(20): | |
nursery.start_soon(ask_deep_thought, rpc_client, i) | |
shutdown() | |
async with await trio.open_tcp_stream('127.0.0.1', 9999) as stream: | |
async with trio.open_nursery() as nursery: | |
rpc_client = RpcClient(stream) | |
nursery.start_soon(rpc_client.receiver) | |
nursery.start_soon(ask_deep_thoughts, rpc_client, nursery.cancel_scope.cancel) | |
try: | |
trio.run(find_out_the_meaning_of_life) | |
except KeyboardInterrupt: | |
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
log = logging.getLogger('protocol') | |
async def receive_messages(stream): | |
# Ugly code for yielding length-prefixed JSON messages off a stream. | |
buf = bytearray() | |
length = None | |
while True: | |
chunk = await stream.receive_some(16384) | |
buf += chunk | |
while True: | |
if length is None: | |
if len(buf) < 4: | |
break | |
length, buf = int.from_bytes(buf[:4], 'big'), buf[4:] | |
if len(buf) < length: | |
break | |
raw_message, buf = buf[:length], buf[length:] | |
length = None | |
message = json.loads(raw_message.decode()) | |
yield message | |
if not chunk: | |
if buf: | |
log.debug('discarding incomplete message') | |
break | |
async def send_message(stream, message): | |
# Send a length-prefixed JSON message over a stream. | |
raw_message = json.dumps(message).encode() | |
prefixed_raw_message = len(raw_message).to_bytes(4, 'big') + raw_message | |
await stream.send_all(prefixed_raw_message) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import signal | |
import logging | |
import trio | |
import protocol | |
logging.basicConfig(level=logging.DEBUG) | |
log = logging.getLogger('server') | |
async def rpc_request_handler(request, stream, send_lock, finished): | |
id = request['id'] | |
method = request['method'] | |
params = request['params'] | |
log.info(f'{id:02} RECEIVED: {request}') | |
if method == 'add': | |
result, error = sum(params), None | |
elif method == 'meaning-of-life': | |
thinking_time = random.random() * 10 | |
await trio.sleep(thinking_time) | |
# Simulate some failures. | |
if random.random() > 0.2: | |
result, error = 42, None | |
else: | |
result, error = None, 'thinking overload' | |
else: | |
result, error = None, 'unknown method' | |
response = { | |
'id': id, | |
'result': result, | |
'error': error, | |
} | |
async with send_lock: | |
await protocol.send_message(stream, response) | |
finished() | |
async def receiver(stream, nursery): | |
send_lock = trio.Lock() | |
concurrent_tasks = trio.Semaphore(128) | |
task_finished = concurrent_tasks.release | |
async for request in protocol.receive_messages(stream): | |
await concurrent_tasks.acquire() | |
nursery.start_soon(rpc_request_handler, request, stream, send_lock, task_finished) | |
async def client_handler(stream): | |
try: | |
async with stream, trio.open_nursery() as nursery: | |
nursery.start_soon(receiver, stream, nursery) | |
except trio.BrokenStreamError: | |
log.info('client bailed early') | |
except (Exception, trio.MultiError): | |
log.exception('client handler crashed') | |
else: | |
log.info('client finished') | |
async def signal_handler(terminate): | |
signals = {signal.SIGINT, signal.SIGTERM} | |
with trio.catch_signals(signals) as batched_signals_aiter: | |
async for batch in batched_signals_aiter: | |
terminate() | |
break | |
async def serve(): | |
async with trio.open_nursery() as nursery: | |
nursery.start_soon(signal_handler, nursery.cancel_scope.cancel) | |
# Number of concurrent connections is limited by RLIMIT_NOFILE, | |
# so no need to limit ourselves. | |
nursery.start_soon(trio.serve_tcp, client_handler, 9999) | |
trio.run(serve) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment