Created
August 29, 2018 16:16
-
-
Save akhilman/6f4aa516a7317f36ebac427a9d392865 to your computer and use it in GitHub Desktop.
asynchron but without eventloop client for aiozmq rpc.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Based on synchronous implementation of the aiozmq.rpc.RPCClient | |
https://gist.github.com/derfenix/f18e4a8f0ee9bad738c2b22106a3ad4d | |
""" | |
import functools | |
import logging | |
import os | |
import random | |
import struct | |
import sys | |
import time | |
from collections import ChainMap | |
from concurrent.futures import Future | |
from functools import partial | |
import zmq | |
from aiozmq.rpc.base import GenericError | |
from aiozmq.rpc.rpc import _default_error_table | |
from slivoglot.core.packer import PicklePacker | |
__all__ = ['RPCClient'] | |
@functools.lru_cache() | |
def log(): | |
return logging.getLogger(__name__) | |
class RPCFuture(Future): | |
def __init__(self, client, req_id): | |
super().__init__() | |
self.client = client | |
self.req_id = req_id | |
def result(self, timeout=None): | |
if not self.done(): | |
self.client.poll_until_done(self, timeout) | |
assert self.done() | |
return super().result() | |
def exception(self, timeout=None): | |
if not self.done(): | |
self.client.poll_until_done(self, timeout) | |
assert self.done() | |
return super().exception() | |
class RPCClient(object): | |
REQ_PREFIX = struct.Struct('=HH') | |
REQ_SUFFIX = struct.Struct('=Ld') | |
RESP = struct.Struct('=HHLd?') | |
def __init__(self, *, connect, timeout=None, error_table=None): | |
self.timeout = timeout | |
self.uri = connect | |
self.prefix = self.REQ_PREFIX.pack(os.getpid() % 0x10000, | |
random.randrange(0x10000)) | |
self.packer = PicklePacker() | |
context = zmq.Context() | |
self.socket = context.socket(zmq.DEALER) | |
self.socket.connect(connect) | |
self.calls = {} | |
self._counter = 0 | |
if error_table is None: | |
self.error_table = _default_error_table | |
else: | |
self.error_table = ChainMap(error_table, _default_error_table) | |
def __del__(self): | |
if self.socket: | |
self.socket.close() | |
def _new_id(self): | |
self._counter += 1 | |
if self._counter > 0xffffffff: | |
self._counter = 0 | |
return (self.prefix + self.REQ_SUFFIX.pack(self._counter, time.time()), | |
self._counter) | |
def __getattr__(self, item): | |
try: | |
return self.__getattribute__(item) | |
except AttributeError: | |
return partial(self.call, item) | |
def __call__(self, name, *args, **kwargs): | |
return self.call(name, *args, **kwargs) | |
def call(self, name: str, *args, **kwargs): | |
binary_name = name.encode('utf-8') | |
binary_args = self.packer.packb(args) | |
binary_kwargs = self.packer.packb(kwargs) | |
header, req_id = self._new_id() | |
self.socket.send_multipart( | |
[header, binary_name, binary_args, binary_kwargs]) | |
fut = RPCFuture(self, req_id) | |
fut.set_running_or_notify_cancel() | |
self.calls[req_id] = fut | |
return fut | |
def poll(self, timeout=None): | |
if not self.calls: | |
return 0 | |
if timeout: | |
timeout = int(timeout * 1000) | |
print(timeout) | |
count = self.socket.poll(timeout=timeout) | |
if not count: | |
return 0 | |
for _ in range(count): | |
data = self.socket.recv_multipart() | |
self.msg_received(data) | |
return count | |
def poll_until_done(self, fut, timeout=None): | |
start_time = time.time() | |
while not fut.done(): | |
if timeout: | |
poll_timeout = start_time + timeout - time.time() | |
if poll_timeout <= 0: | |
raise TimeoutError('Timeout') | |
else: | |
poll_timeout = None | |
self.poll(poll_timeout) | |
def _translate_error(self, exc_type, exc_args, exc_repr): | |
found = self.error_table.get(exc_type) | |
if found is None: | |
return GenericError(exc_type, exc_args, exc_repr) | |
else: | |
return found(*exc_args) | |
def msg_received(self, data): | |
try: | |
header, banswer = data | |
pid, rnd, req_id, timestamp, is_error = self.RESP.unpack(header) | |
answer = self.packer.unpackb(banswer) | |
except Exception: | |
log().critical("Cannot unpack %r", data, exc_info=sys.exc_info()) | |
return | |
call = self.calls.pop(req_id, None) | |
if call is None: | |
log().critical("Unknown answer id: %d (%d %d %f %d) -> %s", | |
req_id, pid, rnd, timestamp, is_error, answer) | |
elif call.cancelled(): | |
log().debug("The future for request #%08x has been cancelled, " | |
"skip the received result.", req_id) | |
else: | |
if is_error: | |
call.set_exception(self._translate_error(*answer)) | |
else: | |
call.set_result(answer) | |
def connect_rpc(**kwargs): | |
return RPCClient(**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment