Last active
November 18, 2024 04:38
-
-
Save s3rius/3bf4a0bd6b28ca1ae94376aa290f8f1c to your computer and use it in GitHub Desktop.
PyO3-asyncio async streams
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "itertest" | |
version = "0.1.0" | |
edition = "2021" | |
[dependencies] | |
futures = "0.3.28" | |
pyo3 = "0.19.2" | |
pyo3-asyncio = { version = "0.19.0", features = ["tokio-runtime"] } | |
tokio = { version = "1.32.0", features = ["sync"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::sync::Arc; | |
use futures::{Stream, StreamExt}; | |
use pyo3::{ | |
exceptions::PyStopAsyncIteration, pymethods, pymodule, types::PyModule, PyObject, PyRef, | |
PyResult, Python, | |
}; | |
/// Here we define our Rust type, | |
/// that implements the Stream trait. | |
/// | |
/// It iterates from 1 to i. | |
pub struct RustStreamer { | |
i: u32, | |
current: u32, | |
} | |
impl RustStreamer { | |
pub fn new(i: u32) -> Self { | |
RustStreamer { i, current: 0 } | |
} | |
} | |
/// Here goes stream implementation. | |
/// | |
/// | |
/// It's a simple stream. On each poll_next call it returns next value. | |
/// If current value is equal to i, it returns None, which means, | |
/// that stream is finished. | |
impl Stream for RustStreamer { | |
type Item = u32; | |
fn poll_next( | |
self: std::pin::Pin<&mut Self>, | |
_cx: &mut std::task::Context<'_>, | |
) -> std::task::Poll<Option<Self::Item>> { | |
let this = self.get_mut(); | |
if this.current < this.i { | |
this.current += 1; | |
std::task::Poll::Ready(Some(this.current)) | |
} else { | |
std::task::Poll::Ready(None) | |
} | |
} | |
} | |
/// Here I defined a class that can be used, | |
/// as an async iterator. | |
/// | |
/// It's a simple class, that has an inner field, | |
/// which is an object that implements the Stream trait. | |
/// | |
/// But I wrap it in Mutex<...> to make it thread safe | |
/// and shareable between tokio-threads. Arc here, because | |
/// it's cheap to clone. | |
/// | |
/// Also, without mutex, it's not possible to mutate | |
/// the data inside the Arc. | |
#[pyo3::pyclass] | |
struct TestIterator { | |
pub inner: Arc<tokio::sync::Mutex<RustStreamer>>, | |
} | |
#[pymethods] | |
impl TestIterator { | |
#[new] | |
fn new(i: u32) -> Self { | |
TestIterator { | |
inner: Arc::new(tokio::sync::Mutex::new(RustStreamer::new(i))), | |
} | |
} | |
/// We don't want to create another classes, we want this | |
/// class to be iterable. Since we implemented __anext__ method, | |
/// we can return self here. | |
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { | |
slf | |
} | |
/// This is an anext implementation. | |
/// | |
/// Notable thing here is that we return PyResult<Option<PyObject>>. | |
/// We cannot return &PyAny directly here, because of pyo3 limitations. | |
/// Here's the issue about it: https://github.com/PyO3/pyo3/issues/3190 | |
fn __anext__<'a>(&self, py: Python<'a>) -> PyResult<Option<PyObject>> { | |
// Here we clone the inner field, so we can use it | |
// in our future. | |
let streamer = self.inner.clone(); | |
let future = pyo3_asyncio::tokio::future_into_py(py, async move { | |
// Here we lock the mutex to access the data inside | |
// and call next() method to get the next value. | |
let val = streamer.lock().await.next().await; | |
match val { | |
Some(val) => Ok(val), | |
// Here we return PyStopAsyncIteration error, | |
// because python needs exceptions to tell that iterator | |
// has ended. | |
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")), | |
} | |
}); | |
Ok(Some(future?.into())) | |
} | |
} | |
#[pymodule] | |
fn _internal(_py: Python<'_>, pymod: &PyModule) -> PyResult<()> { | |
pymod.add_class::<TestIterator>()?; | |
Ok(()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[project] | |
name = "itertest" | |
[tool.maturin] | |
python-source = "python" | |
module-name = "itertest._internal" | |
features = ["pyo3/extension-module"] | |
[build-system] | |
requires = ["maturin>=1.0,<2.0"] | |
build-backend = "maturin" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from itertest._internal import TestIterator | |
async def main(): | |
ti = TestIterator(i=5) | |
async for i in ti: | |
print(i) | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment