Skip to content

Instantly share code, notes, and snippets.

@timhughes
Last active March 15, 2021 09:39
Show Gist options
  • Save timhughes/3e85a3a0e75858670c2dd3c21c790d3f to your computer and use it in GitHub Desktop.
Save timhughes/3e85a3a0e75858670c2dd3c21c790d3f to your computer and use it in GitHub Desktop.
Asyncio Websocket server which exits cleanly
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
import asyncio
import datetime
import json
import logging
from typing import Dict
import websockets
HOST = "localhost"
PORT = 6789
logging.basicConfig(level=logging.INFO)
connections = set()
def encode_msg(msg: Dict) -> str:
return json.dumps(msg, ensure_ascii=False)
def decode_msg(text: str) -> Dict:
return json.loads(text)
async def producer() -> str:
while True:
await asyncio.sleep(5)
message = {
"type": "heartbeat",
"message": {"datetime": str(datetime.datetime.now())},
}
yield message
async def consumer_handler(websocket) -> None:
try:
async for message_raw in websocket:
message = decode_msg(message_raw)
logging.warning(message)
except websockets.ConnectionClosed as exc:
logging.info(
"Lost connection from %s:%s, %s",
websocket.remote_address[0],
websocket.remote_address[1],
str(exc),
)
async def producer_handler(websocket) -> None:
async for message in producer():
await websocket.send(encode_msg(message))
async def handler(websocket: websockets.WebSocketServerProtocol, path: str):
remote_ip = websocket.remote_address
logging.info(
"Accepted connection from %s:%s",
remote_ip[0],
remote_ip[1],
)
consumer_task = asyncio.ensure_future(consumer_handler(websocket))
producer_task = asyncio.ensure_future(producer_handler(websocket))
done, pending = await asyncio.wait(
[consumer_task, producer_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
if __name__ == "__main__":
logging.info("Listening on ws://%s:%s", HOST, PORT)
logging.info("Event loop running forever, press Ctrl+C to interrupt.")
server = websockets.serve(handler, HOST, PORT)
tasks = asyncio.gather(server)
loop = asyncio.get_event_loop()
loop.run_until_complete(tasks)
try:
loop.run_forever()
except KeyboardInterrupt:
# Cleanup
server.ws_server.close()
loop.run_until_complete(server.ws_server.wait_closed())
finally:
loop.close()
logging.info("Successfully shutdown the service.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment