Created
May 5, 2021 09:17
-
-
Save ctron/3fa7f9912da044bd1e15331a3676cfd6 to your computer and use it in GitHub Desktop.
A streaming JSON serializer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2021 Red Hat Inc. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
use actix_http::{http::StatusCode, Response}; | |
use bytes::{BufMut, Bytes, BytesMut}; | |
use core::fmt::Debug; | |
use futures::{ | |
task::{Context, Poll}, | |
{ready, Stream}, | |
}; | |
use pin_project::pin_project; | |
use serde::Serialize; | |
use std::{ | |
fmt::{Display, Formatter}, | |
pin::Pin, | |
}; | |
/// The internal state of the stream | |
enum State { | |
/// Before the first item | |
Start, | |
/// In the middle of processing | |
Data, | |
/// After the last item | |
End, | |
} | |
#[derive(Debug)] | |
pub enum ArrayStreamerError<E> | |
where | |
E: Debug + Display, | |
{ | |
Source(E), | |
Serializer(serde_json::Error), | |
} | |
impl<E> Display for ArrayStreamerError<E> | |
where | |
E: Debug + Display, | |
{ | |
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { | |
match self { | |
Self::Source(err) => write!(f, "Source error: {}", err), | |
Self::Serializer(err) => write!(f, "Serializer error: {}", err), | |
} | |
} | |
} | |
impl<E> actix_http::ResponseError for ArrayStreamerError<E> | |
where | |
E: Debug + Display + actix_http::ResponseError, | |
{ | |
fn status_code(&self) -> StatusCode { | |
match self { | |
Self::Source(err) => err.status_code(), | |
_ => StatusCode::INTERNAL_SERVER_ERROR, | |
} | |
} | |
fn error_response(&self) -> Response { | |
match self { | |
Self::Source(err) => err.error_response(), | |
Self::Serializer(err) => Response::InternalServerError().body(err.to_string()), | |
} | |
} | |
} | |
#[pin_project] | |
pub struct ArrayStreamer<S, T, E> | |
where | |
S: Stream<Item = Result<T, E>>, | |
T: Serialize, | |
E: Debug + Display, | |
{ | |
#[pin] | |
stream: S, | |
state: State, | |
} | |
impl<S, T, E> ArrayStreamer<S, T, E> | |
where | |
S: Stream<Item = Result<T, E>>, | |
T: Serialize, | |
E: Debug + Display, | |
{ | |
pub fn new(stream: S) -> Self { | |
Self { | |
stream, | |
state: State::Start, | |
} | |
} | |
} | |
impl<S, T, E> Stream for ArrayStreamer<S, T, E> | |
where | |
S: Stream<Item = Result<T, E>>, | |
T: Serialize, | |
E: Debug + Display, | |
{ | |
type Item = Result<Bytes, ArrayStreamerError<E>>; | |
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | |
if matches!(self.state, State::End) { | |
return Poll::Ready(None); | |
} | |
let mut this = self.project(); | |
let mut data = BytesMut::new(); | |
if matches!(this.state, State::Start) { | |
data.put_u8(b'[') | |
} | |
let res = ready!(this.stream.as_mut().poll_next(cx)); | |
match res { | |
Some(Err(err)) => return Poll::Ready(Some(Err(ArrayStreamerError::Source(err)))), | |
Some(Ok(item)) => { | |
// first/next item | |
if matches!(this.state, State::Data) { | |
data.put_u8(b','); | |
} | |
// serialize | |
match serde_json::to_vec(&item) { | |
Ok(buffer) => data.put(Bytes::from(buffer)), | |
Err(err) => return Poll::Ready(Some(Err(ArrayStreamerError::Serializer(err)))), | |
} | |
// change state after encoding | |
*this.state = State::Data; | |
} | |
None => { | |
// no more content | |
*this.state = State::End; | |
} | |
}; | |
if matches!(this.state, State::End) { | |
data.put_u8(b']'); | |
} | |
if data.is_empty() { | |
Poll::Ready(None) | |
} else { | |
Poll::Ready(Some(Ok(data.into()))) | |
} | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
self.stream.size_hint() | |
} | |
} | |
#[cfg(test)] | |
mod test { | |
use super::*; | |
use futures::{stream, TryStreamExt}; | |
#[tokio::test] | |
async fn test_streamer_default() { | |
let data: Vec<Result<_, String>> = vec![Ok("foo"), Ok("bar")]; | |
let streamer = ArrayStreamer::new(stream::iter(data)); | |
let outcome: Vec<Bytes> = streamer.try_collect().await.unwrap(); | |
let outcome: String = outcome | |
.into_iter() | |
.map(|b| String::from_utf8(b.to_vec()).unwrap_or_default()) | |
.collect(); | |
assert_eq!(outcome, r#"["foo","bar"]"#); | |
} | |
#[tokio::test] | |
async fn test_streamer_empty() { | |
let data: Vec<Result<String, String>> = vec![]; | |
let streamer = ArrayStreamer::new(stream::iter(data)); | |
let outcome: Vec<Bytes> = streamer.try_collect().await.unwrap(); | |
let outcome: String = outcome | |
.into_iter() | |
.map(|b| String::from_utf8(b.to_vec()).unwrap_or_default()) | |
.collect(); | |
assert_eq!(outcome, r#"[]"#); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment