Created
March 15, 2025 17:16
-
-
Save jb55/82f4bd23b458311df6666e305ceb5cd3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 6df138a2c634f5dbdb5a131add00bb1bb5118756 | |
Author: William Casarin <[email protected]> | |
Date: Mon Dec 9 16:39:37 2024 -0800 | |
async: adding efficient, poll-based stream support | |
This is a much more efficient, polling-based stream implementation that | |
doesn't rely on horrible things like spawning threads just to do async. | |
Changelog-Added: Add async stream support | |
Signed-off-by: William Casarin <[email protected]> | |
diff --git a/Cargo.toml b/Cargo.toml | |
index f6870adc3075..7569fb171511 100644 | |
--- a/Cargo.toml | |
+++ b/Cargo.toml | |
@@ -21,6 +21,7 @@ bindgen = [] | |
flatbuffers = "23.5.26" | |
libc = "0.2.151" | |
thiserror = "2.0.7" | |
+futures = "0.3.31" | |
tokio = { version = "1", features = ["rt-multi-thread", "macros"] } | |
tracing = "0.1.40" | |
tracing-subscriber = "0.3.18" | |
diff --git a/nostrdb b/nostrdb | |
index 3260fa14639c..423598b0f747 160000 | |
--- a/nostrdb | |
+++ b/nostrdb | |
@@ -1 +1 @@ | |
-Subproject commit 3260fa14639cf2adfec69b6a2bb000047f038e18 | |
+Subproject commit 423598b0f747920369a8625d9aca5298b8e6aa59 | |
diff --git a/src/config.rs b/src/config.rs | |
index 6c5d889124c2..cf0f566449f8 100644 | |
--- a/src/config.rs | |
+++ b/src/config.rs | |
@@ -3,8 +3,6 @@ use crate::bindings; | |
#[derive(Copy, Clone)] | |
pub struct Config { | |
pub config: bindings::ndb_config, | |
- // We add a flag to know if we've installed a Rust closure so we can clean it up in Drop. | |
- is_rust_closure: bool, | |
} | |
impl Default for Config { | |
@@ -29,11 +27,7 @@ impl Config { | |
bindings::ndb_default_config(&mut config); | |
} | |
- let is_rust_closure = false; | |
- Config { | |
- config, | |
- is_rust_closure, | |
- } | |
+ Config { config } | |
} | |
// | |
@@ -54,7 +48,8 @@ impl Config { | |
self | |
} | |
- /// Set a callback for when we have | |
+ /// Set a callback to be notified on updated subscriptions. The function | |
+ /// will be called with the corresponsing subscription id. | |
pub fn set_sub_callback<F>(mut self, closure: F) -> Self | |
where | |
F: FnMut(u64) + 'static, | |
@@ -67,7 +62,6 @@ impl Config { | |
self.config.sub_cb = Some(sub_callback_trampoline); | |
self.config.sub_cb_ctx = ctx_ptr; | |
- self.is_rust_closure = true; | |
self | |
} | |
diff --git a/src/future.rs b/src/future.rs | |
new file mode 100644 | |
index 000000000000..08d3c61f0155 | |
--- /dev/null | |
+++ b/src/future.rs | |
@@ -0,0 +1,87 @@ | |
+use crate::{Ndb, NoteKey, Subscription}; | |
+ | |
+use std::{ | |
+ pin::Pin, | |
+ task::{Context, Poll}, | |
+}; | |
+ | |
+use futures::Stream; | |
+ | |
+/// Used to track query futures | |
+#[derive(Debug, Clone)] | |
+pub(crate) struct SubscriptionState { | |
+ pub ready: bool, | |
+ pub done: bool, | |
+ pub waker: Option<std::task::Waker>, | |
+} | |
+ | |
+/// A subscription that you can .await on. This can enables very clean | |
+/// integration into Rust's async state machinery. | |
+pub struct SubscriptionStream { | |
+ // some handle or state | |
+ // e.g., a reference to a non-blocking API or a shared atomic state | |
+ ndb: Ndb, | |
+ sub_id: Subscription, | |
+ max_notes: u32, | |
+} | |
+ | |
+impl SubscriptionStream { | |
+ pub fn new(ndb: Ndb, sub_id: Subscription) -> Self { | |
+ // Most of the time we only want to fetch a few things. If expecting | |
+ // lots of data, use `set_max_notes_per_await` | |
+ let max_notes = 32; | |
+ SubscriptionStream { | |
+ ndb, | |
+ sub_id, | |
+ max_notes, | |
+ } | |
+ } | |
+ | |
+ pub fn notes_per_await(mut self, max_notes: u32) -> Self { | |
+ self.max_notes = max_notes; | |
+ self | |
+ } | |
+ | |
+ pub fn sub_id(&self) -> Subscription { | |
+ self.sub_id | |
+ } | |
+} | |
+ | |
+impl Drop for SubscriptionStream { | |
+ fn drop(&mut self) { | |
+ // Perform cleanup here, like removing the subscription from the global map | |
+ let mut map = self.ndb.subs.lock().unwrap(); | |
+ map.remove(&self.sub_id); | |
+ } | |
+} | |
+ | |
+impl Stream for SubscriptionStream { | |
+ type Item = Vec<NoteKey>; | |
+ | |
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | |
+ let pinned = std::pin::pin!(self); | |
+ let me = pinned.as_ref().get_ref(); | |
+ let mut map = me.ndb.subs.lock().unwrap(); | |
+ let sub_state = map.entry(me.sub_id).or_insert(SubscriptionState { | |
+ ready: false, | |
+ done: false, | |
+ waker: None, | |
+ }); | |
+ | |
+ // we've unsubscribed | |
+ if sub_state.done { | |
+ return Poll::Ready(None); | |
+ } | |
+ | |
+ if sub_state.ready { | |
+ // Reset ready, fetch notes | |
+ sub_state.ready = false; | |
+ let notes = me.ndb.poll_for_notes(me.sub_id, me.max_notes); | |
+ return Poll::Ready(Some(notes)); | |
+ } | |
+ | |
+ // Not ready yet, store waker | |
+ sub_state.waker = Some(cx.waker().clone()); | |
+ std::task::Poll::Pending | |
+ } | |
+} | |
diff --git a/src/lib.rs b/src/lib.rs | |
index 0b9d15e3d457..aa465c430948 100644 | |
--- a/src/lib.rs | |
+++ b/src/lib.rs | |
@@ -12,6 +12,9 @@ mod bindings; | |
mod ndb_profile; | |
mod block; | |
+ | |
+mod future; | |
+ | |
mod config; | |
mod error; | |
mod filter; | |
@@ -30,6 +33,8 @@ pub use block::{Block, BlockType, Blocks, Mention}; | |
pub use config::Config; | |
pub use error::{Error, FilterError}; | |
pub use filter::{Filter, FilterBuilder}; | |
+pub(crate) use future::SubscriptionState; | |
+pub use future::SubscriptionStream; | |
pub use ndb::Ndb; | |
pub use ndb_profile::{NdbProfile, NdbProfileRecord}; | |
pub use ndb_str::{NdbStr, NdbStrVariant}; | |
diff --git a/src/ndb.rs b/src/ndb.rs | |
index c5ef1f481841..9d76abacf87b 100644 | |
--- a/src/ndb.rs | |
+++ b/src/ndb.rs | |
@@ -3,22 +3,20 @@ use std::ptr; | |
use crate::{ | |
bindings, Blocks, Config, Error, Filter, Note, NoteKey, ProfileKey, ProfileRecord, QueryResult, | |
- Result, Subscription, Transaction, | |
+ Result, Subscription, SubscriptionState, SubscriptionStream, Transaction, | |
}; | |
+use futures::StreamExt; | |
+use std::collections::hash_map::Entry; | |
+use std::collections::HashMap; | |
use std::fs; | |
use std::os::raw::c_int; | |
use std::path::Path; | |
-use std::sync::Arc; | |
-use tokio::task; // Make sure to import the task module | |
+use std::sync::{Arc, Mutex}; | |
use tracing::debug; | |
#[derive(Debug)] | |
struct NdbRef { | |
ndb: *mut bindings::ndb, | |
- | |
- /// Have we configured a rust closure for our callback? If so we need | |
- /// to clean that up when this is dropped | |
- has_rust_closure: bool, | |
rust_cb_ctx: *mut ::std::os::raw::c_void, | |
} | |
@@ -34,7 +32,7 @@ impl Drop for NdbRef { | |
unsafe { | |
bindings::ndb_destroy(self.ndb); | |
- if self.has_rust_closure && !self.rust_cb_ctx.is_null() { | |
+ if !self.rust_cb_ctx.is_null() { | |
// Rebuild the Box from the raw pointer and drop it. | |
let _ = Box::from_raw(self.rust_cb_ctx as *mut Box<dyn FnMut()>); | |
} | |
@@ -42,10 +40,15 @@ impl Drop for NdbRef { | |
} | |
} | |
+type SubMap = HashMap<Subscription, SubscriptionState>; | |
+ | |
/// A nostrdb context. Construct one of these with [Ndb::new]. | |
#[derive(Debug, Clone)] | |
pub struct Ndb { | |
refs: Arc<NdbRef>, | |
+ | |
+ /// Track query future states | |
+ pub(crate) subs: Arc<Mutex<SubMap>>, | |
} | |
impl Ndb { | |
@@ -65,7 +68,30 @@ impl Ndb { | |
let min_mapsize = 1024 * 1024 * 512; | |
let mut mapsize = config.config.mapsize; | |
- let mut config = *config; | |
+ let config = *config; | |
+ | |
+ let prev_callback = config.config.sub_cb; | |
+ let prev_callback_ctx = config.config.sub_cb_ctx; | |
+ let subs = Arc::new(Mutex::new(SubMap::default())); | |
+ let subs_clone = subs.clone(); | |
+ | |
+ // We need to register our own callback so that we can wake | |
+ // query futures | |
+ let mut config = config.set_sub_callback(move |sub_id: u64| { | |
+ let mut map = subs_clone.lock().unwrap(); | |
+ if let Some(s) = map.get_mut(&Subscription::new(sub_id)) { | |
+ s.ready = true; | |
+ if let Some(w) = s.waker.take() { | |
+ w.wake(); | |
+ } | |
+ } | |
+ | |
+ if let Some(pcb) = prev_callback { | |
+ unsafe { | |
+ pcb(prev_callback_ctx, sub_id); | |
+ }; | |
+ } | |
+ }); | |
let result = loop { | |
let result = | |
@@ -90,15 +116,10 @@ impl Ndb { | |
return Err(Error::DbOpenFailed); | |
} | |
- let has_rust_closure = !config.config.sub_cb_ctx.is_null(); | |
let rust_cb_ctx = config.config.sub_cb_ctx; | |
- let refs = Arc::new(NdbRef { | |
- ndb, | |
- has_rust_closure, | |
- rust_cb_ctx, | |
- }); | |
+ let refs = Arc::new(NdbRef { ndb, rust_cb_ctx }); | |
- Ok(Ndb { refs }) | |
+ Ok(Ndb { refs, subs }) | |
} | |
/// Ingest a relay-sent event in the form `["EVENT","subid", {"id:"...}]` | |
@@ -155,9 +176,17 @@ impl Ndb { | |
unsafe { bindings::ndb_num_subscriptions(self.as_ptr()) as u32 } | |
} | |
- pub fn unsubscribe(&self, sub: Subscription) -> Result<()> { | |
+ pub fn unsubscribe(&mut self, sub: Subscription) -> Result<()> { | |
let r = unsafe { bindings::ndb_unsubscribe(self.as_ptr(), sub.id()) }; | |
+ // mark the subscription as done if it exists in our stream map | |
+ { | |
+ let mut map = self.subs.lock().unwrap(); | |
+ if let Entry::Occupied(mut entry) = map.entry(sub) { | |
+ entry.get_mut().done = true; | |
+ } | |
+ } | |
+ | |
if r == 0 { | |
Err(Error::SubscriptionError) | |
} else { | |
@@ -204,32 +233,11 @@ impl Ndb { | |
sub_id: Subscription, | |
max_notes: u32, | |
) -> Result<Vec<NoteKey>> { | |
- let ndb = self.clone(); | |
- let handle = task::spawn_blocking(move || { | |
- let mut vec: Vec<u64> = vec![]; | |
- vec.reserve_exact(max_notes as usize); | |
- let res = unsafe { | |
- bindings::ndb_wait_for_notes( | |
- ndb.as_ptr(), | |
- sub_id.id(), | |
- vec.as_mut_ptr(), | |
- max_notes as c_int, | |
- ) | |
- }; | |
- if res == 0 { | |
- Err(Error::SubscriptionError) | |
- } else { | |
- unsafe { | |
- vec.set_len(res as usize); | |
- }; | |
- Ok(vec) | |
- } | |
- }); | |
+ let mut stream = SubscriptionStream::new(self.clone(), sub_id).notes_per_await(max_notes); | |
- match handle.await { | |
- Ok(Ok(res)) => Ok(res.into_iter().map(NoteKey::new).collect()), | |
- Ok(Err(err)) => Err(err), | |
- Err(_) => Err(Error::SubscriptionError), | |
+ match stream.next().await { | |
+ Some(res) => Ok(res), | |
+ None => Err(Error::SubscriptionError), | |
} | |
} | |
@@ -527,4 +535,40 @@ mod tests { | |
// we should definitely clean this up... especially on windows | |
test_util::cleanup_db(&db); | |
} | |
+ | |
+ #[tokio::test] | |
+ async fn test_stream() { | |
+ let db = "target/testdbs/test_callback"; | |
+ test_util::cleanup_db(&db); | |
+ | |
+ { | |
+ let mut ndb = Ndb::new(db, &Config::new()).expect("ndb"); | |
+ let sub_id = { | |
+ let filter = Filter::new().kinds(vec![1]).build(); | |
+ let filters = vec![filter]; | |
+ | |
+ let sub_id = ndb.subscribe(&filters).expect("sub_id"); | |
+ let mut sub = sub_id.stream(&ndb).notes_per_await(1); | |
+ | |
+ let res = sub.next(); | |
+ | |
+ ndb.process_event(r#"["EVENT","b",{"id": "702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3","pubkey": "32bf915904bfde2d136ba45dde32c88f4aca863783999faea2e847a8fafd2f15","created_at": 1702675561,"kind": 1,"tags": [],"content": "hello, world","sig": "2275c5f5417abfd644b7bc74f0388d70feb5d08b6f90fa18655dda5c95d013bfbc5258ea77c05b7e40e0ee51d8a2efa931dc7a0ec1db4c0a94519762c6625675"}]"#).expect("process ok"); | |
+ | |
+ let res = res.await.expect("await ok"); | |
+ assert_eq!(res, vec![NoteKey::new(1)]); | |
+ | |
+ // ensure that unsubscribing kills the stream | |
+ assert!(ndb.unsubscribe(sub_id).is_ok()); | |
+ assert!(sub.next().await.is_none()); | |
+ | |
+ assert!(ndb.subs.lock().unwrap().contains_key(&sub_id)); | |
+ sub_id | |
+ }; | |
+ | |
+ // ensure subscription state is removed after stream is dropped | |
+ assert!(!ndb.subs.lock().unwrap().contains_key(&sub_id)); | |
+ } | |
+ | |
+ test_util::cleanup_db(&db); | |
+ } | |
} | |
diff --git a/src/subscription.rs b/src/subscription.rs | |
index 8e77d6a8c87e..905642bbe72e 100644 | |
--- a/src/subscription.rs | |
+++ b/src/subscription.rs | |
@@ -1,4 +1,6 @@ | |
-#[derive(Debug, Clone, Copy, Eq, PartialEq)] | |
+use crate::{Ndb, SubscriptionStream}; | |
+ | |
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] | |
pub struct Subscription(u64); | |
impl Subscription { | |
@@ -8,4 +10,8 @@ impl Subscription { | |
pub fn id(self) -> u64 { | |
self.0 | |
} | |
+ | |
+ pub fn stream(&self, ndb: &Ndb) -> SubscriptionStream { | |
+ SubscriptionStream::new(ndb.clone(), *self) | |
+ } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment