Created
May 23, 2024 11:02
-
-
Save Pzixel/3928af2b3dce9baf2ece569beedc14a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::sync::atomic::AtomicBool; | |
use std::sync::Arc; | |
use std::sync::Mutex; | |
use std::task::Waker; | |
use std::time::Duration; | |
use futures_util::SinkExt; | |
use futures_util::StreamExt; | |
use reqwest::header::HeaderName; | |
use tokio_tungstenite; | |
use tokio_tungstenite::tungstenite::client::IntoClientRequest; | |
#[derive(serde::Deserialize, Debug, Clone)] | |
pub struct Update { | |
pub data: u32 | |
} | |
#[derive(serde::Deserialize, Debug, Clone)] | |
pub struct Config { | |
pub pricing_url: String, | |
pub auth_name: String, | |
pub api_password: String, | |
} | |
fn get_stream(config: Config, ping_interval: Duration, pong_timout: Duration, reconnect_delay: Duration) -> impl futures_core::stream::Stream<Item = Update> { | |
async_stream::stream! { | |
// async { | |
loop { | |
let mut request = config.pricing_url.as_str().into_client_request().unwrap(); | |
request.headers_mut().insert( | |
HeaderName::from_static("name"), | |
config.auth_name.parse().unwrap(), | |
); | |
request.headers_mut().insert( | |
reqwest::header::AUTHORIZATION, | |
config.api_password.parse().unwrap(), | |
); | |
let ws_stream = match tokio_tungstenite::connect_async(request).await { | |
Ok((ws_stream, _)) => ws_stream, | |
Err(e) => { | |
tracing::warn!("[wss] Cannot connect to RPC endpoint '{}': {}. Waiting and reconnecting", &config.pricing_url, e); | |
tokio::time::sleep(reconnect_delay).await; | |
continue; | |
} | |
}; | |
tracing::info!("[] Connected to {}", config.pricing_url); | |
let (mut write, mut read) = ws_stream.split(); | |
let should_keep_pinging = Arc::new(AtomicBool::new(true)); | |
{ | |
let should_keep_pinging = Arc::clone(&should_keep_pinging); | |
tokio::spawn({ | |
async move { | |
while should_keep_pinging.load(std::sync::atomic::Ordering::Relaxed) { | |
tokio::time::sleep(ping_interval).await; | |
if let Err(e) = write.send(tokio_tungstenite::tungstenite::Message::Ping(Vec::new())).await { | |
tracing::error!("[] Error sending ping for subscription: {}. Expecting disconnect to be discovered in {}s", e, pong_timout.as_secs_f32()); | |
} | |
}; | |
} | |
}); | |
} | |
loop { | |
let msg = match tokio::time::timeout(pong_timout, read.next()).await { | |
Ok(Some(Ok(msg))) => msg, | |
Ok(Some(Err(e))) => { | |
tracing::warn!("[] Error reading from stream: {}", e); | |
break; | |
} | |
Ok(None) => { | |
tracing::warn!("[] Stream closed"); | |
break; | |
} | |
Err(e) => { | |
tracing::warn!("[] Pong timeout of {}s has been elasped: {}", pong_timout.as_secs_f32(), e); | |
break; | |
} | |
}; | |
let msg = match msg { | |
tokio_tungstenite::tungstenite::Message::Text(x) => x, | |
tokio_tungstenite::tungstenite::Message::Ping(_) => { | |
tracing::info!("[] Received ping"); | |
// TODO: send pong back? | |
continue; | |
}, | |
tokio_tungstenite::tungstenite::Message::Pong(_) => { | |
tracing::info!("[] Received pong"); | |
continue; | |
} | |
x => { | |
tracing::warn!("[] Received unexpected message: {:?}. Breaking listening", x); | |
break; | |
} | |
}; | |
let Ok(response) = serde_json::from_str::<Update>(&msg) else { | |
tracing::error!("[] Error parsing following response as NewHeadsResponse: {}", msg); | |
continue; | |
}; | |
yield response; | |
} | |
should_keep_pinging.store(false, std::sync::atomic::Ordering::Relaxed); | |
tracing::info!("[] Reconnecting"); | |
} | |
} | |
} | |
pub struct BufferedMessages { | |
future: tokio::task::JoinHandle<()>, | |
shared_state: Arc<Mutex<SharedState>>, | |
} | |
struct SharedState { | |
buffer: Vec<Update>, | |
waker: Option<Waker>, | |
} | |
impl BufferedMessages { | |
pub fn new(config: Config, ping_interval: Duration, pong_timout: Duration, reconnect_delay: Duration) -> Self { | |
let shared_state = Arc::new(Mutex::new(SharedState { | |
buffer: Vec::new(), | |
waker: None, | |
})); | |
let future = tokio::spawn({ | |
let shared_state = Arc::clone(&shared_state); | |
async move { | |
let mut stream = Box::pin(get_stream(config, ping_interval, pong_timout, reconnect_delay)); | |
while let Some(x) = stream.next().await { | |
let mut shared_state = shared_state.lock().unwrap(); | |
shared_state.buffer.push(x); | |
if let Some(waker) = shared_state.waker.take() { | |
waker.wake() | |
} | |
} | |
} | |
}); | |
Self { | |
future, | |
shared_state, | |
} | |
} | |
} | |
impl futures_core::stream::Stream for BufferedMessages { | |
type Item = Vec<Update>; | |
fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> { | |
let mut shared_state = self.shared_state.lock().unwrap(); | |
if shared_state.buffer.is_empty() { | |
shared_state.waker = Some(cx.waker().clone()); | |
std::task::Poll::Pending | |
} else { | |
std::task::Poll::Ready(Some(std::mem::take( &mut shared_state.buffer))) | |
} | |
} | |
} | |
impl Drop for BufferedMessages { | |
fn drop(&mut self) { | |
tracing::info!("[] Dropping BufferedMessages"); | |
self.future.abort(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment