Last active
December 30, 2023 16:44
-
-
Save Msameim181/a060be20caca2e458c791d4e074f2659 to your computer and use it in GitHub Desktop.
SocketIO with FastAPI, Running with `uvicorn`. It can handle API routes and socket connections simultaneously.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import socketio | |
import asyncio | |
from abc import ABC, abstractmethod | |
from typing import Dict | |
import logging | |
import threading | |
import time | |
import sys | |
logger = logging.getLogger(__name__) | |
class IBaseSocketClient(ABC): | |
@abstractmethod | |
def call_backs(self): | |
"""Some functions register for socket to communicate by them. | |
you can set event function by: | |
self.sio.event | |
or | |
self.sio.on(<event name>) | |
""" | |
... | |
@abstractmethod | |
async def connect_to_server(self): | |
"""Connect to server by socketio.""" | |
... | |
@abstractmethod | |
async def run(self): | |
"""Run the socket client.""" | |
... | |
@abstractmethod | |
def start_background_loop(self, loop: asyncio.AbstractEventLoop) -> None: | |
"""Start the background loop.""" | |
... | |
class SocketClientConfig: | |
server_url: str | |
headers: Dict | |
socketio_path: str = "/socket.io" | |
auth: Dict = None | |
logger: bool = True | |
engineio_logger: bool = True | |
reconnection: bool = True | |
reconnection_delay: int = 3 | |
reconnection_attempts: int = 10 | |
def __init__( | |
self, | |
server_url: str, | |
headers: Dict, | |
socketio_path: str = "/socket.io", | |
auth: Dict = None, | |
logger: bool = True, | |
engineio_logger: bool = True, | |
reconnection: bool = True, | |
reconnection_delay: int = 3, | |
reconnection_attempts: int = 10, | |
): | |
self.server_url = server_url | |
self.headers = headers | |
self.socketio_path = socketio_path | |
self.auth = auth | |
self.logger = logger | |
self.engineio_logger = engineio_logger | |
self.reconnection = reconnection | |
self.reconnection_delay = reconnection_delay | |
self.reconnection_attempts = reconnection_attempts | |
class BaseSocketClient(IBaseSocketClient): | |
def __init__( | |
self, | |
config: SocketClientConfig, | |
logger: logging.Logger = None, | |
): | |
self.config = config | |
self.sio = socketio.AsyncClient( | |
handle_sigint=True, | |
logger=self.config.logger, | |
engineio_logger=self.config.engineio_logger, | |
reconnection=self.config.reconnection, | |
reconnection_delay=self.config.reconnection_delay, | |
reconnection_attempts=self.config.reconnection_attempts, | |
) | |
self.logger = logger | |
self._client_loop = asyncio.new_event_loop() | |
def start_background_loop(self, loop: asyncio.AbstractEventLoop) -> None: | |
asyncio.set_event_loop(loop) | |
loop.run_forever() | |
@property | |
def client_loop(self): | |
return self._client_loop | |
async def connect_to_server(self): | |
try: | |
await self.sio.connect( | |
self.config.server_url, | |
headers=self.config.headers, | |
socketio_path=self.config.socketio_path, | |
auth=self.config.auth, | |
) | |
self.logger.info("Connected to the server") | |
except ConnectionError: | |
self.logger.error("Connection failed.") | |
await self.sio.wait() | |
def call_backs(self): | |
@self.sio.on("connect") | |
async def connect(): | |
self.logger.info("Socket connected to server") | |
@self.sio.on("disconnect") | |
def disconnect(): | |
self.logger.info("Socket disconnected from server") | |
self.client_loop.stop() | |
@self.sio.on("message") | |
async def message(data): | |
self.logger.info(f"Message from server: {data}") | |
async def run(self): | |
self.call_backs() | |
await self.connect_to_server() | |
if __name__ == "__main__": | |
token = "test" | |
config = SocketClientConfig( | |
server_url="http://0.0.0.0:8009", | |
headers={"Authorization": f"Bearer {token}"}, | |
socketio_path="/socket.io", | |
auth={"token": token}, | |
engineio_logger=True, | |
) | |
base_client = BaseSocketClient(config, logger=logger) | |
th = threading.Thread( | |
target=base_client.start_background_loop, | |
args=(base_client.client_loop,), | |
daemon=True, | |
) | |
th.start() | |
asyncio.run_coroutine_threadsafe(base_client.run(), base_client.client_loop) | |
time.sleep(0.5) | |
while True: | |
if not base_client.sio.connected: | |
logger.info("Client is not connected to the server") | |
sys.exit(1) | |
try: | |
user_input = input('Enter a message (or "exit" to quit): ') | |
if user_input.lower() == "exit": | |
break | |
else: | |
asyncio.run_coroutine_threadsafe( | |
base_client.sio.emit("message", user_input), base_client.client_loop | |
) | |
except KeyboardInterrupt: | |
break | |
base_client.sio.disconnect() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import socketio | |
import asyncio | |
import uvicorn | |
from fastapi import FastAPI | |
from abc import ABC, abstractmethod | |
from typing import Dict | |
import logging | |
logger = logging.getLogger(__name__) | |
class IBaseSocketServer(ABC): | |
@abstractmethod | |
def call_backs(self): | |
"""Some functions register for socket to communicate by them. | |
you can set event function by: | |
self.sio.event | |
or | |
self.sio.on(<event name>) | |
""" | |
... | |
@abstractmethod | |
def run_server(self): | |
""" | |
Run socket server: | |
- create app | |
- attach app to socket | |
- register call back functions (event functions) | |
- run the application and set host and port to it | |
""" | |
... | |
class SocketServerConfig: | |
host: str | |
port: int | |
log_level: str = "info" | |
cors_allowed_origins: str = "*" | |
socketio_path: str = "/socket.io" | |
logger: bool = True | |
always_connect: bool = True | |
engineio_logger: bool = True | |
server_workers: int = None | |
reload: bool = False | |
def __init__( | |
self, | |
host: str, | |
port: int = 8000, | |
log_level: str = "info", | |
cors_allowed_origins: str = "*", | |
socketio_path: str = "/socket.io", | |
logger: bool = True, | |
always_connect: bool = True, | |
engineio_logger: bool = True, | |
server_workers: int = None, | |
): | |
self.host = host | |
self.port = port | |
self.log_level = log_level | |
self.cors_allowed_origins = cors_allowed_origins | |
self.socketio_path = socketio_path | |
self.logger = logger | |
self.always_connect = always_connect | |
self.engineio_logger = engineio_logger | |
self.server_workers = server_workers | |
class BaseSocketServer(IBaseSocketServer): | |
def __init__( | |
self, | |
config: SocketServerConfig, | |
async_mode: str = "asgi", | |
cors_allowed_origins: str = "*", | |
logger: logging.Logger = None, | |
): | |
self.config = config | |
self.fastapi_app = FastAPI() | |
self.api_route() | |
self.sio = socketio.AsyncServer( | |
async_mode="asgi", | |
cors_allowed_origins=self.config.cors_allowed_origins, always_connect=self.config.always_connect, | |
logger=self.config.logger, engineio_logger=self.config.engineio_logger | |
) | |
self.app = socketio.ASGIApp( | |
self.sio, self.fastapi_app, socketio_path=self.config.socketio_path | |
) | |
self.logger = logger | |
def run_server(self): | |
self.call_backs() | |
uvicorn.run( | |
self.app, | |
host=self.config.host, | |
port=self.config.port, | |
log_level=self.config.log_level, | |
workers=self.config.server_workers, | |
) | |
def api_route(self): | |
@self.fastapi_app.get("/") | |
def home(): | |
return {"message": "Hello World"} | |
@self.fastapi_app.get("/send/{room_id}/{message}") | |
async def send_message(room_id: str, message: str): | |
self.logger.info(f"[API] Send message to {room_id}: {message}") | |
await self.send_message(room_id, message) | |
return {"message": "OK"} | |
def call_backs(self): | |
@self.sio.on("connect") | |
async def connect(sid, environ, auth): | |
self.logger.info(environ) | |
self.logger.info(f"Socket connected with ID: {sid}") | |
if auth: | |
self.logger.info(f"Auth: {auth}") | |
else: | |
self.logger.info("No Auth") | |
# @TODO: add auth to socket | |
await self.sio.emit("connected", room=sid) | |
@self.sio.on("disconnect") | |
def disconnect(sid): | |
self.logger.info(f"Socket disconnected with ID: {sid}") | |
@self.sio.on("join") | |
async def join(sid, room): | |
await self.sio.enter_room(sid, room) | |
self.logger.info(f"Client {sid} joined room {room}, {self.sio.rooms(sid)}") | |
@self.sio.on("leave") | |
async def leave(sid, room): | |
await self.sio.leave_room(sid, room) | |
self.logger.info(f"Client {sid} left room {room}") | |
@self.sio.on("message") | |
async def message(sid, data): | |
self.logger.info(f"Message from {sid}: {data}") | |
self.logger.info("Sending message to somewhere") | |
# send a outside message to client with room id | |
async def send_message(self, room_id: str, message: str): | |
self.logger.info(f"Send message to {room_id}: {message}") | |
await self.sio.emit("message", message, room=room_id) | |
if __name__ == "__main__": | |
config = SocketServerConfig( | |
host="0.0.0.0", port=8009, log_level="info", | |
logger=False, engineio_logger=False | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
logger.addHandler(logging.StreamHandler()) | |
server = BaseSocketServer(config=config, logger=logger) | |
server.run_server() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment