Skip to content

Instantly share code, notes, and snippets.

@sockheadrps
Created December 9, 2024 12:36
Show Gist options
  • Save sockheadrps/042323036edb3e88f156f7dde7822570 to your computer and use it in GitHub Desktop.
Save sockheadrps/042323036edb3e88f156f7dde7822570 to your computer and use it in GitHub Desktop.
Websocket Server and client using Python and FastAPI, Pydantic, Asyncio and Websockets
# py_websockets/client/main.py
import asyncio
from typing import Any, Dict
import websockets
from pydantic import BaseModel
from py_websockets.models.models import WebsocketConnect
class WebsocketEvent(BaseModel):
event: str
data: WebsocketConnect | Dict[str, Any]
class Client:
def __init__(self, websocket_url: str):
self.websocket_url = websocket_url
self.websocket = None
self.client_id = None
async def connect(self):
self.websocket = await websockets.connect(self.websocket_url)
connect_event = WebsocketEvent(
event="connect",
data={"message": "connection request"}
)
print(connect_event.model_dump_json())
await self.websocket.send(connect_event.model_dump_json())
response_json = await self.websocket.recv()
response = WebsocketEvent.model_validate_json(response_json)
print(response)
self.client_id = response.data.client_id
async def send(self, event: str, data: Dict[str, Any] = None):
event = WebsocketEvent(
event=event,
client_id=self.client_id,
data=data or {"message": "default"})
print(event.model_dump_json())
await self.websocket.send(event.model_dump_json())
async def recv(self) -> WebsocketEvent:
response_json = await self.websocket.recv()
return WebsocketEvent.model_validate_json(response_json)
async def handle_client(client: Client):
await client.connect()
while True:
event = await client.recv()
print(f"Client {client.client_id} received:", event.model_dump())
async def main():
client = Client("ws://localhost:8000/ws/py_client")
await handle_client(client)
if __name__ == "__main__":
asyncio.run(main())
# py_websockets/models/models.py
from pydantic import BaseModel, Field
from typing import Literal, Dict, Any
import uuid
from fastapi import WebSocket
class WebsocketConnect(BaseModel):
status: Literal["granted"]
client_id: str
class ClientWebsocketConnectEvent(BaseModel):
event: Literal["connect"]
data: Dict[str, Any]
class WebsocketData(BaseModel):
data: WebsocketConnect | ClientWebsocketConnectEvent
class ClientWebsocketEvent(BaseModel):
event: str
client_id: str
data: WebsocketData
class ServerWebsocketEvent(BaseModel):
event: str
data: WebsocketConnect | ClientWebsocketConnectEvent
class Connection(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
websocket: WebSocket
event_validator: type[ClientWebsocketConnectEvent] = ClientWebsocketConnectEvent
class Config:
arbitrary_types_allowed = True
# py_websockets/server/main.py
from fastapi import FastAPI, WebSocket
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from typing import Dict
import uvicorn
from starlette.websockets import WebSocketDisconnect
from py_websockets.models.models import WebsocketConnect, ClientWebsocketConnectEvent, WebsocketData, ClientWebsocketEvent, ServerWebsocketEvent, Connection
app = FastAPI()
app.mount(
"/static", StaticFiles(directory="py_websockets/server/static"), name="static"
)
templates = Jinja2Templates(directory="py_websockets/server/templates")
class BaseConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Connection] = {}
async def connect(self, websocket: WebSocket) -> str:
await websocket.accept()
connection = Connection(websocket=websocket)
self.active_connections[connection.id] = connection
print(f"Client {connection.id} connected")
message = await connection.websocket.receive_json()
event = ClientWebsocketConnectEvent(**message)
print(f"Received event: {event}")
connect_event = ServerWebsocketEvent(
event="connect",
data=WebsocketConnect(
status="granted",
client_id=connection.id
))
print(connect_event.model_dump())
await connection.websocket.send_json(connect_event.model_dump())
return connection.id
async def disconnect(self, connection_id: str) -> None:
if connection_id in self.active_connections:
del self.active_connections[connection_id]
async def send_to_client(self, client_id: str, message: dict) -> None:
if client_id in self.active_connections:
await self.active_connections[client_id].websocket.send_json(message)
py_connection_manager = BaseConnectionManager()
@app.websocket("/ws/py_client")
async def websocket_endpoint(websocket: WebSocket):
connection_id = None
try:
connection_id = await py_connection_manager.connect(websocket)
while True:
message = await websocket.receive_json()
event = ClientWebsocketEvent(**message)
print(f"Received event: {event.event}")
except WebSocketDisconnect:
print(f"Client {connection_id} disconnected")
if connection_id:
await py_connection_manager.disconnect(connection_id)
except Exception as e:
print(f"Error: {e}")
if connection_id:
await py_connection_manager.disconnect(connection_id)
if __name__ == "__main__":
uvicorn.run("py_websockets.server.main:app", host="localhost", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment