Created
March 6, 2018 11:01
-
-
Save litnimax/f19d6b017be98dd2f7a6e03f99b620f2 to your computer and use it in GitHub Desktop.
MQTT JSON RPC Client-Server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from asyncio.locks import Event | |
| from asyncio import Queue | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import uuid | |
| from hbmqtt.client import MQTTClient, ClientException | |
| from hbmqtt.mqtt.constants import QOS_1, QOS_2 | |
| from tinyrpc.server import RPCServer | |
| from tinyrpc.protocols.jsonrpc import JSONRPCProtocol | |
| from tinyrpc.dispatch import RPCDispatcher | |
| from tinyrpc.exc import RPCError | |
| from tinyrpc.client import RPCProxy | |
| logger = logging.getLogger('mqtt_rpc') | |
| dispatcher = RPCDispatcher() | |
| REPLY_TIMEOUT = float(os.environ.get('REPLY_TIMEOUT', 5)) | |
| CLIENT_UID = os.environ.get('CLIENT_UID', str(uuid.getnode())) | |
| class RPCProxy(object): | |
| def __init__(self, client, destination, one_way=False): | |
| self.client = client | |
| self.destination = destination | |
| self.one_way = one_way | |
| def __getattr__(self, name): | |
| proxy_func = lambda *args, **kwargs: self.client.call( | |
| self.destination, | |
| name, | |
| args, | |
| kwargs, | |
| one_way=self.one_way | |
| ) | |
| return proxy_func | |
| class MQTTRPCServer(object): | |
| request_count = 0 | |
| rpc_replies = {} | |
| replied = Event() # This event is triggered every time a new reply has come. | |
| subscriptions = [] # We hold a list of our subscriptions not to subscribe to | |
| # every request to the same client. | |
| def __init__(self, mqtt_url, loop=None): | |
| if not loop: | |
| loop = asyncio.get_event_loop() | |
| self.loop = loop | |
| self.protocol = JSONRPCProtocol() | |
| self.dispatcher = dispatcher | |
| self.mqtt_url = mqtt_url | |
| self.client = MQTTClient(client_id=CLIENT_UID) | |
| logger.info('Client {} initialized'.format(CLIENT_UID)) | |
| async def connect(self): | |
| r = await self.client.connect(self.mqtt_url) | |
| await self.client.subscribe([ | |
| ('rpc/{}/+'.format(CLIENT_UID), QOS_1), | |
| ]) | |
| async def serve_forever(self): | |
| asyncio.ensure_future(self.connect()) | |
| while True: | |
| try: | |
| await self.process_messages() | |
| except Exception as e: | |
| raise e | |
| async def process_messages(self): | |
| await self.client._connected_state.wait() | |
| while True: | |
| message = await self.client.deliver_message() | |
| logger.debug('Message at topic {}'.format(message.topic)) | |
| if re.search('^rpc/(\w+)/(\w+)$', message.topic): | |
| # RPC request | |
| _, _, context = message.topic.split('/') | |
| logger.debug('RPC request from {}'.format(context)) | |
| data_str = message.data.decode() | |
| asyncio.ensure_future(self.receive_rpc_request(context, data_str)) | |
| elif re.search('^rpc/(\w+)/(\w+)/reply$', message.topic): | |
| # RPC reply | |
| _, _, context, _ = message.topic.split('/') | |
| logger.debug('RPC reply from {}'.format(context)) | |
| data_str = message.data.decode() | |
| waiting_replies = self.rpc_replies.get(message.topic) | |
| if not waiting_replies: | |
| logger.warning( | |
| 'Got an unexpected RPC reply from {}: {}'.format( | |
| message.topic, data_str)) | |
| else: | |
| try: | |
| data_js = json.loads(data_str) | |
| except json.decoder.JSONDecodeError: | |
| logger.error('RPC reply bad json data: {}'.format(data_str)) | |
| else: | |
| request_id = data_js.get('id') | |
| if request_id not in waiting_replies.keys(): | |
| logger.warning( | |
| 'Got a reply from {} to bad request id: {}'.format( | |
| message.topic, data_str)) | |
| else: | |
| # Finally matched the request by id | |
| logger.debug( | |
| 'Waiting reply found for request {}'.format( | |
| request_id)) | |
| await waiting_replies[request_id].put(data_str) | |
| async def receive_rpc_request(self, context, data): | |
| logger.debug('Request: {}'.format(data)) | |
| self.request_count += 1 | |
| if type(data) != str: | |
| # Turn non-string to string or die trying | |
| data = json.dumps(data) | |
| message = data | |
| def handle_message(context, message): | |
| try: | |
| request = self.protocol.parse_request(message) | |
| except RPCError as e: | |
| response = e.error_respond() | |
| else: | |
| response = self.dispatcher.dispatch( | |
| request, | |
| getattr(self.protocol, '_caller', None) | |
| ) | |
| # send reply | |
| if response is not None: | |
| result = response.serialize() | |
| self.send_rpc_reply(context, result) | |
| handle_message(context, message) | |
| def send_rpc_reply(self, context, reply): | |
| logger.debug('RPC reply to {}: {}'.format(context, reply)) | |
| self.loop.create_task( | |
| self.client.publish('rpc/{}/{}/reply'.format(CLIENT_UID, context), reply.encode()) | |
| ) | |
| # MQTT RPC Client methods | |
| def get_proxy_for(self, destination, one_way=False): | |
| return RPCProxy(self, destination, one_way) | |
| async def _send_and_handle_reply(self, destination, req, one_way, no_exception=False): | |
| # Define a function called in several code blocks | |
| async def check_unsubscribe_from_reply(): | |
| if len(self.rpc_replies[reply_topic]) == 0: | |
| logger.debug('Unsubscribe from reply {}'.format(reply_topic)) | |
| await self.client.unsubscribe([reply_topic]) | |
| self.subscriptions.remove(reply_topic) | |
| # Convert to bytes and send to destination | |
| if one_way: | |
| # We do not need a reply it's a notification call | |
| await self.client.publish( | |
| 'rpc/{}/{}'.format(destination, CLIENT_UID), | |
| req.serialize().encode()) | |
| return | |
| # We need a reply | |
| reply_topic = ('rpc/{}/{}/reply'.format(destination, CLIENT_UID)) | |
| self.rpc_replies.setdefault(reply_topic, {})[req.unique_id] = Queue() | |
| if reply_topic not in self.subscriptions: | |
| await self.client.subscribe([(reply_topic, QOS_1)]) | |
| # Fire a task via loop without awaiting for it. | |
| self.loop.create_task(self.client.publish( | |
| 'rpc/{}/{}'.format(destination, CLIENT_UID), req.serialize().encode())) | |
| try: | |
| reply_data = await asyncio.wait_for( | |
| self.rpc_replies[reply_topic][req.unique_id].get(), REPLY_TIMEOUT) | |
| except asyncio.TimeoutError: | |
| del self.rpc_replies[reply_topic][req.unique_id] | |
| await check_unsubscribe_from_reply() | |
| raise RPCError('Reply Timeout') | |
| else: | |
| # We got a reply, handle it. | |
| logger.debug('Got a reply for request id: {}'.format( | |
| req.unique_id)) | |
| rpc_response = self.protocol.parse_reply(reply_data) | |
| del self.rpc_replies[reply_topic][req.unique_id] | |
| # Unsubscribe if we do not expect any more replies | |
| await check_unsubscribe_from_reply() | |
| # Check that there is no RPC errors. | |
| if not no_exception and hasattr(rpc_response, 'error'): | |
| raise RPCError('Error calling remote procedure: %s' %\ | |
| rpc_response.error) | |
| return rpc_response | |
| async def call(self, destination, method, args, kwargs, one_way=False): | |
| req = self.protocol.create_request(method, args, kwargs, one_way) | |
| rep = await self._send_and_handle_reply(destination, req, one_way) | |
| if one_way: | |
| return | |
| return rep.result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment