Skip to content

Instantly share code, notes, and snippets.

@Msameim181
Last active December 30, 2023 16:44
Show Gist options
  • Save Msameim181/a060be20caca2e458c791d4e074f2659 to your computer and use it in GitHub Desktop.
Save Msameim181/a060be20caca2e458c791d4e074f2659 to your computer and use it in GitHub Desktop.
SocketIO with FastAPI, Running with `uvicorn`. It can handle API routes and socket connections simultaneously.
import socketio
import asyncio
from abc import ABC, abstractmethod
from typing import Dict
import logging
import threading
import time
import sys
logger = logging.getLogger(__name__)
class IBaseSocketClient(ABC):
@abstractmethod
def call_backs(self):
"""Some functions register for socket to communicate by them.
you can set event function by:
self.sio.event
or
self.sio.on(<event name>)
"""
...
@abstractmethod
async def connect_to_server(self):
"""Connect to server by socketio."""
...
@abstractmethod
async def run(self):
"""Run the socket client."""
...
@abstractmethod
def start_background_loop(self, loop: asyncio.AbstractEventLoop) -> None:
"""Start the background loop."""
...
class SocketClientConfig:
server_url: str
headers: Dict
socketio_path: str = "/socket.io"
auth: Dict = None
logger: bool = True
engineio_logger: bool = True
reconnection: bool = True
reconnection_delay: int = 3
reconnection_attempts: int = 10
def __init__(
self,
server_url: str,
headers: Dict,
socketio_path: str = "/socket.io",
auth: Dict = None,
logger: bool = True,
engineio_logger: bool = True,
reconnection: bool = True,
reconnection_delay: int = 3,
reconnection_attempts: int = 10,
):
self.server_url = server_url
self.headers = headers
self.socketio_path = socketio_path
self.auth = auth
self.logger = logger
self.engineio_logger = engineio_logger
self.reconnection = reconnection
self.reconnection_delay = reconnection_delay
self.reconnection_attempts = reconnection_attempts
class BaseSocketClient(IBaseSocketClient):
def __init__(
self,
config: SocketClientConfig,
logger: logging.Logger = None,
):
self.config = config
self.sio = socketio.AsyncClient(
handle_sigint=True,
logger=self.config.logger,
engineio_logger=self.config.engineio_logger,
reconnection=self.config.reconnection,
reconnection_delay=self.config.reconnection_delay,
reconnection_attempts=self.config.reconnection_attempts,
)
self.logger = logger
self._client_loop = asyncio.new_event_loop()
def start_background_loop(self, loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
@property
def client_loop(self):
return self._client_loop
async def connect_to_server(self):
try:
await self.sio.connect(
self.config.server_url,
headers=self.config.headers,
socketio_path=self.config.socketio_path,
auth=self.config.auth,
)
self.logger.info("Connected to the server")
except ConnectionError:
self.logger.error("Connection failed.")
await self.sio.wait()
def call_backs(self):
@self.sio.on("connect")
async def connect():
self.logger.info("Socket connected to server")
@self.sio.on("disconnect")
def disconnect():
self.logger.info("Socket disconnected from server")
self.client_loop.stop()
@self.sio.on("message")
async def message(data):
self.logger.info(f"Message from server: {data}")
async def run(self):
self.call_backs()
await self.connect_to_server()
if __name__ == "__main__":
token = "test"
config = SocketClientConfig(
server_url="http://0.0.0.0:8009",
headers={"Authorization": f"Bearer {token}"},
socketio_path="/socket.io",
auth={"token": token},
engineio_logger=True,
)
base_client = BaseSocketClient(config, logger=logger)
th = threading.Thread(
target=base_client.start_background_loop,
args=(base_client.client_loop,),
daemon=True,
)
th.start()
asyncio.run_coroutine_threadsafe(base_client.run(), base_client.client_loop)
time.sleep(0.5)
while True:
if not base_client.sio.connected:
logger.info("Client is not connected to the server")
sys.exit(1)
try:
user_input = input('Enter a message (or "exit" to quit): ')
if user_input.lower() == "exit":
break
else:
asyncio.run_coroutine_threadsafe(
base_client.sio.emit("message", user_input), base_client.client_loop
)
except KeyboardInterrupt:
break
base_client.sio.disconnect()
import socketio
import asyncio
import uvicorn
from fastapi import FastAPI
from abc import ABC, abstractmethod
from typing import Dict
import logging
logger = logging.getLogger(__name__)
class IBaseSocketServer(ABC):
@abstractmethod
def call_backs(self):
"""Some functions register for socket to communicate by them.
you can set event function by:
self.sio.event
or
self.sio.on(<event name>)
"""
...
@abstractmethod
def run_server(self):
"""
Run socket server:
- create app
- attach app to socket
- register call back functions (event functions)
- run the application and set host and port to it
"""
...
class SocketServerConfig:
host: str
port: int
log_level: str = "info"
cors_allowed_origins: str = "*"
socketio_path: str = "/socket.io"
logger: bool = True
always_connect: bool = True
engineio_logger: bool = True
server_workers: int = None
reload: bool = False
def __init__(
self,
host: str,
port: int = 8000,
log_level: str = "info",
cors_allowed_origins: str = "*",
socketio_path: str = "/socket.io",
logger: bool = True,
always_connect: bool = True,
engineio_logger: bool = True,
server_workers: int = None,
):
self.host = host
self.port = port
self.log_level = log_level
self.cors_allowed_origins = cors_allowed_origins
self.socketio_path = socketio_path
self.logger = logger
self.always_connect = always_connect
self.engineio_logger = engineio_logger
self.server_workers = server_workers
class BaseSocketServer(IBaseSocketServer):
def __init__(
self,
config: SocketServerConfig,
async_mode: str = "asgi",
cors_allowed_origins: str = "*",
logger: logging.Logger = None,
):
self.config = config
self.fastapi_app = FastAPI()
self.api_route()
self.sio = socketio.AsyncServer(
async_mode="asgi",
cors_allowed_origins=self.config.cors_allowed_origins, always_connect=self.config.always_connect,
logger=self.config.logger, engineio_logger=self.config.engineio_logger
)
self.app = socketio.ASGIApp(
self.sio, self.fastapi_app, socketio_path=self.config.socketio_path
)
self.logger = logger
def run_server(self):
self.call_backs()
uvicorn.run(
self.app,
host=self.config.host,
port=self.config.port,
log_level=self.config.log_level,
workers=self.config.server_workers,
)
def api_route(self):
@self.fastapi_app.get("/")
def home():
return {"message": "Hello World"}
@self.fastapi_app.get("/send/{room_id}/{message}")
async def send_message(room_id: str, message: str):
self.logger.info(f"[API] Send message to {room_id}: {message}")
await self.send_message(room_id, message)
return {"message": "OK"}
def call_backs(self):
@self.sio.on("connect")
async def connect(sid, environ, auth):
self.logger.info(environ)
self.logger.info(f"Socket connected with ID: {sid}")
if auth:
self.logger.info(f"Auth: {auth}")
else:
self.logger.info("No Auth")
# @TODO: add auth to socket
await self.sio.emit("connected", room=sid)
@self.sio.on("disconnect")
def disconnect(sid):
self.logger.info(f"Socket disconnected with ID: {sid}")
@self.sio.on("join")
async def join(sid, room):
await self.sio.enter_room(sid, room)
self.logger.info(f"Client {sid} joined room {room}, {self.sio.rooms(sid)}")
@self.sio.on("leave")
async def leave(sid, room):
await self.sio.leave_room(sid, room)
self.logger.info(f"Client {sid} left room {room}")
@self.sio.on("message")
async def message(sid, data):
self.logger.info(f"Message from {sid}: {data}")
self.logger.info("Sending message to somewhere")
# send a outside message to client with room id
async def send_message(self, room_id: str, message: str):
self.logger.info(f"Send message to {room_id}: {message}")
await self.sio.emit("message", message, room=room_id)
if __name__ == "__main__":
config = SocketServerConfig(
host="0.0.0.0", port=8009, log_level="info",
logger=False, engineio_logger=False
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
server = BaseSocketServer(config=config, logger=logger)
server.run_server()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment