Created
June 11, 2023 15:42
-
-
Save ethereumdegen/17b6f35c4f49e191044a2701f67fa555 to your computer and use it in GitHub Desktop.
A websocket server + client implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::sync::Arc ; | |
use tokio::sync::{Mutex,RwLock}; | |
mod websocket_messages; | |
use websocket_messages::{ | |
SocketMessage, | |
SocketMessageDestination, | |
InboundMessage, | |
OutboundMessage | |
}; | |
/* | |
This will start the websocket server. | |
You can use a utility such as 'websocat' to send messages to the server. | |
*/ | |
#[tokio::main] | |
async fn main() -> std::io::Result<()> { | |
let websocket_client = Arc::new( Mutex::new( WebsocketClient::new() ) ); | |
let websocket_server = Arc::new( Mutex::new( WebsocketServer::new() ) ); | |
let server_url:String = "localhost:9000".to_string(); | |
websocket_server.lock().await.start_in_thread(Some(server_url)); | |
client_socket_conn.lock().await.connect("ws://localhost:9000".to_string()).await; | |
if let Err(e) = client_socket_conn { | |
println!("Error connectiong to socket server {}", e); | |
}else { | |
println!("Connected to socket server"); | |
let msg = SocketMessage::Text("hello world".to_string()); | |
client_socket_conn.lock().await.send_message( msg ).await | |
} | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures_util::{ StreamExt, SinkExt}; | |
use tokio_tungstenite::{connect_async, tungstenite::Message}; | |
use tokio_tungstenite::{WebSocketStream,MaybeTlsStream}; | |
use tokio::net::TcpStream; | |
use std::sync::Arc ; | |
use tokio::sync::{Mutex}; | |
use std::thread; | |
use tokio::runtime::Runtime; | |
use crossbeam_channel::{ Receiver, Sender}; | |
use crate::util::websocket_messages::SocketMessage; | |
use super::websocket_messages::InboundMessage; | |
pub struct Connection { | |
write: futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>, | |
read: Option< futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>> > , //can be used like a one-time mutex ! | |
pub socket_connection_uuid: String | |
} | |
impl Connection { | |
/* | |
Consumes the single read stream and starts a new loop which continuously forwards received packets into a crossbeam channel | |
*/ | |
pub fn start_listening_on_new_thread( | |
&mut self, | |
sender_channel: Sender<InboundMessage>, | |
) { | |
let mut read = self.read.take().expect("The read stream has already been consumed."); | |
let socket_connection_uuid = self.socket_connection_uuid.clone(); | |
// Start a new OS thread | |
thread::spawn(move || { | |
// Create a new Tokio runtime | |
let rt = Runtime::new().unwrap(); | |
// Use the runtime | |
rt.block_on(async { | |
while let Some(message_result) = read.next().await { | |
match message_result { | |
Ok(message) => { | |
let inbound_msg = InboundMessage { | |
socket_connection_uuid: socket_connection_uuid.clone(), | |
message: SocketMessage::from_message(message), | |
}; | |
// Send the message into the crossbeam channel | |
sender_channel.send(inbound_msg).unwrap(); | |
} | |
Err(e) => { | |
eprintln!("Error while reading message: {:?}", e); | |
break; | |
} | |
} | |
} | |
//if stops looping then somehow notify self that we are disconnected / not listening ? | |
}); | |
}); | |
} | |
pub async fn send_message(&mut self, message: SocketMessage ) | |
{ | |
println!("sending message out of conn"); | |
let send_msg_result = self.write.send( message.to_message() ).await ; | |
println!("tried to send a msg out of websocket client conn "); | |
} | |
} | |
pub struct WebsocketClient{ | |
pub connection: Option<Connection>, | |
} | |
impl WebsocketClient { | |
pub fn new() -> Self { | |
Self { | |
connection: None, | |
} | |
} | |
pub async fn connect(&mut self, connect_addr: String ) -> std::io::Result<()> { | |
let url = url::Url::parse(&connect_addr).unwrap(); | |
loop { | |
match connect_async(url.clone()).await { | |
Ok((ws_stream, _)) => { | |
println!("WebSocket handshake has been successfully completed"); | |
let (write, read) = ws_stream.split(); | |
let socket_connection_uuid = uuid::Uuid::new_v4().to_string(); | |
self.connection = Some( Connection { | |
write, | |
read : Some(read), | |
socket_connection_uuid | |
}); | |
// once connected, break the loop | |
break; | |
}, | |
Err(e) => { | |
println!("Failed to connect, retrying in 1 second..."); | |
// wait for 1 second | |
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; | |
} | |
} | |
} | |
Ok(()) | |
} | |
pub fn listen_on_new_thread(&mut self, sender_channel: Sender<InboundMessage>){ | |
match &mut self.connection { | |
Some(conn) => {conn.start_listening_on_new_thread(sender_channel) } | |
None => { | |
println!("Could not start listening! No connection :( ") | |
} | |
} | |
} | |
pub async fn send_message(&mut self, message: SocketMessage ) | |
{ | |
match &mut self.connection { | |
Some(conn) => { conn.send_message(message).await } | |
None => { | |
println!("Could not send message! No connection :( ") | |
} | |
} | |
} | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use serde::{Serialize,Deserialize}; | |
use tokio_tungstenite::tungstenite::Message; | |
use uuid; | |
//create custom message types for your app to serialize and send through. See struct 'WrappedMessage' | |
//use shared::net::message_types::{ ClientMessage, ServerMessage }; | |
//this is essentially the same as 'message' but we can use serde directives on it ( a hack -- there might be a better way ) | |
#[derive(Serialize,Deserialize,Clone)] | |
pub enum SocketMessage { | |
Text(String), | |
Binary(Vec<u8>), | |
Unknown, | |
Close | |
} | |
impl SocketMessage { | |
pub fn from_message(msg: Message) -> Self { | |
match msg { | |
Message::Text(inner) => SocketMessage::Text(inner), | |
Message::Binary(inner) => SocketMessage::Binary(inner.into_iter().collect()), | |
Message::Close(_) => SocketMessage::Close, | |
_ => SocketMessage::Unknown, | |
} | |
} | |
pub fn to_message(&self) -> Message{ | |
match self { | |
SocketMessage::Text(inner) => Message::Text(inner.to_string()), | |
SocketMessage::Binary(inner) => Message::Binary(inner.to_vec()), | |
_ => Message::Text("Unknown!".to_string()) | |
} | |
} | |
//should throw an error instead ! | |
pub fn to_string(&self) -> String{ | |
match self { | |
SocketMessage::Text(inner) => inner.to_string(), | |
SocketMessage::Binary(inner) => format!("{:?}",inner), | |
_ => "Unknown!".to_string() | |
} | |
} | |
} | |
#[derive(Serialize, Deserialize,Debug ,Clone)] | |
pub enum SocketMessageDestination { | |
All, | |
ClientConnection(String), //client connection uuid | |
Room(String), | |
} | |
#[derive(Serialize, Deserialize,Debug ,Clone)] | |
pub enum WrappedMessageDestination { | |
All, | |
Client(String), //client uuid | |
Room(String), | |
ResponseToMsg(String), //message uuid | |
Server //EcosystemServer) //server type | |
} | |
#[derive(Serialize,Deserialize,Clone)] | |
pub struct OutboundMessage { | |
pub destination: SocketMessageDestination, | |
pub message: SocketMessage | |
} | |
#[derive(Serialize,Deserialize,Clone)] | |
pub struct InboundMessage { | |
pub socket_connection_uuid: String, | |
pub message: SocketMessage | |
} | |
impl InboundMessage { | |
pub fn new(socket_connection_uuid:String, msg:Message) -> Self { | |
let message = SocketMessage::from_message(msg); //text( msg.clone().into_text().unwrap() ) ; | |
Self{ | |
socket_connection_uuid, | |
message | |
} | |
} | |
} | |
/* | |
#[derive(Serialize, Deserialize,Debug ,Clone)] | |
pub struct WrappedMessage { | |
pub destination: WrappedMessageDestination, | |
pub message_uuid: String, //used so we can respond to it | |
pub contents: WrappedMessageContents // the contents and the From info | |
} | |
impl WrappedMessage { | |
pub fn from_stringified( raw_msg_string:String ) -> Result<Self, serde_json::Error> { | |
let message:WrappedMessage = serde_json::from_str( &raw_msg_string )?; | |
Ok(message) | |
} | |
pub fn to_stringified(&self) -> Result<String, serde_json::Error> { | |
let message_string = serde_json::to_string(&self)?; | |
Ok(message_string) | |
} | |
pub fn wrap( | |
destination:WrappedMessageDestination, | |
contents:WrappedMessageContents | |
) -> Self { | |
let message_uuid = uuid::Uuid::new_v4().to_string(); | |
let wrapped_message = WrappedMessage{ | |
message_uuid, | |
destination, | |
contents | |
}; | |
wrapped_message | |
} | |
} | |
#[derive(Serialize, Deserialize,Debug ,Clone)] | |
#[serde(tag = "msg_type", content = "data")] | |
pub enum WrappedMessageContents { | |
ClientMsg(ClientMessage), | |
ServerMsg(ServerMessage) | |
} | |
impl WrappedMessageContents { | |
pub fn from_stringified( raw_msg_string:String ) -> Result<Self, serde_json::Error> { | |
let message:WrappedMessageContents = serde_json::from_str( &raw_msg_string )?; | |
Ok(message) | |
} | |
} | |
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures_util::StreamExt; | |
use futures_util::stream::SplitSink; | |
use futures_util::future::join_all; | |
use tokio_tungstenite::WebSocketStream; | |
use tokio::net::{TcpListener, TcpStream}; | |
use futures::SinkExt; | |
use std::collections::HashMap; | |
use std::thread; | |
use tokio::sync::RwLock; | |
use std::sync::{Arc}; | |
use tokio::sync:: Mutex; | |
use tokio_tungstenite::tungstenite::Message; | |
use shared::util::rand::generate_random_uuid; | |
use std::collections::HashSet; | |
use crossbeam_channel::{ unbounded, Receiver, Sender, TryRecvError}; | |
use super::websocket_messages::{ | |
SocketMessage, | |
SocketMessageDestination, | |
InboundMessage, | |
OutboundMessage | |
}; | |
type ClientsMap = Arc<RwLock<HashMap<String, ClientConnection>>>; | |
type RoomsMap = Arc<RwLock<HashMap<String, HashSet<String>>>>; | |
type TxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>; | |
type RxSink = Arc<Mutex<SplitSink<WebSocketStream<TcpStream>, Message>>>; | |
#[derive(Clone)] | |
pub struct ClientConnection { | |
pub client_socket_uuid: String, | |
pub addr: String, | |
pub tx_sink: TxSink, | |
} | |
impl ClientConnection { | |
pub fn new( addr:String, client_tx: SplitSink<WebSocketStream<tokio::net::TcpStream>, Message> ) -> Self{ | |
Self { | |
client_socket_uuid: generate_random_uuid(), | |
addr: addr.clone(), | |
tx_sink: Arc::new(Mutex::new( client_tx )) | |
} | |
} | |
pub async fn send_message(&self, msg: Message) -> Result<(), tokio_tungstenite::tungstenite::error::Error> { | |
self.tx_sink.lock().await.send(msg).await | |
} | |
} | |
pub struct WebsocketServer{ | |
clients: ClientsMap, | |
rooms: RoomsMap, // room name -> Set[client_uuid] | |
//let (sender, receiver): (Sender<T>, Receiver<T>) = unbounded(); | |
global_recv_tx: Sender<InboundMessage>, //passed to each client connection | |
global_recv_rx: Receiver<InboundMessage>, | |
global_send_tx: Sender<OutboundMessage>, | |
global_send_rx: Receiver<OutboundMessage>, | |
} | |
impl WebsocketServer { | |
pub fn new() -> Self { | |
let (global_recv_tx, global_recv_rx): (Sender<InboundMessage>, Receiver<InboundMessage>) = unbounded(); | |
let (global_send_tx, global_send_rx): (Sender<OutboundMessage>, Receiver<OutboundMessage>) = unbounded(); | |
Self { | |
clients: Arc::new(RwLock::new(HashMap::new())), | |
rooms: Arc::new(RwLock::new(HashMap::new())), | |
global_recv_tx, | |
global_recv_rx, | |
global_send_tx, | |
global_send_rx, | |
} | |
} | |
pub fn start_in_thread(&mut self, url: Option<String>) -> | |
std::io::Result< std::thread::JoinHandle<()> > { | |
let clients = Arc::clone(&self.clients); | |
let rooms = Arc::clone(&self.rooms); | |
let global_recv_channel = self.global_recv_tx.clone(); | |
let global_send_channel = self.global_send_rx.clone(); | |
let accept_connections_thread = thread::spawn(move || { //use a non-tokio thread here | |
let runtime = tokio::runtime::Runtime::new().unwrap(); | |
runtime.block_on(async { | |
let addr: String = url.unwrap_or_else(|| "127.0.0.1:8080".to_string()); | |
// Create the event loop and TCP listener we'll accept connections on. | |
let try_socket = TcpListener::bind(&addr).await; | |
let listener = try_socket.expect("Failed to bind"); | |
println!("Listening on: {}", addr); | |
let accept_connections = Self::try_accept_new_connections( Arc::clone(&clients), listener,global_recv_channel ); | |
let send_outbound_messages = Self::try_send_outbound_messages( | |
Arc::clone(&clients), | |
Arc::clone(&rooms), | |
global_send_channel | |
); | |
tokio::try_join!(accept_connections,send_outbound_messages); | |
}); | |
}); | |
println!("Started websocket server"); | |
Ok( accept_connections_thread ) | |
} | |
pub async fn start(&mut self, url:Option<String>) -> std::io::Result<()> { | |
let clients = Arc::clone(&self.clients); | |
let rooms = Arc::clone(&self.rooms); | |
let global_recv_channel = self.global_recv_tx.clone(); | |
let global_send_channel = self.global_send_rx.clone(); | |
let addr: String = url.unwrap_or_else(|| "127.0.0.1:8080".to_string()); | |
// Create the event loop and TCP listener we'll accept connections on. | |
let try_socket = TcpListener::bind(&addr).await; | |
let listener = try_socket.expect("Failed to bind"); | |
println!("Listening on: {}", addr); | |
let accept_connections = Self::try_accept_new_connections( Arc::clone(&clients), listener,global_recv_channel ); | |
let send_outbound_messages = Self::try_send_outbound_messages( | |
Arc::clone(&clients) , | |
Arc::clone(&rooms), | |
global_send_channel | |
); | |
tokio::try_join!(accept_connections, send_outbound_messages); | |
Ok(()) | |
} | |
//recv'd client messages are fed into here | |
pub fn get_recv_channel(&self) -> Receiver<InboundMessage> { | |
self.global_recv_rx.clone() | |
} | |
pub fn get_send_channel(&self) -> Sender<OutboundMessage> { | |
self.global_send_tx.clone() | |
} | |
pub async fn send_outbound_message(&self, msg:OutboundMessage) { | |
Self::broadcast( | |
Arc::clone(&self.clients), | |
Arc::clone(&self.rooms), | |
msg ).await; | |
} | |
async fn get_cloned_clients(clients: &ClientsMap) -> Vec<ClientConnection> { | |
let clients_map = clients.read().await; | |
clients_map.values().cloned().collect() | |
} | |
async fn get_cloned_clients_in_room(clients: &ClientsMap, rooms: &RoomsMap, room_name: String ) -> Vec<ClientConnection> { | |
let client_connection_uuids = Vec::new(); | |
let rooms = rooms.read().await; | |
match rooms.get(&room_name) { | |
Some(uuid_set) => {} | |
None => {} | |
} | |
return Self::get_cloned_clients_filtered(clients, client_connection_uuids).await; | |
} | |
async fn get_cloned_clients_filtered(clients: &ClientsMap, client_connection_uuids: Vec<String> ) -> Vec<ClientConnection> { | |
let clients_map = clients.read().await; | |
let mut filtered_clients: Vec<ClientConnection> = Vec::new(); | |
for uuid in client_connection_uuids { | |
if let Some(client_conn) = clients_map.get(&uuid) { | |
filtered_clients.push(client_conn.clone()); | |
} | |
} | |
filtered_clients | |
} | |
async fn get_cloned_client_specific(clients: &ClientsMap, client_connection_uuid: String ) -> Vec<ClientConnection> { | |
let clients_map = clients.read().await; | |
let mut filtered_clients: Vec<ClientConnection> = Vec::new(); | |
if let Some(client_conn) = clients_map.get(&client_connection_uuid) { | |
filtered_clients.push(client_conn.clone()); | |
} | |
filtered_clients | |
} | |
pub async fn try_send_outbound_messages( | |
clients_map: ClientsMap, | |
rooms_map: RoomsMap, | |
global_send_rx: Receiver<OutboundMessage> | |
) -> std::io::Result<()> { | |
loop { | |
match global_send_rx.try_recv() { | |
Ok(msg) => { | |
// let message = msg; | |
let clients_map = Arc::clone(&clients_map); | |
let rooms_map = Arc::clone(&rooms_map); | |
println!("try send outbound message 2 " ); | |
Self::broadcast(clients_map, rooms_map, msg).await; | |
} | |
Err(TryRecvError::Empty) => { | |
// No messages available right now, sleep for a short duration | |
tokio::time::sleep(std::time::Duration::from_millis(100)).await; | |
} | |
Err(TryRecvError::Disconnected) => break, | |
} | |
} | |
Ok(()) | |
} | |
pub async fn add_client_to_room(&self, client_connection_uuid:String, room_name: String ) { | |
let mut rooms = self.rooms.write().await; | |
let room_clients = rooms.entry(room_name).or_insert_with(HashSet::new); | |
room_clients.insert(client_connection_uuid); | |
} | |
pub async fn remove_client_from_room(&self, client_connection_uuid:String, room_name: String ) { | |
let mut rooms = self.rooms.write().await; | |
if let Some(room_clients) = rooms.get_mut(&room_name) { | |
room_clients.remove(&client_connection_uuid); | |
// Optionally, you can remove the room if it's now empty | |
if room_clients.is_empty() { | |
rooms.remove(&room_name); | |
} | |
} | |
} | |
pub async fn broadcast( | |
clients_map: ClientsMap, | |
rooms_map:RoomsMap, | |
outbound_message: OutboundMessage | |
) { | |
println!("broadcasting msg: {} ", outbound_message.message.to_string() ); | |
let socket_message = outbound_message.message; | |
let client_connections = match outbound_message.destination { | |
SocketMessageDestination::All => Self::get_cloned_clients(&clients_map).await, | |
SocketMessageDestination::Room(room_name) => Self::get_cloned_clients_in_room(&clients_map,&rooms_map,room_name).await, | |
SocketMessageDestination::ClientConnection(client_connection_uuid) => Self::get_cloned_client_specific(&clients_map,client_connection_uuid).await, | |
// MessageDestination::ResponseToMsg(msg_uuid) => {}, | |
// MessageDestination::Server => {} | |
}; | |
Self::broadcast_to_connections(client_connections, socket_message).await; | |
} | |
pub async fn broadcast_to_connections( connections: Vec<ClientConnection>, socket_message: SocketMessage) { | |
let message = socket_message.to_message(); | |
//Could cause thread lock issue !? | |
let send_futures: Vec<_> = { | |
connections | |
.iter() | |
.map(|client| { | |
let message = message.clone(); | |
client.send_message(message) | |
}) | |
.collect() | |
}; | |
let results = join_all(send_futures).await; | |
for result in results { | |
if let Err(err) = result { | |
eprintln!("Failed to send a message: {}", err); | |
} | |
} | |
} | |
pub async fn try_accept_new_connections( | |
clients_map: ClientsMap, | |
listener: TcpListener, | |
global_recv_tx: Sender<InboundMessage> | |
) -> std::io::Result<()> { | |
while let Ok((stream, _)) = listener.accept().await { | |
let clients_map = Arc::clone(&clients_map); | |
tokio::spawn(Self::accept_connection(clients_map, stream, global_recv_tx.clone())); | |
} | |
Ok(()) | |
} | |
async fn accept_connection( | |
clients: ClientsMap, | |
raw_stream: TcpStream, | |
global_socket_tx: Sender<InboundMessage> | |
) { | |
let addr = raw_stream | |
.peer_addr() | |
.expect("connected streams should have a peer address") | |
.to_string(); | |
let ws_stream = tokio_tungstenite::accept_async(raw_stream) | |
.await | |
.expect("Error during the websocket handshake occurred"); | |
println!("New WebSocket connection: {}", addr); | |
let ( client_tx, mut client_rx) = ws_stream.split(); //this is how i can read and write to this client | |
let new_client_connection = ClientConnection::new( addr.clone(), client_tx ); | |
let client_uuid = new_client_connection.client_socket_uuid.clone(); | |
clients.write().await.insert( | |
new_client_connection.client_socket_uuid.clone(), | |
new_client_connection | |
); | |
//in this new thread for the socket connection, recv'd messages are constantly collected | |
while let Some(msg) = client_rx.next().await { | |
match msg { | |
Ok(msg) => { | |
if msg.is_text() || msg.is_binary() { | |
let data = msg.clone().into_data(); | |
println!("Received a message from {}: {:?}", addr, data); | |
// here you can consume your messages | |
let client_msg = InboundMessage::new( | |
client_uuid.clone(), | |
msg | |
); | |
global_socket_tx.send( client_msg ); | |
} | |
} | |
Err(e) => { | |
eprintln!( | |
"an error occurred while processing incoming messages: {:?}", | |
e | |
); | |
break; | |
} | |
} | |
} | |
// Remove the client from the map once it has disconnected. | |
clients.write().await.remove(&addr); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment