Created
December 9, 2024 12:36
-
-
Save sockheadrps/042323036edb3e88f156f7dde7822570 to your computer and use it in GitHub Desktop.
Websocket Server and client using Python and FastAPI, Pydantic, Asyncio and Websockets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# py_websockets/client/main.py | |
import asyncio | |
from typing import Any, Dict | |
import websockets | |
from pydantic import BaseModel | |
from py_websockets.models.models import WebsocketConnect | |
class WebsocketEvent(BaseModel): | |
event: str | |
data: WebsocketConnect | Dict[str, Any] | |
class Client: | |
def __init__(self, websocket_url: str): | |
self.websocket_url = websocket_url | |
self.websocket = None | |
self.client_id = None | |
async def connect(self): | |
self.websocket = await websockets.connect(self.websocket_url) | |
connect_event = WebsocketEvent( | |
event="connect", | |
data={"message": "connection request"} | |
) | |
print(connect_event.model_dump_json()) | |
await self.websocket.send(connect_event.model_dump_json()) | |
response_json = await self.websocket.recv() | |
response = WebsocketEvent.model_validate_json(response_json) | |
print(response) | |
self.client_id = response.data.client_id | |
async def send(self, event: str, data: Dict[str, Any] = None): | |
event = WebsocketEvent( | |
event=event, | |
client_id=self.client_id, | |
data=data or {"message": "default"}) | |
print(event.model_dump_json()) | |
await self.websocket.send(event.model_dump_json()) | |
async def recv(self) -> WebsocketEvent: | |
response_json = await self.websocket.recv() | |
return WebsocketEvent.model_validate_json(response_json) | |
async def handle_client(client: Client): | |
await client.connect() | |
while True: | |
event = await client.recv() | |
print(f"Client {client.client_id} received:", event.model_dump()) | |
async def main(): | |
client = Client("ws://localhost:8000/ws/py_client") | |
await handle_client(client) | |
if __name__ == "__main__": | |
asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# py_websockets/models/models.py | |
from pydantic import BaseModel, Field | |
from typing import Literal, Dict, Any | |
import uuid | |
from fastapi import WebSocket | |
class WebsocketConnect(BaseModel): | |
status: Literal["granted"] | |
client_id: str | |
class ClientWebsocketConnectEvent(BaseModel): | |
event: Literal["connect"] | |
data: Dict[str, Any] | |
class WebsocketData(BaseModel): | |
data: WebsocketConnect | ClientWebsocketConnectEvent | |
class ClientWebsocketEvent(BaseModel): | |
event: str | |
client_id: str | |
data: WebsocketData | |
class ServerWebsocketEvent(BaseModel): | |
event: str | |
data: WebsocketConnect | ClientWebsocketConnectEvent | |
class Connection(BaseModel): | |
id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
websocket: WebSocket | |
event_validator: type[ClientWebsocketConnectEvent] = ClientWebsocketConnectEvent | |
class Config: | |
arbitrary_types_allowed = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# py_websockets/server/main.py | |
from fastapi import FastAPI, WebSocket | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from typing import Dict | |
import uvicorn | |
from starlette.websockets import WebSocketDisconnect | |
from py_websockets.models.models import WebsocketConnect, ClientWebsocketConnectEvent, WebsocketData, ClientWebsocketEvent, ServerWebsocketEvent, Connection | |
app = FastAPI() | |
app.mount( | |
"/static", StaticFiles(directory="py_websockets/server/static"), name="static" | |
) | |
templates = Jinja2Templates(directory="py_websockets/server/templates") | |
class BaseConnectionManager: | |
def __init__(self): | |
self.active_connections: Dict[str, Connection] = {} | |
async def connect(self, websocket: WebSocket) -> str: | |
await websocket.accept() | |
connection = Connection(websocket=websocket) | |
self.active_connections[connection.id] = connection | |
print(f"Client {connection.id} connected") | |
message = await connection.websocket.receive_json() | |
event = ClientWebsocketConnectEvent(**message) | |
print(f"Received event: {event}") | |
connect_event = ServerWebsocketEvent( | |
event="connect", | |
data=WebsocketConnect( | |
status="granted", | |
client_id=connection.id | |
)) | |
print(connect_event.model_dump()) | |
await connection.websocket.send_json(connect_event.model_dump()) | |
return connection.id | |
async def disconnect(self, connection_id: str) -> None: | |
if connection_id in self.active_connections: | |
del self.active_connections[connection_id] | |
async def send_to_client(self, client_id: str, message: dict) -> None: | |
if client_id in self.active_connections: | |
await self.active_connections[client_id].websocket.send_json(message) | |
py_connection_manager = BaseConnectionManager() | |
@app.websocket("/ws/py_client") | |
async def websocket_endpoint(websocket: WebSocket): | |
connection_id = None | |
try: | |
connection_id = await py_connection_manager.connect(websocket) | |
while True: | |
message = await websocket.receive_json() | |
event = ClientWebsocketEvent(**message) | |
print(f"Received event: {event.event}") | |
except WebSocketDisconnect: | |
print(f"Client {connection_id} disconnected") | |
if connection_id: | |
await py_connection_manager.disconnect(connection_id) | |
except Exception as e: | |
print(f"Error: {e}") | |
if connection_id: | |
await py_connection_manager.disconnect(connection_id) | |
if __name__ == "__main__": | |
uvicorn.run("py_websockets.server.main:app", host="localhost", port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment