Skip to content

Instantly share code, notes, and snippets.

@hexcowboy
Last active October 13, 2024 01:37
Show Gist options
  • Save hexcowboy/8ebcf13a5d3b681aa6c684ad51dd6e0c to your computer and use it in GitHub Desktop.
Save hexcowboy/8ebcf13a5d3b681aa6c684ad51dd6e0c to your computer and use it in GitHub Desktop.

Is there a way to share SplitSink between two different threads?

Yes

The best way to do this is to create an mpsc channel that forwards messages to the SplitSink. In the example below you can see that multiple threads can send to the sink using sender.clone(). Both send_task and recv_task are doing this, and in theory you can make as many senders as you like.

use std::net::SocketAddr;
use std::sync::Arc;

use axum::extract::ws::Message;
use axum::extract::WebSocketUpgrade;
use axum::extract::{ws::WebSocket, State};
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use futures::{sink::SinkExt, stream::StreamExt};
use tokio::sync::{broadcast, mpsc};

struct AppState {
    // channel used to send messages to all connected clients
    tx: broadcast::Sender<String>,
}

impl Default for AppState {
    fn default() -> Self {
        let (tx, _) = broadcast::channel(16);
        Self { tx }
    }
}

#[tokio::main]
async fn main() {
    // Set up application state for use with with_state().
    let app = Router::new()
        .route("/ws", get(websocket_handler))
        .with_state(Arc::new(AppState::default()));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    println!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn websocket_handler(
    ws: WebSocketUpgrade,
    State(state): State<Arc<AppState>>,
) -> impl IntoResponse {
    ws.on_upgrade(|socket| websocket(socket, state))
}

async fn websocket(stream: WebSocket, state: Arc<AppState>) {
    // split the websocket stream into a sender (sink) and receiver (stream)
    let (mut sink, mut stream) = stream.split();
    // create an mpsc so we can send messages to the sink from multiple threads
    let (sender, mut receiver) = mpsc::channel::<String>(16);

    // spawn a task that forwards messages from the mpsc to the sink
    tokio::spawn(async move {
        while let Some(message) = receiver.recv().await {
            if sink.send(message.into()).await.is_err() {
                break;
            }
        }
    });

    // subscribe to the chat channel
    let mut rx_chat = state.tx.subscribe();

    // whenever a chat is sent to rx_chat, forward it to the mpsc
    let send_task_sender = sender.clone();
    let mut send_task = tokio::spawn(async move {
        while let Ok(msg) = rx_chat.recv().await {
            if send_task_sender
                .send(format!("New message: {}", msg))
                .await
                .is_err()
            {
                break;
            }
        }
    });

    // clone the tx channel so we can send messages to it
    let tx_chat = state.tx.clone();

    // whenever a user sends a chat, send it to the tx_chat
    let recv_task_sender = sender.clone();
    let mut recv_task = tokio::spawn(async move {
        while let Some(Ok(Message::Text(text))) = stream.next().await {
            let _ = tx_chat.send(format!("{}", text));
            if recv_task_sender
                .send(String::from("Your message has been sent"))
                .await
                .is_err()
            {
                break;
            }
        }
    });

    tokio::select! {
        _ = (&mut send_task) => recv_task.abort(),
        _ = (&mut recv_task) => send_task.abort(),
    };
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment