Is there a way to share
SplitSink
between two different threads?
Yes
The best way to do this is to create an mpsc
channel that forwards messages to the SplitSink
. In the example below you can see that multiple threads can send to the sink using sender.clone()
. Both send_task
and recv_task
are doing this, and in theory you can make as many senders as you like.
use std::net::SocketAddr;
use std::sync::Arc;
use axum::extract::ws::Message;
use axum::extract::WebSocketUpgrade;
use axum::extract::{ws::WebSocket, State};
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use futures::{sink::SinkExt, stream::StreamExt};
use tokio::sync::{broadcast, mpsc};
struct AppState {
// channel used to send messages to all connected clients
tx: broadcast::Sender<String>,
}
impl Default for AppState {
fn default() -> Self {
let (tx, _) = broadcast::channel(16);
Self { tx }
}
}
#[tokio::main]
async fn main() {
// Set up application state for use with with_state().
let app = Router::new()
.route("/ws", get(websocket_handler))
.with_state(Arc::new(AppState::default()));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
// split the websocket stream into a sender (sink) and receiver (stream)
let (mut sink, mut stream) = stream.split();
// create an mpsc so we can send messages to the sink from multiple threads
let (sender, mut receiver) = mpsc::channel::<String>(16);
// spawn a task that forwards messages from the mpsc to the sink
tokio::spawn(async move {
while let Some(message) = receiver.recv().await {
if sink.send(message.into()).await.is_err() {
break;
}
}
});
// subscribe to the chat channel
let mut rx_chat = state.tx.subscribe();
// whenever a chat is sent to rx_chat, forward it to the mpsc
let send_task_sender = sender.clone();
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx_chat.recv().await {
if send_task_sender
.send(format!("New message: {}", msg))
.await
.is_err()
{
break;
}
}
});
// clone the tx channel so we can send messages to it
let tx_chat = state.tx.clone();
// whenever a user sends a chat, send it to the tx_chat
let recv_task_sender = sender.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = stream.next().await {
let _ = tx_chat.send(format!("{}", text));
if recv_task_sender
.send(String::from("Your message has been sent"))
.await
.is_err()
{
break;
}
}
});
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
}