Last active
August 22, 2021 07:43
-
-
Save sjlongland/27d3c10a30b34a98c39c9d0545bc83ee to your computer and use it in GitHub Desktop.
Proxying MQTT/websockets from aiohttp to back-end MQTT server (amqtt in this case)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from yaml import safe_load | |
from asyncio import coroutine, gather, get_event_loop, ensure_future | |
from aiohttp.web import json_response, Application, View, WebSocketResponse | |
from aiohttp import WSMsgType, ClientSession | |
# Read a configuration file | |
config = safe_load(open('config.yml','r').read()) | |
# Set up the core application | |
http_config = config.pop('http', {}) | |
app = Application( | |
logger=log.getChild('http'), | |
client_max_size=http_config.pop('max_payload_size', 10485760) | |
) | |
# Application set-up… we wrap `amqtt`'s `Broker` class in a "manager" that | |
# abstracts handling of the configuration and calls its `start` method. Here, | |
# we instantiate it passing in some configuration settings and a logger. | |
# | |
# Crucially, it exposes the port used by the MQTT WS port as mqttmgr._ws_port. | |
mqttmgr = MQTTManager(config.pop('mqtt'), log.getChild('mqtt')) | |
app['mqttmgr'] = mqttmgr | |
# Start the broker up | |
ensure_future(mqttmgr.start()) | |
# This is how we'd proxy that to the outside world. | |
class MQTTProxyView(View): | |
@classmethod | |
def attach(cls, aiohttp_application, base_uri): | |
""" | |
Connect this view to a aiohttp application instance. | |
""" | |
aiohttp_application.router.add_view('%smqtt' % base_uri, cls) | |
async def get(self): | |
""" | |
Proxy through Websocket requests | |
""" | |
log = logging.getLogger(__name__) | |
try: | |
manager = self.request.app['mqttmgr'] | |
# Establish websocket connection with the client | |
log.debug('Negotiating incoming WS connection') | |
ws_server = WebSocketResponse(protocols=('mqtt',)) | |
await ws_server.prepare(self.request) | |
# Establish websocket connection with MQTT backend | |
log.debug('Connecting to back-end MQTT/WS server') | |
client = ClientSession() | |
async with client.ws_connect( | |
'http://localhost:%d/' % manager._ws_port, | |
protocols=('mqtt',) | |
) as ws_client: | |
# Credit: @JustinTArthur for the initial hint regarding Tasks. | |
# https://gitter.im/aio-libs/Lobby | |
async def client_to_server(): | |
async for msg in ws_server: | |
if msg.type == WSMsgType.TEXT: | |
await ws_client.send_str(msg.data) | |
elif msg.type == WSMsgType.BINARY: | |
await ws_client.send_bytes(msg.data) | |
elif msg.type == WSMsgType.ERROR: | |
log.error( | |
'MQTT Client reports error %s', | |
ws_server.exception() | |
) | |
async def server_to_client(): | |
async for msg in ws_client: | |
if msg.type == WSMsgType.TEXT: | |
await ws_server.send_str(msg.data) | |
elif msg.type == WSMsgType.BINARY: | |
await ws_server.send_bytes(msg.data) | |
elif msg.type == WSMsgType.ERROR: | |
log.error( | |
'MQTT Server reports error %s', | |
ws_client.exception() | |
) | |
try: | |
# Let the message passing begin | |
await gather( | |
client_to_server(), | |
server_to_client() | |
) | |
finally: | |
# Clean up | |
if not ws_client.closed: | |
await ws_client.close() | |
if not ws_server.closed: | |
await ws_server.close() | |
return ws_server | |
except: | |
logging.getLogger(__name__).exception( | |
'Failed to handle MQTT websocket link' | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment