Created
June 26, 2025 02:45
-
-
Save tabokie/84728599bbb475fc947a53b84d37d04e to your computer and use it in GitHub Desktop.
h2 panic client code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::{ | |
| collections::{BTreeMap, HashMap, HashSet}, | |
| error::Error as StdError, | |
| future::Future, | |
| ops::Bound::{Included, Unbounded}, | |
| pin::Pin, | |
| sync::Arc, | |
| task::{Context, Poll}, | |
| time::{Duration, Instant}, | |
| }; | |
| use arrow_flight::{self, flight_service_client::FlightServiceClient}; | |
| use await_tree::InstrumentAwait; | |
| use backon::{BackoffBuilder, RetryableWithContext}; | |
| use futures::{FutureExt, future::BoxFuture}; | |
| use huge_pb::rbac::CreateTokenV2Request; | |
| use huge_pb::rbac::create_token_v2_request::TraditionalConfig; | |
| use huge_pb::{ | |
| common::Address, | |
| rbac::{CreateTokenRequest, rbac_client::RbacClient}, | |
| }; | |
| use hyper::http; | |
| use hyper_util::client::legacy::connect::HttpConnector; | |
| use moist::auth::Client as AuthClient; | |
| use moist::auth::{CacheSettings, ComplianceSettings, RequestSettings}; | |
| use prost::Message; | |
| use rand::prelude::*; | |
| use thiserror::Error; | |
| use tokio::sync::{RwLock, Semaphore, watch}; | |
| use tonic::{ | |
| Code, Request, Status, Streaming, | |
| body::BoxBody, | |
| codegen::InterceptedService, | |
| transport::{Channel, Endpoint, Uri}, | |
| }; | |
| use tower::Service; | |
| use tracing::{debug, error, info, instrument, warn}; | |
| use uuid::Uuid; | |
| use crate::common_error::{self, CommonError, CommonErrorCtx}; | |
| use crate::rbac::Interceptor; | |
| use crate::{AuthNConfig, Credential as CredentialType, utils}; | |
| use crate::{allocator::spawn, is_tonic_status_retryable}; | |
| /// Tonic hangs when sending >=2GiB requests(https://github.com/hyperium/tonic/issues/352) and | |
| /// reports errors when sending >=2GiB responses(https://github.com/hyperium/hyper/issues/2893). | |
| /// | |
| /// We limit the maximum message size to 2GiB(after compression), so that tonic will report an error instead of hanging. | |
| /// NOTE: The error reported by tonic in this case is weird: | |
| /// "h2 protocol error: http2 error: stream error sent by user: unexpected internal error encountered" | |
| /// instead of | |
| /// "Error, message length too large: found {} bytes, the limit is: {} bytes" | |
| pub const MAX_TONIC_MESSAGE_SIZE: usize = i32::MAX as usize; | |
| pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20); | |
| const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(3); | |
| #[derive(Error, Debug)] | |
| pub enum PoolError { | |
| #[error("failed to connect: {0}")] | |
| ConnectionError(#[from] tonic::transport::Error), | |
| #[error("failed to check connection: {0}")] | |
| CheckConnectionError(#[from] tonic::Status), | |
| } | |
| impl From<PoolError> for tonic::Status { | |
| fn from(err: PoolError) -> Self { | |
| tonic::Status::internal(err.to_string()) | |
| } | |
| } | |
| const RECONNECT_INTERVAL: Duration = Duration::from_secs(1); | |
| const RECONNECT_INTERVAL_MAX: Duration = Duration::from_secs(10); | |
| pub async fn do_connect( | |
| endpoint: Endpoint, | |
| socks5_proxy: Option<String>, | |
| ) -> Result<Channel, crate::error::CommonError> { | |
| // Socks5 proxy | |
| if let Some(proxy_host) = socks5_proxy { | |
| // Randomly choose a proxy from the ip list | |
| let ips = tokio::net::lookup_host(proxy_host.clone()) | |
| .await | |
| .map_err(|e| { | |
| crate::error::CommonError::invalid_argument(format!( | |
| "Failed to lookup socks5 host: {e:?}" | |
| )) | |
| })? | |
| .collect::<Vec<_>>(); | |
| tracing::info!("lookup socks5 host: {} => {:?}", proxy_host, ips); | |
| let random_addr = *ips.choose(&mut rand::rng()).ok_or_else(|| { | |
| crate::error::CommonError::invalid_argument("No socks5 proxy found".to_string()) | |
| })?; | |
| // Connect to the proxy | |
| let mut connector = HttpConnector::new(); | |
| connector.enforce_http(false); | |
| tracing::info!("connecting to socks5 proxy: {:?}", random_addr); | |
| let proxy = crate::socks5::SocksConnector { | |
| proxy_addr: random_addr, | |
| connector, | |
| }; | |
| return endpoint | |
| .connect_with_connector(proxy) | |
| .await | |
| .map_err(Into::into); | |
| } | |
| // No proxy | |
| endpoint.connect().await.map_err(Into::into) | |
| } | |
| pub fn rewrite_uri(req: &mut hyper::Request<tonic::body::BoxBody>, origin: &Uri) { | |
| let uri = Uri::builder() | |
| .scheme(origin.scheme().unwrap().clone()) | |
| .authority(origin.authority().unwrap().clone()) | |
| .path_and_query(req.uri().path_and_query().unwrap().clone()) | |
| .build() | |
| .unwrap(); | |
| *req.uri_mut() = uri; | |
| } | |
| pub struct AddOrigin { | |
| channel: Channel, | |
| origin: Uri, | |
| } | |
| impl AddOrigin { | |
| pub fn new(channel: Channel, origin: Uri) -> Self { | |
| Self { channel, origin } | |
| } | |
| } | |
| impl Service<hyper::Request<tonic::body::BoxBody>> for AddOrigin { | |
| type Response = <Channel as Service<hyper::Request<tonic::body::BoxBody>>>::Response; | |
| type Error = <Channel as Service<hyper::Request<tonic::body::BoxBody>>>::Error; | |
| type Future = <Channel as Service<hyper::Request<tonic::body::BoxBody>>>::Future; | |
| fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { | |
| self.channel.poll_ready(ctx) | |
| } | |
| fn call(&mut self, mut req: hyper::Request<tonic::body::BoxBody>) -> Self::Future { | |
| rewrite_uri(&mut req, &self.origin); | |
| self.channel.call(req) | |
| } | |
| } | |
| pub struct ChannelPool { | |
| inner: RwLock<ChannelPoolInner>, | |
| // Each permit of the semaphore represents a connected channel. | |
| // `poll_ready` acquires a permit before selecting a channel from the pool to ensure `channels` | |
| // is not empty. And if `poll_ready` returns pending because all channels are not available, it | |
| // will be woken up when a channel is reconnected and semaphore permits are added. | |
| alive: Semaphore, | |
| is_draining: watch::Receiver<bool>, | |
| } | |
| struct ChannelPoolInner { | |
| // TODO: make it faster with indexed tree map | |
| channels: BTreeMap<(Address, u64), ChannelWithEpoch>, | |
| enabled_endpoints: HashSet<Address>, | |
| endpoint_config: EndpointConfig, | |
| } | |
| struct ChannelWithEpoch { | |
| channel: Channel, | |
| // epoch is increased after each reconnection to avoid removing a newly created channel. | |
| epoch: u64, | |
| } | |
| impl ChannelPool { | |
| pub fn new(endpoint_config: EndpointConfig) -> Arc<Self> { | |
| Arc::new(ChannelPool { | |
| inner: RwLock::new(ChannelPoolInner { | |
| channels: BTreeMap::new(), | |
| enabled_endpoints: HashSet::new(), | |
| endpoint_config, | |
| }), | |
| alive: Semaphore::new(0), | |
| is_draining: watch::channel(false).1, | |
| }) | |
| } | |
| // Creates a ChannelPool from a created channel. Used for testing. | |
| pub fn from_channel(channel: Channel) -> Arc<Self> { | |
| let endpoint_config = EndpointConfig { | |
| conn_per_endpoint: 1, | |
| ..Default::default() | |
| }; | |
| let addr = Address::default(); | |
| Arc::new(ChannelPool { | |
| inner: RwLock::new(ChannelPoolInner { | |
| channels: [((addr.clone(), 0), ChannelWithEpoch { channel, epoch: 0 })] | |
| .into_iter() | |
| .collect(), | |
| enabled_endpoints: [addr].into_iter().collect(), | |
| endpoint_config, | |
| }), | |
| alive: Semaphore::new(Semaphore::MAX_PERMITS), | |
| is_draining: watch::channel(false).1, | |
| }) | |
| } | |
| pub fn make_tonic_service(self: &Arc<Self>) -> TonicService { | |
| debug!("new TonicService on channel pool"); | |
| TonicService { | |
| pool: self.clone(), | |
| poll_fut: None, | |
| is_draining: self.is_draining.clone(), | |
| ready: None, | |
| prefer_address: None, | |
| config: Default::default(), | |
| } | |
| } | |
| pub async fn make_tonic_service_with_preference( | |
| self: &Arc<Self>, | |
| addr: &Address, | |
| ) -> TonicService { | |
| debug!( | |
| "new TonicService on channel pool, with address preference {}", | |
| addr | |
| ); | |
| let _ = self.add_endpoint(addr).await; | |
| TonicService { | |
| pool: self.clone(), | |
| poll_fut: None, | |
| is_draining: self.is_draining.clone(), | |
| ready: None, | |
| prefer_address: addr.clone().into(), | |
| config: Default::default(), | |
| } | |
| } | |
| pub async fn add_endpoint(self: &Arc<Self>, addr: &Address) -> bool { | |
| let mut inner = self.inner.write().await; | |
| if !inner.enabled_endpoints.insert(addr.clone()) { | |
| return false; | |
| } | |
| for index in 0..inner.endpoint_config.conn_per_endpoint { | |
| self.clone().spawn_connect(addr.clone(), index, 0); | |
| } | |
| true | |
| } | |
| pub async fn remove_endpoint(self: &Arc<Self>, addr: &Address) -> bool { | |
| let mut inner = self.inner.write().await; | |
| if !inner.enabled_endpoints.remove(addr) { | |
| return false; | |
| } | |
| let mut key = (addr.clone(), 0); | |
| for index in 0..inner.endpoint_config.conn_per_endpoint { | |
| key.1 = index; | |
| if inner.channels.remove(&key).is_some() { | |
| self.alive.acquire().await.unwrap().forget(); | |
| } | |
| } | |
| true | |
| } | |
| pub async fn enabled_endpoints(&self) -> Vec<Address> { | |
| let inner = self.inner.read().await; | |
| inner.enabled_endpoints.iter().cloned().collect() | |
| } | |
| pub async fn new_with_uris( | |
| uris: impl IntoIterator<Item = Uri>, | |
| cfg: EndpointConfig, | |
| ) -> Arc<Self> { | |
| let pool = ChannelPool::new(cfg); | |
| for uri in uris { | |
| let address = uri.into(); | |
| pool.add_endpoint(&address).await; | |
| } | |
| pool | |
| } | |
| pub async fn new_with_pending( | |
| uris: impl IntoIterator<Item = Uri>, | |
| endpoint_config: EndpointConfig, | |
| pending: watch::Receiver<bool>, | |
| ) -> Arc<Self> { | |
| let pool = Arc::new(ChannelPool { | |
| inner: RwLock::new(ChannelPoolInner { | |
| channels: BTreeMap::new(), | |
| enabled_endpoints: HashSet::new(), | |
| endpoint_config, | |
| }), | |
| alive: Semaphore::new(0), | |
| is_draining: pending, | |
| }); | |
| for uri in uris { | |
| let address = uri.into(); | |
| pool.add_endpoint(&address).await; | |
| } | |
| pool | |
| } | |
| async fn handle_transport_error(self: Arc<Self>, addr: Address, index: u64, epoch: u64) { | |
| let permit = self | |
| .alive | |
| .acquire() | |
| .instrument_await("acquire_alive") | |
| .await | |
| .unwrap(); | |
| let mut inner = self.inner.write().instrument_await("acquire_inner").await; | |
| let key = (addr.clone(), index); | |
| match inner.channels.get(&key) { | |
| Some(channel) if channel.epoch == epoch => { | |
| inner.channels.remove(&key); | |
| drop(inner); | |
| permit.forget(); | |
| self.spawn_connect(key.0, key.1, epoch + 1); | |
| } | |
| _ => {} | |
| } | |
| } | |
| fn spawn_connect(self: &Arc<Self>, addr: Address, index: u64, epoch: u64) { | |
| let weak = Arc::downgrade(self); | |
| let uri = addr.to_uri(); | |
| spawn(async move { | |
| let mut interval = RECONNECT_INTERVAL; | |
| while let Some(this) = weak.upgrade() { | |
| debug!("connect to {uri}, index = {index}, epoch = {epoch}"); | |
| let (endpoint, socks5_proxy) = { | |
| let cfg = &this.inner.read().await.endpoint_config; | |
| (cfg.build(uri.clone()), cfg.socks5_proxy.clone()) | |
| }; | |
| match do_connect(endpoint, socks5_proxy).await { | |
| Ok(channel) => { | |
| debug!("connected to {uri}, index = {index}"); | |
| let mut inner = this.inner.write().await; | |
| if inner.enabled_endpoints.contains(&addr) { | |
| let channel = ChannelWithEpoch { channel, epoch }; | |
| inner.channels.insert((addr.clone(), index), channel); | |
| this.alive.add_permits(1); | |
| } | |
| break; | |
| } | |
| Err(err) => { | |
| debug!("failed to connect to {uri}: {err:?}"); | |
| if !this.inner.read().await.enabled_endpoints.contains(&addr) { | |
| break; | |
| } | |
| let jitter = tokio::time::Duration::from_secs( | |
| rand::rng().random_range(0..interval.as_secs()), | |
| ); | |
| tokio::time::sleep(interval + jitter).await; | |
| interval += jitter; | |
| interval = interval.min(RECONNECT_INTERVAL_MAX); | |
| } | |
| } | |
| } | |
| }); | |
| } | |
| } | |
| pub struct TonicService { | |
| pool: Arc<ChannelPool>, | |
| #[allow(clippy::type_complexity)] | |
| poll_fut: | |
| Option<Pin<Box<dyn Future<Output = Result<ReadyChannel, Status>> + Send + Sync + 'static>>>, | |
| is_draining: watch::Receiver<bool>, | |
| ready: Option<ReadyChannel>, | |
| prefer_address: Option<Address>, | |
| config: TonicServiceConfig, | |
| } | |
| pub(crate) struct TonicServiceConfig { | |
| pub(crate) warn_interval: Duration, | |
| pub(crate) connection_timeout: Duration, | |
| } | |
| impl Default for TonicServiceConfig { | |
| fn default() -> Self { | |
| Self { | |
| warn_interval: Duration::from_secs(30), | |
| connection_timeout: Duration::from_secs(10 * 60), | |
| } | |
| } | |
| } | |
| #[derive(Clone, Debug)] | |
| pub struct EndpointConfig { | |
| pub conn_per_endpoint: u64, | |
| pub connection_window_size: u32, | |
| pub stream_window_size: u32, | |
| pub mstreams_per_node: u64, | |
| pub max_inflights_per_mstream: u64, | |
| /// HTTP/2 keep alive is required for server behind LB, otherwise server | |
| /// changes behind LB may not be detected by the client and the client may get stuck. | |
| pub enable_h2_keep_alive: bool, | |
| pub socks5_proxy: Option<String>, | |
| /// Maximum time to fetch the sender of a preferred address. If this time is exceeded, | |
| /// choose a random msender instead. | |
| pub fetch_preferred_sender_timeout: Duration, | |
| pub reset_mstream_interval: Duration, | |
| } | |
| // The default implementation is only used internally. | |
| // each multiplexed stream could serve around 150 MB/s, | |
| // assume client have 1 GB/s inbound, below are some recommended settings | |
| // Prod - 10 read pods: mstreams_per_node >= 1 | |
| // Res - 20 read pods: mstreams_per_node >= 1 | |
| // Bench - 6 read pods: mstreams_per_node >= 2 | |
| // Dev - 2 read pods: mstreams_per_node >= 4 | |
| // Singel read pod: mstreams_per_node >= 8 | |
| // `mstreams_per_node` is parsed from `HugestoreOption.proxy_connections`, which can be set in ~/.huge-config.toml | |
| impl Default for EndpointConfig { | |
| fn default() -> Self { | |
| Self { | |
| conn_per_endpoint: 3, | |
| connection_window_size: 1024 * 1024 * 1024, | |
| stream_window_size: 2 * 1024 * 1024, | |
| mstreams_per_node: 2, | |
| max_inflights_per_mstream: 32, | |
| enable_h2_keep_alive: true, | |
| socks5_proxy: None, | |
| fetch_preferred_sender_timeout: Duration::from_secs(1), | |
| reset_mstream_interval: Duration::ZERO, // disable reset by default | |
| } | |
| } | |
| } | |
| impl EndpointConfig { | |
| pub fn build(&self, uri: Uri) -> Endpoint { | |
| let mut builder = Channel::builder(uri) | |
| .initial_connection_window_size(Some(self.connection_window_size)) | |
| .initial_stream_window_size(Some(self.stream_window_size)) | |
| .tcp_keepalive(Some(DEFAULT_TIMEOUT)) | |
| .connect_timeout(DEFAULT_CONNECT_TIMEOUT); | |
| if self.enable_h2_keep_alive { | |
| builder = builder | |
| .keep_alive_while_idle(true) | |
| .http2_keep_alive_interval(Duration::from_secs(5)) | |
| .keep_alive_timeout(DEFAULT_TIMEOUT); | |
| } | |
| builder | |
| } | |
| } | |
| struct ReadyChannel { | |
| addr: Address, | |
| index: u64, | |
| channel: Channel, | |
| epoch: u64, | |
| } | |
| impl Clone for TonicService { | |
| fn clone(&self) -> Self { | |
| Self { | |
| pool: self.pool.clone(), | |
| poll_fut: None, | |
| is_draining: self.is_draining.clone(), | |
| ready: None, | |
| prefer_address: self.prefer_address.clone(), | |
| config: Default::default(), | |
| } | |
| } | |
| } | |
| impl Service<http::Request<BoxBody>> for TonicService { | |
| type Response = <Channel as Service<http::Request<BoxBody>>>::Response; | |
| type Error = Box<dyn StdError + Send + Sync>; | |
| type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>; | |
| fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> { | |
| if *self.is_draining.borrow() { | |
| self.ready = None; | |
| } | |
| if self.ready.is_some() { | |
| return Poll::Ready(Ok(())); | |
| } | |
| let poll_fut = self.poll_fut.get_or_insert_with(|| { | |
| let pool = self.pool.clone(); | |
| let prefer_address = self.prefer_address.clone(); | |
| let mut is_draining = self.is_draining.clone(); | |
| let warn_interval = self.config.warn_interval; | |
| let connection_timeout = self.config.connection_timeout; | |
| Box::pin(async move { | |
| loop { | |
| // warn every 30s if no channel is ready | |
| let mut warn_interval = tokio::time::interval(warn_interval); | |
| let now = Instant::now(); | |
| // first tick happens immediately | |
| warn_interval.tick().await; | |
| loop { | |
| crate::spurious_select! { | |
| _ = warn_interval.tick() => { | |
| if now.elapsed() > connection_timeout { | |
| return Err(Status::unavailable(format!("no endpoint is connected in the past {connection_timeout:?}"))); | |
| } | |
| warn!("service not ready, no endpoint is connected"); | |
| } | |
| _permit = pool.alive.acquire().instrument_await("acquire_alive") => { | |
| break | |
| } | |
| } | |
| } | |
| if *is_draining.borrow_and_update() { | |
| info!("service is waiting for draining"); | |
| let _ = is_draining.changed().await; | |
| } | |
| let inner = pool.inner.read().instrument_await("acquire_inner").await; | |
| debug_assert!(!inner.channels.is_empty()); | |
| let candidates = match &prefer_address { | |
| Some(prefered_address) => { | |
| debug!("Poll with preference {:?}", prefer_address); | |
| let mut candidates: Vec<_> = inner | |
| .channels | |
| .range(( | |
| Included((prefered_address.clone(), 0)), | |
| Included(( | |
| prefered_address.clone(), | |
| inner.endpoint_config.conn_per_endpoint, | |
| )), | |
| )) | |
| .collect(); | |
| if candidates.is_empty() { | |
| debug!("Service not ready, Locality is changing"); | |
| // Prefered channel drained, try next | |
| candidates = inner | |
| .channels | |
| .range((Included((prefered_address.clone(), 0)), Unbounded)) | |
| .collect(); | |
| if candidates.is_empty() { | |
| // choose first address if next is end | |
| let first_addr = | |
| inner.channels.first_key_value().unwrap().0 .0.clone(); | |
| candidates = inner | |
| .channels | |
| .range(( | |
| Included((first_addr.clone(), 0)), | |
| Included(( | |
| first_addr.clone(), | |
| inner.endpoint_config.conn_per_endpoint, | |
| )), | |
| )) | |
| .collect::<Vec<_>>(); | |
| } else { | |
| let next_addr = candidates[0].0 .0.clone(); | |
| candidates.retain(|candidate| candidate.0 .0 == next_addr); | |
| } | |
| } | |
| candidates | |
| } | |
| None => { | |
| debug!("Poll without preference"); | |
| inner.channels.iter().collect::<Vec<_>>() | |
| } | |
| }; | |
| let ((addr, index), channel) = | |
| candidates[rand::rng().random_range(0..candidates.len())]; | |
| let mut ready = ReadyChannel { | |
| addr: addr.clone(), | |
| index: *index, | |
| channel: channel.channel.clone(), | |
| epoch: channel.epoch, | |
| }; | |
| drop(inner); | |
| // `Channel::poll_ready()` returns errors only when an user request failed on | |
| // the channel, so it can't be used to detect the channel is still alive, | |
| // so as the http2 keepalive request which is not an user request. | |
| match futures::future::poll_fn(|cx| ready.channel.poll_ready(cx)) | |
| .instrument_await("poll_ready") | |
| .await | |
| { | |
| Ok(()) => { | |
| if *is_draining.borrow() { | |
| debug!("service is draining"); | |
| continue; | |
| } | |
| return Ok(ready); | |
| } | |
| Err(err) => { | |
| warn!("transport error {err:?} on {}", ready.addr); | |
| pool.clone() | |
| .handle_transport_error( | |
| ready.addr.clone(), | |
| ready.index, | |
| ready.epoch, | |
| ) | |
| .instrument_await("handle_transport_error_in_poll_ready") | |
| .await; | |
| } | |
| } | |
| } | |
| }) | |
| }); | |
| match poll_fut.poll_unpin(cx) { | |
| Poll::Ready(res) => { | |
| self.poll_fut = None; | |
| match res { | |
| Ok(ready) => { | |
| self.ready = Some(ready); | |
| Poll::Ready(Ok(())) | |
| } | |
| Err(err) => Poll::Ready(Err(Box::new(err))), | |
| } | |
| } | |
| Poll::Pending => Poll::Pending, | |
| } | |
| } | |
| fn call(&mut self, mut request: http::Request<BoxBody>) -> Self::Future { | |
| let pool = self.pool.clone(); | |
| let mut ready = self.ready.take().unwrap(); | |
| let uuid = utils::get_bin_metadata(request.headers(), "tracer-bin", &mut [0; 24]) | |
| .and_then(|v| Uuid::from_slice(v).ok()); | |
| rewrite_uri(&mut request, &ready.addr.to_uri()); | |
| let api = request.uri().path().split('/').last().unwrap_or(""); | |
| let span: await_tree::Span = match uuid { | |
| Some(v) => format!("call[{v:?}][{api}]").into(), | |
| None => format!("call[{api}]").into(), | |
| }; | |
| Box::pin(async move { | |
| let res = ready.channel.call(request).instrument_await(span).await; | |
| if let Err(err) = res.as_ref() { | |
| warn!("transport error {err:?} on {}", ready.addr); | |
| pool.handle_transport_error(ready.addr.clone(), ready.index, ready.epoch) | |
| .instrument_await("handle_transport_error_in_call") | |
| .await; | |
| } | |
| res.map_err(Into::into) | |
| }) | |
| } | |
| } | |
| pub type TonicServiceWithToken = InterceptedService<TonicService, Interceptor>; | |
| pub struct DataClientPool { | |
| pool: Arc<ChannelPool>, | |
| clients: HashMap<Address, FlightServiceClient<TonicServiceWithToken>>, | |
| } | |
| impl DataClientPool { | |
| pub fn new(config: EndpointConfig) -> Self { | |
| Self { | |
| pool: ChannelPool::new(config), | |
| clients: HashMap::new(), | |
| } | |
| } | |
| pub fn is_empty(&self) -> bool { | |
| self.clients.is_empty() | |
| } | |
| pub fn fetch_random(&self) -> Option<FlightServiceClient<TonicServiceWithToken>> { | |
| self.clients | |
| .iter() | |
| .choose(&mut rand::rng()) | |
| .map(|(_, client)| client.clone()) | |
| } | |
| pub fn fetch_with_preferred_address( | |
| &self, | |
| addr: &Address, | |
| ) -> Option<FlightServiceClient<TonicServiceWithToken>> { | |
| self.clients.get(addr).cloned() | |
| } | |
| pub fn fetch_all(&self) -> HashMap<Address, FlightServiceClient<TonicServiceWithToken>> { | |
| self.clients.clone() | |
| } | |
| pub async fn add_address( | |
| &mut self, | |
| addr: Address, | |
| interceptor: Interceptor, | |
| ) -> FlightServiceClient<TonicServiceWithToken> { | |
| if self.clients.contains_key(&addr) { | |
| return self.clients[&addr].clone(); | |
| } | |
| let service = self.pool.make_tonic_service_with_preference(&addr).await; | |
| let client = FlightServiceClient::with_interceptor(service, interceptor) | |
| .max_encoding_message_size(crate::MAX_TONIC_MESSAGE_SIZE) | |
| .max_decoding_message_size(crate::MAX_TONIC_MESSAGE_SIZE); | |
| self.clients.insert(addr, client.clone()); | |
| client | |
| } | |
| pub async fn remove_address(&mut self, addr: &Address) -> bool { | |
| self.clients.remove(addr); | |
| self.pool.remove_endpoint(addr).await | |
| } | |
| pub async fn enabled_endpoints(&self) -> Vec<Address> { | |
| self.pool.enabled_endpoints().await | |
| } | |
| pub fn make_tonic_service(&self) -> TonicService { | |
| self.pool.make_tonic_service() | |
| } | |
| } | |
| pub async fn create_token_v2( | |
| meta_service: TonicService, | |
| config: AuthNConfig, | |
| ) -> Result<String, CommonError> { | |
| moist::async_fail_point!("create_token_v2"); | |
| let mut rbac_client = RbacClient::new(meta_service); | |
| let max_nr_retry = 6; | |
| // extract traditional config | |
| let traditional_cfg = if let Some((username, password)) = &config.traditional { | |
| Some(TraditionalConfig { | |
| username: username.to_string(), | |
| password: password.to_string(), | |
| }) | |
| } else { | |
| None | |
| }; | |
| // extract auth service config | |
| let auth_options = if let Some(auth_config) = &config.auth_service { | |
| let (path, env, region) = if auth_config.credentail == CredentialType::Path { | |
| ( | |
| auth_config.config_path.clone(), | |
| Some(auth_config.env), | |
| Some(auth_config.region), | |
| ) | |
| } else { | |
| (auth_config.config_path.clone(), None, None) | |
| }; | |
| let auth_options = moist::auth::Options::from_file(path, env); | |
| match auth_options { | |
| Ok((options, env_region)) => { | |
| let (env, region) = env_region | |
| .or(match (env, region) { | |
| (Some(env), Some(region)) => Some((env, region)), | |
| _ => None, | |
| }).ok_or_else(|| common_error::rbac_error!( | |
| "env and region must be provided in hugestore config file or auth service config file" | |
| ) | |
| )?; | |
| Some((options, env, region)) | |
| } | |
| Err(e) => { | |
| if traditional_cfg.is_none() { | |
| return Err(common_error::rbac_error!( | |
| "Failed to read auth service config file: {}", | |
| e.to_string() | |
| )); | |
| } else { | |
| warn!( | |
| "Failed to read auth service config file: {}, use username/password to login hugestore..", | |
| e.to_string() | |
| ); | |
| None | |
| } | |
| } | |
| } | |
| } else { | |
| None | |
| }; | |
| let mut last_err = None; | |
| for _ in 0..max_nr_retry { | |
| moist::async_fail_point!("create_token_v2_loop"); | |
| let traditional_cfg = traditional_cfg.clone(); | |
| let create_token_result: Result<_, CommonError> = async { | |
| let token = if let Some((options, env, region)) = &auth_options { | |
| let auth_client = AuthClient::initialize( | |
| options, | |
| env.to_owned(), | |
| region.to_owned(), | |
| &RequestSettings::default(), | |
| &CacheSettings::default(), | |
| &ComplianceSettings::default(), | |
| ) | |
| .await?; | |
| let mut access_token = auth_client.self_access_token().await?; | |
| (|| { | |
| moist::fail_point!("common::self_access_token", |s| { | |
| access_token = s.unwrap(); | |
| }) | |
| })(); | |
| Some(access_token) | |
| } else { | |
| None | |
| }; | |
| let create_token_resp = rbac_client | |
| .create_token_v2(Request::new(CreateTokenV2Request { | |
| traditional_cfg, | |
| access_token: token, | |
| expire_in_days: 30, | |
| })) | |
| .await | |
| .map_err(|e| { | |
| common_error::rbac_error!("Failed to create token: {}", e.to_string()) | |
| })?; | |
| Ok(create_token_resp) | |
| } | |
| .await; | |
| match create_token_result { | |
| Ok(res) => { | |
| let res = res.into_inner(); | |
| let token = res.jwt.unwrap().token; | |
| // use traditional config successfully, but failed to use auth service | |
| if let Some(msg) = res.err_msg { | |
| moist::async_fail_point!("common::create_token_v2::auth_error"); | |
| warn!("error about auth service: {}", msg); | |
| } | |
| return Ok(token); | |
| } | |
| Err(err) => { | |
| warn!("failed to create token: {err:?}, retrying..."); | |
| last_err = Some(err); | |
| tokio::time::sleep(std::time::Duration::from_millis(100)).await; | |
| } | |
| } | |
| } | |
| Err(common_error::rbac_error!( | |
| "Failed to create token after maximum retry attempts, last error: {last_err:?}", | |
| )) | |
| } | |
| pub async fn create_token( | |
| meta_service: TonicService, | |
| username: String, | |
| password: String, | |
| ) -> Result<String, CommonError> { | |
| let mut rbac_client = RbacClient::new(meta_service); | |
| let mut nr_retry = 0; | |
| loop { | |
| let res = rbac_client | |
| .create_token(Request::new(CreateTokenRequest { | |
| user_name: username.clone(), | |
| password: password.clone(), | |
| expire_in_days: 30, | |
| })) | |
| .await; | |
| match res { | |
| Ok(res) => { | |
| let token = res.into_inner().token.unwrap().token; | |
| info!("token obtained"); | |
| return Ok(token); | |
| } | |
| Err(status) if status.code() == Code::Unauthenticated => { | |
| return Err(CommonError::unauthenticated(status)); | |
| } | |
| Err(err) => { | |
| warn!("failed to create token: {err:?}, retry"); | |
| nr_retry += 1; | |
| if nr_retry > 3 { | |
| return Err(common_error::rbac_error!( | |
| "Failed to create token after maximum retry attempts: {}", | |
| err | |
| )); | |
| } | |
| tokio::time::sleep(std::time::Duration::from_millis(nr_retry * 100)).await; | |
| } | |
| } | |
| } | |
| } | |
| // use this function if you want stream response from data server otherwise data_single_action will | |
| // be a better choice | |
| #[instrument(skip_all)] | |
| pub async fn proxy_do_action( | |
| proxy: &mut FlightServiceClient<TonicServiceWithToken>, | |
| req: huge_pb::proxy::proxy_request::Request, | |
| ) -> Result<tonic::Response<Streaming<arrow_flight::Result>>, tonic::Status> { | |
| let proxy_req = huge_pb::proxy::ProxyRequest { request: Some(req) }; | |
| let mut body = Vec::new(); | |
| proxy_req | |
| .encode(&mut body) | |
| .map_err(|e| Status::internal(format!("failed to encode request: {:?}", e.to_string())))?; | |
| let action = arrow_flight::Action { | |
| r#type: "".into(), | |
| body: body.into(), | |
| }; | |
| proxy.do_action(tonic::Request::new(action)).await | |
| } | |
| // do action and fetch only the first message from stream response | |
| #[instrument(skip_all)] | |
| pub async fn data_single_action( | |
| data: &mut FlightServiceClient<TonicServiceWithToken>, | |
| req: huge_pb::proxy::proxy_request::Request, | |
| ) -> Result<huge_pb::proxy::ProxyResponse, tonic::Status> { | |
| moist::async_fail_point!("data_single_action_error", |_| { | |
| Err(tonic::Status::unknown("data_single_action fail_point")) | |
| }); | |
| let mut stream = proxy_do_action(data, req).await?.into_inner(); | |
| let body = stream | |
| .message() | |
| .await? | |
| .ok_or_else(|| Status::internal("empty response from do_action, this should not happen"))? | |
| .body; | |
| huge_pb::proxy::ProxyResponse::decode(body).map_err(|e| { | |
| error!("failed to decode proxy response: {}", e.to_string()); | |
| Status::internal(e.to_string()) | |
| }) | |
| } | |
| pub async fn data_action_with_backoff( | |
| data: FlightServiceClient<TonicServiceWithToken>, | |
| req_fn: impl Fn() -> huge_pb::proxy::proxy_request::Request + Clone, | |
| backoff: impl BackoffBuilder, | |
| ) -> Result<huge_pb::proxy::ProxyResponse, tonic::Status> { | |
| let (_, resp) = (|mut data| async { | |
| let req = req_fn(); | |
| let resp = data_single_action(&mut data, req).await; | |
| (data, resp) | |
| }) | |
| .retry(&backoff) | |
| .context(data) | |
| .when(is_tonic_status_retryable) | |
| .notify(|e, duration| { | |
| warn!("retryable error: {:?}, sleep after {:?}", e, duration); | |
| }) | |
| .await; | |
| resp | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use std::{os::fd::AsRawFd, sync::Mutex, time::Instant}; | |
| use super::*; | |
| use futures::{Stream, TryStreamExt}; | |
| use tokio::net::TcpListener; | |
| use tonic::{ | |
| Response, async_trait, | |
| transport::{Server, server::TcpIncoming}, | |
| }; | |
| use tonic_health::{ | |
| ServingStatus, | |
| pb::{ | |
| HealthCheckRequest, HealthCheckResponse, | |
| health_client::HealthClient, | |
| health_server::{Health, HealthServer}, | |
| }, | |
| }; | |
| struct HealthService; | |
| #[async_trait] | |
| impl Health for HealthService { | |
| async fn check( | |
| &self, | |
| _request: Request<HealthCheckRequest>, | |
| ) -> Result<Response<HealthCheckResponse>, Status> { | |
| tokio::time::sleep(Duration::from_millis(100)).await; | |
| Ok(Response::new(HealthCheckResponse { | |
| status: ServingStatus::Serving as i32, | |
| })) | |
| } | |
| type WatchStream = | |
| Pin<Box<dyn Stream<Item = Result<HealthCheckResponse, Status>> + Send + 'static>>; | |
| async fn watch( | |
| &self, | |
| _request: Request<HealthCheckRequest>, | |
| ) -> Result<Response<Self::WatchStream>, Status> { | |
| let output = async_stream::try_stream! { | |
| loop { | |
| let status = ServingStatus::Serving as i32; | |
| yield HealthCheckResponse { status }; | |
| tokio::time::sleep(Duration::from_millis(100)).await; | |
| } | |
| #[allow(unreachable_code)] | |
| () | |
| }; | |
| Ok(Response::new(Box::pin(output) as Self::WatchStream)) | |
| } | |
| } | |
| #[tokio::test] | |
| async fn test_multiplex() { | |
| let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); | |
| let addr = listener.local_addr().unwrap(); | |
| spawn(async move { | |
| Server::builder() | |
| .add_service(HealthServer::new(HealthService)) | |
| .serve_with_incoming(TcpIncoming::from_listener(listener, true, None).unwrap()) | |
| .await | |
| .unwrap(); | |
| }); | |
| let uri = Uri::try_from(format!("http://{}", addr)).unwrap(); | |
| let service = ChannelPool::new_with_uris([uri], Default::default()) | |
| .await | |
| .make_tonic_service(); | |
| // unary | |
| let begin = Instant::now(); | |
| futures::future::join_all((0..100).map(|_| async { | |
| let mut client = HealthClient::new(service.clone()); | |
| let resp = client | |
| .check(HealthCheckRequest::default()) | |
| .await | |
| .unwrap() | |
| .into_inner(); | |
| assert_eq!(resp.status, ServingStatus::Serving as i32); | |
| })) | |
| .await; | |
| assert!(begin.elapsed() < Duration::from_secs(10)); | |
| // stream | |
| let begin = Instant::now(); | |
| futures::future::join_all((0..100).map(|_| async { | |
| let mut client = HealthClient::new(service.clone()); | |
| let mut stream = client | |
| .watch(HealthCheckRequest::default()) | |
| .await | |
| .unwrap() | |
| .into_inner(); | |
| for _ in 0..3 { | |
| let resp = stream.try_next().await.unwrap().unwrap(); | |
| assert_eq!(resp.status, ServingStatus::Serving as i32); | |
| } | |
| })) | |
| .await; | |
| assert!(begin.elapsed() < Duration::from_secs(10)); | |
| } | |
| #[tokio::test] | |
| async fn test_reconnect() { | |
| let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); | |
| let addr = listener.local_addr().unwrap(); | |
| let tcp_fds = Arc::new(Mutex::new(Vec::new())); | |
| let tcp_fds_ = tcp_fds.clone(); | |
| let listener_stream = | |
| tokio_stream::wrappers::TcpListenerStream::new(listener).inspect_ok(move |tcp| { | |
| tcp_fds_.lock().unwrap().push(tcp.as_raw_fd()); | |
| }); | |
| spawn(async move { | |
| Server::builder() | |
| .add_service(HealthServer::new(HealthService)) | |
| .serve_with_incoming(listener_stream) | |
| .await | |
| .unwrap(); | |
| }); | |
| let uri = Uri::try_from(format!("http://{}", addr)).unwrap(); | |
| let service = ChannelPool::new_with_uris([uri], Default::default()) | |
| .await | |
| .make_tonic_service(); | |
| let semaphore = Arc::new(Semaphore::new(0)); | |
| let mut handles = Vec::new(); | |
| for _ in 0..100 { | |
| let service = service.clone(); | |
| let semaphore = semaphore.clone(); | |
| handles.push(spawn(async move { | |
| let mut client = HealthClient::new(service.clone()); | |
| let mut stream = client | |
| .watch(HealthCheckRequest::default()) | |
| .await? | |
| .into_inner(); | |
| semaphore.add_permits(1); | |
| while let Some(resp) = stream.try_next().await? { | |
| assert_eq!(resp.status, ServingStatus::Serving as i32); | |
| } | |
| Ok::<_, anyhow::Error>(()) | |
| })); | |
| } | |
| semaphore.acquire_many(100).await.unwrap().forget(); | |
| let fds = tcp_fds.lock().unwrap().clone(); | |
| for fd in fds { | |
| unsafe { | |
| libc::shutdown(fd, libc::SHUT_RDWR); | |
| } | |
| } | |
| for handle in handles { | |
| assert!(handle.await.unwrap().is_err()); | |
| } | |
| let mut client = HealthClient::new(service.clone()); | |
| let resp = client | |
| .check(HealthCheckRequest::default()) | |
| .await | |
| .unwrap() | |
| .into_inner(); | |
| assert_eq!(resp.status, ServingStatus::Serving as i32); | |
| } | |
| #[tokio::test] | |
| async fn test_connection_timeout() { | |
| let mut tonic_service = ChannelPool::new(Default::default()).make_tonic_service(); | |
| tonic_service.config.warn_interval = Duration::from_millis(1); | |
| tonic_service.config.connection_timeout = Duration::from_millis(1); | |
| let mut client = HealthClient::new(tonic_service); | |
| let err = client | |
| .check(HealthCheckRequest::default()) | |
| .await | |
| .unwrap_err(); | |
| // error returned by tonic poll_ready is always unknown | |
| assert_eq!(err.code(), Code::Unknown); | |
| assert!(err.message().contains("no endpoint is connected")); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /// Fault tolerable Fragments Resolver | |
| /// | |
| /// Meta-Server | |
| /// ^ | | |
| /// | | -> agent request on-track, retry them when recover from error | |
| /// (resolve)| v (response) | |
| /// MetaResolveAgent -> Rebuilt the meta connection when notified meta-server offline. while met critical error, notify all the code readers through ReaderProducer | |
| /// ^ | |
| /// | (resolve fragment) | |
| /// ------^------ | |
| /// | | | |
| /// ReaderProducer ReaderProducer -> forge reader from resolved ticket | |
| /// | | -> streamingly send readers | |
| /// v v | |
| /// CodeReader1, CodeReader2, ... | |
| use await_tree::InstrumentAwait; | |
| use futures::future::try_join_all; | |
| use huge_common::{ | |
| ALL_CODES, AbortableHandle, TonicServiceWithToken, | |
| error::{CommonErrorCtx, HugeErrorKind}, | |
| tracer::with_local_tracer, | |
| utils::is_tonic_status_retryable, | |
| }; | |
| use huge_pb::{ | |
| LOGICAL_TIME_MAX, LOGICAL_TIME_MIN, | |
| common::DataDescriptor, | |
| meta::{ | |
| FragmentsOfCodeV2, ResolveFragmentsRequest, ResolveFragmentsResponse, | |
| meta_client::MetaClient, | |
| }, | |
| proxy::Ticket, | |
| }; | |
| use moist::async_fail_point; | |
| use std::{ | |
| collections::{HashMap, VecDeque}, | |
| sync::Arc, | |
| time::Duration, | |
| }; | |
| use tokio::sync::{ | |
| Mutex, | |
| mpsc::{Receiver, Sender}, | |
| oneshot, | |
| }; | |
| use tokio_stream::{StreamExt, wrappers::ReceiverStream}; | |
| use tonic::Streaming; | |
| use tracing::info; | |
| use crate::{InternalReadOptions, errors::ClientError, utils::wait_exponential}; | |
| #[async_trait::async_trait] | |
| pub trait ResolveFragment: Send + Sync { | |
| async fn resolve( | |
| &self, | |
| code: String, | |
| request: ResolveFragmentsRequest, | |
| ) -> Result<ResolveFragmentsResponse, ClientError>; | |
| /// Resolve all fragments in the given range in timely order. Can't handle read_latest, only 1 | |
| /// code is allowed | |
| async fn resolve_all( | |
| &self, | |
| code: String, | |
| mut req: ResolveFragmentsRequest, | |
| ) -> Result<ResolveFragmentsResponse, ClientError> { | |
| if req.read_latest { | |
| return Err(ClientError::internal( | |
| "resolve_all can't handle read_latest", | |
| )); | |
| } | |
| if req.version() == huge_pb::meta::resolve_fragments_request::Version::Deprecated { | |
| return Err(ClientError::internal("resolve_all can't handle v1 request")); | |
| } | |
| if req.data.is_none() { | |
| return Err(ClientError::internal("data descriptor is missing")); | |
| } | |
| if req.data.as_ref().unwrap().codes != vec![code.clone()] || code == *ALL_CODES { | |
| return Err(ClientError::internal( | |
| "resolve_all can't handle multiple codes", | |
| )); | |
| } | |
| let mut partitions = vec![]; | |
| let mut columns = vec![]; | |
| loop { | |
| let (start, end) = { | |
| let dd = req.data.as_ref().unwrap(); | |
| (dd.start, dd.end) | |
| }; | |
| match (start, end) { | |
| (Some(start), Some(end)) if start < end => { | |
| let foc = self | |
| .resolve(code.clone(), req.clone()) | |
| .await? | |
| .codes_v2 | |
| .remove(0); | |
| if columns.is_empty() { | |
| columns = foc.columns; | |
| } | |
| if foc.partitions.is_empty() { | |
| break; | |
| } | |
| let end = foc | |
| .partitions | |
| .iter() | |
| .map(|p| *p.end.as_ref().unwrap()) | |
| .max() | |
| .ok_or_else(|| { | |
| ClientError::internal("end of resolved partition is none") | |
| })?; | |
| req.data.as_mut().unwrap().start = Some(end); | |
| partitions.extend(foc.partitions); | |
| } | |
| _ => break, | |
| }; | |
| } | |
| for p in partitions.windows(2) { | |
| debug_assert_eq!(p[0].end.as_ref().unwrap(), p[1].start.as_ref().unwrap()); | |
| if p[0].end.as_ref().unwrap() != p[1].start.as_ref().unwrap() { | |
| tracing::warn!( | |
| "resolve dataset[{}] with code {}: partition end {:?} != next partition start {:?}", | |
| req.data.as_ref().unwrap().dataset_id, | |
| code, | |
| p[0].end.as_ref().unwrap(), | |
| p[1].start.as_ref().unwrap() | |
| ); | |
| } | |
| } | |
| Ok(ResolveFragmentsResponse { | |
| codes_v2: vec![FragmentsOfCodeV2 { | |
| code, | |
| columns, | |
| partitions, | |
| }], | |
| }) | |
| } | |
| } | |
| pub async fn resolve_all_codes<R: ResolveFragment + 'static>( | |
| resolver: Arc<R>, | |
| req: ResolveFragmentsRequest, | |
| concurrency: usize, | |
| ) -> Result<ResolveFragmentsResponse, ClientError> { | |
| assert!(req.version() == huge_pb::meta::resolve_fragments_request::Version::V2); | |
| let codes = req.data.as_ref().unwrap().codes.clone(); | |
| let mut codes_v2 = vec![]; | |
| let mut tasks = Vec::with_capacity(codes.len()); | |
| let sema = Arc::new(tokio::sync::Semaphore::new(concurrency)); | |
| for code in codes { | |
| let mut req = req.clone(); | |
| let sema = sema.clone(); | |
| let resolver = resolver.clone(); | |
| let task = tokio::spawn(async move { | |
| let _permit = sema.acquire().await; | |
| req.data.as_mut().unwrap().codes = vec![code.clone()]; | |
| resolver.resolve_all(code, req.clone()).await | |
| }); | |
| tasks.push(async move { AbortableHandle::new(task).await? }); | |
| } | |
| for resp in try_join_all(tasks).await? { | |
| codes_v2.extend(resp.codes_v2); | |
| } | |
| Ok(ResolveFragmentsResponse { codes_v2 }) | |
| } | |
| /// An agent to resolve fragments with `AgentRequest` | |
| #[derive(Clone)] | |
| pub struct MetaResolveAgent { | |
| agent_req_txs: Vec<Sender<AgentRequest>>, | |
| router: Arc<std::sync::atomic::AtomicUsize>, | |
| } | |
| pub const SLOT_DIVISOR: u64 = 1000; | |
| impl MetaResolveAgent { | |
| // Build a MetaResolveAgent that accept requests for speicific code | |
| pub fn new_per_code( | |
| meta: MetaClient<TonicServiceWithToken>, | |
| // The number of `resolve_fragments` streams. | |
| // Usually, it's `(columns.len() / SLOT_DIVISOR).max(1)`. | |
| num_slots: u64, | |
| max_retry_times: u32, | |
| ) -> Arc<Self> { | |
| Self::new::<PerCodeRequestTracker>(meta, num_slots, max_retry_times) | |
| } | |
| // Build a MetaResolveAgent that only accept requests for __ALL_CODES__ | |
| pub fn new_all_codes( | |
| meta: MetaClient<TonicServiceWithToken>, | |
| max_retry_times: u32, | |
| ) -> Arc<Self> { | |
| Self::new::<AllCodesRequestTracker>(meta, 1, max_retry_times) | |
| } | |
| fn new<T>( | |
| meta: MetaClient<TonicServiceWithToken>, | |
| // The number of `resolve_fragments` streams. | |
| // Usually, it's `(columns.len() / SLOT_DIVISOR).max(1)`. | |
| num_slots: u64, | |
| max_retry_times: u32, | |
| ) -> Arc<Self> | |
| where | |
| T: AgentRequestTracker + Send + 'static, | |
| { | |
| let mut agent_req_txs = vec![]; | |
| for _ in 0..num_slots { | |
| let (agent_req_tx, agent_req_rx) = tokio::sync::mpsc::channel(64); | |
| let backend = | |
| MetaResolveAgentBackend::<T>::new(meta.clone(), agent_req_rx, max_retry_times); | |
| huge_common::spawn(with_local_tracer("resolve_backend", async move { | |
| backend.run().await; | |
| })); | |
| agent_req_txs.push(agent_req_tx); | |
| } | |
| Arc::new(Self { | |
| agent_req_txs, | |
| router: Arc::new(std::sync::atomic::AtomicUsize::new(0)), | |
| }) | |
| } | |
| } | |
| #[async_trait::async_trait] | |
| impl ResolveFragment for MetaResolveAgent { | |
| async fn resolve( | |
| &self, | |
| code: String, | |
| request: ResolveFragmentsRequest, | |
| ) -> Result<ResolveFragmentsResponse, ClientError> { | |
| let (resp_tx, resp_rx) = oneshot::channel(); | |
| let agent_req = AgentRequest { | |
| code, | |
| request, | |
| resp_tx, | |
| }; | |
| let slot = self | |
| .router | |
| .fetch_add(1, std::sync::atomic::Ordering::Release) | |
| % self.agent_req_txs.len(); | |
| self.agent_req_txs[slot] | |
| .send(agent_req) | |
| .instrument_await("send") | |
| .await | |
| .map_err(|_| { | |
| ClientError::cancelled_by_other_error("fragements resovler sender error") | |
| })?; | |
| match resp_rx.await { | |
| Ok(resp) => resp, | |
| Err(_) => Err(ClientError::cancelled_by_other_error( | |
| "fragements resovler receiver error", | |
| )), | |
| } | |
| } | |
| } | |
| // Help to resolve desired fragments, then streamingly send the flight info to downstream. | |
| // After all fragments were resolved, it will close | |
| pub struct TicketProducer { | |
| pub code: String, | |
| pub fragment_resolver: Arc<dyn ResolveFragment>, | |
| pub address: Sender<Result<Ticket, ClientError>>, | |
| pub data_descriptor: DataDescriptor, | |
| pub include_indices: [bool; 3], | |
| pub internal_read_options: InternalReadOptions, | |
| } | |
| impl TicketProducer { | |
| #[inline] | |
| pub async fn run(self) { | |
| debug_assert_eq!(self.data_descriptor.codes.len(), 1); | |
| // Otherwise response will contain multiple codes. | |
| debug_assert!( | |
| self.data_descriptor.use_range_partition | |
| || self.data_descriptor.codes != *crate::ALL_CODES_VEC | |
| ); | |
| if self.internal_read_options.read_latest { | |
| self.run_for_latest_day().await; | |
| } else { | |
| self.run_for_all().await; | |
| } | |
| } | |
| /// Resolves all partitions in the given range in timely order. Produces | |
| /// readers iteratively. | |
| async fn run_for_all(mut self) { | |
| // endless resolve, get notified when new data comes | |
| let endless = self.data_descriptor.end.is_none(); | |
| if endless { | |
| self.data_descriptor.end = Some(*LOGICAL_TIME_MAX); | |
| } | |
| loop { | |
| // If start reaches end, we are done. | |
| match (self.data_descriptor.start, self.data_descriptor.end) { | |
| (Some(start), Some(end)) if start < end => {} | |
| _ => return, | |
| } | |
| let data = Some(self.data_descriptor.clone()); | |
| let req = ResolveFragmentsRequest { | |
| data, | |
| version: huge_pb::meta::resolve_fragments_request::Version::V2.into(), | |
| read_latest: false, | |
| ask_partitions: 0, | |
| not_use_cache: false, | |
| filter_empty: !endless, // endless resolve need empty partition end hint to generate promise | |
| }; | |
| let resp = self | |
| .fragment_resolver | |
| .resolve(self.code.clone(), req) | |
| .instrument_await("resolve") | |
| .await; | |
| match resp { | |
| Ok(mut resp) => { | |
| let code = resp.codes_v2.remove(0); | |
| let ordered_part = code.partitions; | |
| if ordered_part.is_empty() { | |
| if endless { | |
| // Make a promise with next lower bound | |
| if self | |
| .address | |
| .send(Ok(Ticket::new_promise(self.data_descriptor.start.unwrap()))) | |
| .instrument_await("send_promise") | |
| .await | |
| .is_err() | |
| { | |
| return; | |
| } | |
| tokio::time::sleep(Duration::from_secs(1)).await; | |
| continue; | |
| } else { | |
| // resolve to the end, break the loop | |
| return; | |
| } | |
| } | |
| for partition in ordered_part { | |
| let partition_start = partition.start; | |
| let partition_end = partition.end; | |
| let start = match (partition_start, self.data_descriptor.start) { | |
| (Some(p_start), Some(dd_start)) => Some(p_start.max(dd_start)), | |
| _ => self.data_descriptor.start, | |
| }; | |
| let end = match (partition_end, self.data_descriptor.end) { | |
| (Some(p_end), Some(dd_end)) => Some(p_end.min(dd_end)), | |
| _ => self.data_descriptor.end, | |
| }; | |
| // filter out empty partition | |
| if partition.fragments.is_empty() { | |
| self.data_descriptor.start = end; | |
| continue; | |
| } | |
| // filter out out of range partition | |
| let (min_frag_start, max_frag_end) = partition | |
| .fragments | |
| .iter() | |
| .fold((*LOGICAL_TIME_MAX, *LOGICAL_TIME_MIN), |acc, f| { | |
| (acc.0.min(f.start.unwrap()), acc.1.max(f.end.unwrap())) | |
| }); | |
| if min_frag_start >= end.unwrap() || max_frag_end <= start.unwrap() { | |
| async_fail_point!("TicketProducer::fragments_out_of_range"); | |
| self.data_descriptor.start = end; | |
| continue; | |
| } | |
| let ticket = Ticket { | |
| partition_id: partition.id_to_deprecate, | |
| code_v2: Some(FragmentsOfCodeV2 { | |
| columns: code.columns.clone(), | |
| partitions: vec![partition], | |
| code: code.code.clone(), | |
| }), | |
| include_indices: self.include_indices.into(), | |
| start, | |
| end, | |
| dataset_id: self.data_descriptor.dataset_id, | |
| bypass_cache: false, | |
| support_nanosecond_index: self.support_nanosecond_index(), | |
| }; | |
| if self | |
| .address | |
| .send(Ok(ticket)) | |
| .instrument_await("send_address") | |
| .await | |
| .is_err() | |
| { | |
| return; | |
| } | |
| self.data_descriptor.start = end; | |
| } | |
| } | |
| Err(err) if err.huge_kind() == HugeErrorKind::CancelledByOtherError => { | |
| // Avoid broadcast the error, sliently close | |
| return; | |
| } | |
| Err(err) => { | |
| let _ = self.address.send(Err(err)).await; | |
| return; | |
| } | |
| } | |
| } | |
| } | |
| /// Resolves all partitions in the latest day, still produces readers in | |
| /// timely ascending order. | |
| /// | |
| /// It does not guarantee only outputing data in one single day. This only | |
| /// happens when one fragment contains data from multiple days and user | |
| /// requested only part of them. Downstream consumer is responsible for | |
| /// additional filtering. | |
| /// | |
| /// The performance is worse than `run` because it cannot pipeline the | |
| /// handling of different partitions and readers. | |
| async fn run_for_latest_day(mut self) { | |
| // `run_for_latest_day` does not support endless resolve | |
| debug_assert!(self.data_descriptor.end.is_some()); | |
| // Match the default value of `max_partitions_returned`. | |
| const ASK_PARTITIONS: usize = 8; | |
| let mut columns = Vec::with_capacity(1); | |
| let mut partitions = VecDeque::with_capacity(ASK_PARTITIONS); | |
| let mut latest_date = None; | |
| let user_end = self.data_descriptor.end.unwrap(); | |
| // This is required by the assertion: | |
| // `latest_date < user_end.trading_date`. | |
| assert!(user_end.raw_time_is_zero()); | |
| #[cfg(debug_assertions)] | |
| let mut received_focs = Vec::new(); | |
| 'out: loop { | |
| let req = ResolveFragmentsRequest { | |
| data: Some(self.data_descriptor.clone()), | |
| version: huge_pb::meta::resolve_fragments_request::Version::V2.into(), | |
| read_latest: true, | |
| ask_partitions: ASK_PARTITIONS as i32, | |
| not_use_cache: false, | |
| filter_empty: false, | |
| }; | |
| let resp = self.fragment_resolver.resolve(self.code.clone(), req).await; | |
| match resp { | |
| Ok(mut resp) => { | |
| let foc = resp.codes_v2.remove(0); | |
| let len = foc.partitions.len(); | |
| #[cfg(debug_assertions)] | |
| received_focs.push(foc.clone()); | |
| if foc.partitions.is_empty() { | |
| // resolve to the end, break the loop | |
| break; | |
| } | |
| debug_assert!( | |
| foc.partitions | |
| .windows(2) | |
| .all(|p| { p[0].end.unwrap() <= p[1].start.unwrap() }), | |
| "{:?}", | |
| foc | |
| ); | |
| columns.push(foc.columns); | |
| for partition in foc.partitions.into_iter().rev() { | |
| assert!(partition.start.is_some()); | |
| assert!(partition.end.is_some()); | |
| if partition.fragments.is_empty() { | |
| self.data_descriptor.end = partition.start; | |
| continue; | |
| } | |
| // TODO: Is this true for data truncate? | |
| debug_assert!(partition.fragments[0].num_rows > 0); | |
| // For feature fragments, the boundaries might not be aligned. | |
| let tight_start_time = std::cmp::max( | |
| partition.start.unwrap(), | |
| partition | |
| .fragments | |
| .iter() | |
| .map(|f| f.start.unwrap()) | |
| .min() | |
| .unwrap(), | |
| ); | |
| let tight_end_time = std::cmp::min( | |
| partition.end.unwrap(), | |
| partition | |
| .fragments | |
| .iter() | |
| .map(|f| f.end.unwrap()) | |
| .max() | |
| .unwrap(), | |
| ); | |
| if tight_start_time >= self.data_descriptor.end.unwrap() { | |
| debug_assert!(self.data_descriptor.end >= partition.start); | |
| self.data_descriptor.end = partition.start; | |
| continue; | |
| } | |
| if let Some(latest_date) = latest_date { | |
| if latest_date > tight_end_time.trading_date { | |
| break 'out; | |
| } | |
| } else { | |
| // TODO: | |
| // We want to use `tight_end_date` as the latest date if it | |
| // is within range, it relies on invariant: | |
| // `frag.end_time.date == frag.last_row.date` | |
| // | |
| // But the invariant does not hold when truncation happened. | |
| assert!(tight_start_time.trading_date < user_end.trading_date); | |
| // This is not accurate, the actual date might be larger. | |
| // Consider this case: | |
| // part = [day_0, day_1, day_10], part.end = day_10_1, | |
| // read_latest(until = day_9). | |
| // latest_date will be day_0. | |
| latest_date = Some(tight_start_time.trading_date); | |
| } | |
| self.data_descriptor.end = partition.start; | |
| partitions.push_front((partition, columns.len() - 1)); | |
| } | |
| if len < ASK_PARTITIONS { | |
| break; | |
| } | |
| } | |
| Err(err) if err.huge_kind() == HugeErrorKind::CancelledByOtherError => { | |
| // Avoid broadcast the error, sliently close | |
| return; | |
| } | |
| Err(err) => { | |
| let _ = self.address.send(Err(err)).await; | |
| return; | |
| } | |
| } | |
| } | |
| #[cfg(debug_assertions)] | |
| { | |
| let key = format!( | |
| "read_latest_{}_{}", | |
| self.data_descriptor.dataset_id, self.code | |
| ); | |
| let msg = format!( | |
| "user_end = {:?}, received_focs = {:?}, latest_date = {:?}", | |
| user_end, received_focs, latest_date | |
| ); | |
| huge_common::test::debug_info_map().insert(key, msg); | |
| } | |
| if latest_date.is_none() { | |
| return; | |
| } | |
| let window_start = huge_pb::common::LogicalTime::new_with_micro(latest_date.unwrap(), 0); | |
| for (partition, column_idx) in partitions { | |
| let start = partition.start.unwrap().max(window_start); | |
| let end = user_end.min(partition.end.unwrap()); | |
| let ticket = Ticket { | |
| partition_id: partition.id_to_deprecate, | |
| code_v2: Some(FragmentsOfCodeV2 { | |
| columns: columns[column_idx].clone(), | |
| partitions: vec![partition], | |
| code: self.code.clone(), | |
| }), | |
| include_indices: self.include_indices.into(), | |
| start: Some(start), | |
| end: Some(end), | |
| dataset_id: self.data_descriptor.dataset_id, | |
| bypass_cache: false, | |
| support_nanosecond_index: self.support_nanosecond_index(), | |
| }; | |
| if self.address.send(Ok(ticket)).await.is_err() { | |
| return; | |
| } | |
| } | |
| } | |
| fn support_nanosecond_index(&self) -> bool { | |
| moist::fail_point!("TicketProducer::old_client", |_| false); | |
| true | |
| } | |
| } | |
| /// State transition | |
| /// | |
| /// 0. Init | |
| /// * Upstream is built | |
| /// Init --> Recover(0) | |
| /// * build upstream failed, but could retry | |
| /// Init --> RetryableFailure | |
| /// * build upstream failed, but could not retry | |
| /// Init --> Finish with error | |
| /// 1. Normal | |
| /// * a retryable error is received | |
| /// Normal --> RetryableFailure | |
| /// * a critical error is received | |
| /// Normal --> Finish with error | |
| /// * order channel closed | |
| /// Normal --> Finish OK | |
| /// 3. Finish | |
| /// * exit state | |
| /// 4. RetryableFailure | |
| /// * upstream is rebuilt | |
| /// RetryableFailure --> Normal | |
| /// * rebuild upstream failed, but could retry | |
| /// RetryableFailure --> RetryableFailure | |
| /// * rebuild upstream failed, and could not retry | |
| /// RetryableFailure --> Finish with error | |
| #[derive(Debug)] | |
| enum AgentState { | |
| Init, | |
| Normal { | |
| sender: Sender<ResolveFragmentsRequest>, | |
| receiver: Streaming<ResolveFragmentsResponse>, | |
| }, | |
| Finish { | |
| error: Option<ClientError>, | |
| }, | |
| RetryableFailure { | |
| retry_left: u32, | |
| }, | |
| } | |
| trait AgentRequestTracker: Default { | |
| fn complete_agent_req(&mut self, resp: ResolveFragmentsResponse); | |
| fn refs(&self) -> Box<dyn Iterator<Item = &AgentRequest> + '_ + Send>; | |
| fn drain(&mut self) -> Box<dyn Iterator<Item = AgentRequest> + '_>; | |
| fn insert(&mut self, req: AgentRequest) -> bool; | |
| fn is_empty(&self) -> bool { | |
| self.refs().next().is_none() | |
| } | |
| } | |
| #[derive(Default)] | |
| struct PerCodeRequestTracker { | |
| reqs_ontrack: HashMap<String, AgentRequest>, | |
| } | |
| impl AgentRequestTracker for PerCodeRequestTracker { | |
| fn complete_agent_req(&mut self, resp: ResolveFragmentsResponse) { | |
| assert_eq!(resp.codes_v2.len(), 1); | |
| let req = self | |
| .reqs_ontrack | |
| .remove(&resp.codes_v2[0].code) | |
| .expect("unexpected: phantom agent request"); | |
| req.resp_tx | |
| .send(Ok(resp)) | |
| .expect("unexpected: Reader closed"); | |
| } | |
| fn refs(&self) -> Box<dyn Iterator<Item = &AgentRequest> + '_ + Send> { | |
| Box::new(self.reqs_ontrack.values()) | |
| } | |
| fn drain(&mut self) -> Box<dyn Iterator<Item = AgentRequest> + '_> { | |
| Box::new(self.reqs_ontrack.drain().map(|(_, req)| req)) | |
| } | |
| fn insert(&mut self, req: AgentRequest) -> bool { | |
| self.reqs_ontrack.insert(req.code.clone(), req).is_none() | |
| } | |
| } | |
| #[derive(Default)] | |
| struct AllCodesRequestTracker { | |
| req: Option<AgentRequest>, | |
| } | |
| impl AgentRequestTracker for AllCodesRequestTracker { | |
| fn complete_agent_req(&mut self, resp: ResolveFragmentsResponse) { | |
| self.req | |
| .take() | |
| .expect("unexpected: phantom agent request") | |
| .resp_tx | |
| .send(Ok(resp)) | |
| .expect("unexpected: Reader closed"); | |
| } | |
| fn refs(&self) -> Box<dyn Iterator<Item = &AgentRequest> + '_ + Send> { | |
| Box::new(self.req.iter()) | |
| } | |
| fn drain(&mut self) -> Box<dyn Iterator<Item = AgentRequest> + '_> { | |
| Box::new(self.req.take().into_iter()) | |
| } | |
| fn insert(&mut self, req: AgentRequest) -> bool { | |
| assert!(req.code == *ALL_CODES); | |
| self.req.replace(req).is_none() | |
| } | |
| } | |
| /// MetaResolveAgentBackend runs as the bridge between MetaResolveAgent and meta connection | |
| /// It can automatically retry on retryable failures from the meta connection | |
| struct MetaResolveAgentBackend<T> | |
| where | |
| T: AgentRequestTracker, | |
| { | |
| meta: MetaClient<TonicServiceWithToken>, | |
| agent_req_rx: Receiver<AgentRequest>, | |
| max_retry_times: u32, | |
| agent_reqs_ontrack: T, | |
| } | |
| struct AgentRequest { | |
| code: String, // Reader should guarantee there is at most 1 order in-fly | |
| request: ResolveFragmentsRequest, | |
| resp_tx: oneshot::Sender<Result<ResolveFragmentsResponse, ClientError>>, | |
| } | |
| impl<T> MetaResolveAgentBackend<T> | |
| where | |
| T: AgentRequestTracker, | |
| { | |
| pub fn new( | |
| meta: MetaClient<TonicServiceWithToken>, | |
| agent_req_rx: Receiver<AgentRequest>, | |
| max_retry_times: u32, | |
| ) -> Self { | |
| Self { | |
| meta, | |
| agent_req_rx, | |
| max_retry_times, | |
| agent_reqs_ontrack: T::default(), | |
| } | |
| } | |
| pub async fn build_meta_connection(&mut self, retry_left: u32) -> AgentState { | |
| let (req_tx, req_rx) = tokio::sync::mpsc::channel::<ResolveFragmentsRequest>(64); | |
| let res = match self | |
| .meta | |
| .resolve_fragments_v2(ReceiverStream::new(req_rx)) | |
| .await | |
| .map(|resp| resp.into_inner()) | |
| { | |
| Ok(streaming) => AgentState::Normal { | |
| sender: req_tx, | |
| receiver: streaming, | |
| }, | |
| Err(err) => self.handle_error(err, retry_left), | |
| }; | |
| moist::async_fail_point!("MetaResolveAgentBackend::after_build_meta_connection"); | |
| res | |
| } | |
| fn handle_error(&mut self, status: tonic::Status, retry_left: u32) -> AgentState { | |
| if is_tonic_status_retryable(&status) { | |
| if retry_left > 0 { | |
| info!("Resolve fragments connection error: {:?}, retry...", status); | |
| AgentState::RetryableFailure { | |
| retry_left: retry_left - 1, | |
| } | |
| } else { | |
| let error = if self.max_retry_times == 0 { | |
| Some(status.into()) | |
| } else { | |
| Some(ClientError::internal("Reach retry limit")) | |
| }; | |
| AgentState::Finish { error } | |
| } | |
| } else { | |
| AgentState::Finish { | |
| error: Some(status.into()), | |
| } | |
| } | |
| } | |
| async fn execute_normal( | |
| &mut self, | |
| sender: Sender<ResolveFragmentsRequest>, | |
| mut receiver: Streaming<ResolveFragmentsResponse>, | |
| ) -> AgentState { | |
| // recover the agent reqs in-fly from last failure | |
| for agent_req in self.agent_reqs_ontrack.refs() { | |
| let _ = sender | |
| .send(agent_req.request.clone()) | |
| .instrument_await("recover_send") | |
| .await; | |
| } | |
| loop { | |
| huge_common::spurious_select! { | |
| resp = receiver.next() => { | |
| match resp { | |
| Some(Ok(resp)) => { | |
| self.agent_reqs_ontrack.complete_agent_req(resp); | |
| } | |
| Some(Err(error)) => { | |
| info!("MetaResolveAgentBackend encounters error: {:?}", error); | |
| return self.handle_error(error, self.max_retry_times); | |
| } | |
| None => unreachable!("unexpected: meta-server closed the stream"), | |
| } | |
| } | |
| agent_req = self.agent_req_rx.recv() => { | |
| if let Some(agent_req) = agent_req { | |
| let req = agent_req.request.clone(); | |
| assert!(self.agent_reqs_ontrack.insert(agent_req)); | |
| let _ = sender.send(req).instrument_await("send").await; // ignore the error in send | |
| } else { | |
| // reqs senders are all dropped which means resovle fragments completed | |
| debug_assert!(self.agent_reqs_ontrack.is_empty()); | |
| return AgentState::Finish { error: None }; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| fn cancel_agent_reqs_ontrack(&mut self, error: ClientError) { | |
| let mut reqs = self.agent_reqs_ontrack.drain(); | |
| if let Some(req) = reqs.next() { | |
| let _ = req.resp_tx.send(Err(error)); | |
| } | |
| for req in reqs { | |
| let _ = req.resp_tx.send(Err(ClientError::cancelled_by_other_error( | |
| "cancel agent requests ontrack", | |
| ))); | |
| } | |
| } | |
| pub async fn run(mut self) { | |
| let mut state = AgentState::Init; | |
| loop { | |
| state = match state { | |
| AgentState::Init => self.build_meta_connection(self.max_retry_times).await, | |
| AgentState::Normal { sender, receiver } => { | |
| self.execute_normal(sender, receiver) | |
| .instrument_await("execute_normal") | |
| .await | |
| } | |
| AgentState::Finish { error } => { | |
| if let Some(error) = error { | |
| tracing::error!("Resolve fragments meets critical error {:?}", error); | |
| self.cancel_agent_reqs_ontrack(error); | |
| } | |
| break; | |
| } | |
| AgentState::RetryableFailure { retry_left } => { | |
| moist::async_fail_point!("MetaResolveAgentBackend::handle_retryable_failure"); | |
| wait_exponential(self.max_retry_times - retry_left - 1) | |
| .instrument_await(format!("wait[retry_left={retry_left}]")) | |
| .await; | |
| self.build_meta_connection(retry_left).await | |
| } | |
| } | |
| } | |
| tracing::debug!("MetaResolveAgentBackend closed"); | |
| } | |
| } | |
| #[derive(Clone)] | |
| pub struct MockResolver { | |
| all_codes_v2: HashMap<String, Arc<Mutex<VecDeque<FragmentsOfCodeV2>>>>, | |
| } | |
| impl MockResolver { | |
| pub fn new(all_codes_v2: HashMap<String, VecDeque<FragmentsOfCodeV2>>) -> Self { | |
| Self { | |
| all_codes_v2: all_codes_v2 | |
| .into_iter() | |
| .map(|(code, partitions)| (code, Arc::new(Mutex::new(partitions)))) | |
| .collect(), | |
| } | |
| } | |
| pub fn new_from_resolve_response(resp: ResolveFragmentsResponse) -> Self { | |
| let mut all_codes_v2 = HashMap::new(); | |
| resp.codes_v2.into_iter().for_each(|foc| { | |
| all_codes_v2 | |
| .entry(foc.code.clone()) | |
| .or_insert_with(VecDeque::new) | |
| .push_back(foc); | |
| }); | |
| Self::new(all_codes_v2) | |
| } | |
| } | |
| #[async_trait::async_trait] | |
| impl ResolveFragment for MockResolver { | |
| async fn resolve( | |
| &self, | |
| code: String, | |
| _req: ResolveFragmentsRequest, | |
| ) -> Result<ResolveFragmentsResponse, ClientError> { | |
| let mut codes = self | |
| .all_codes_v2 | |
| .get(&code) | |
| .ok_or_else(|| ClientError::internal(format!("code {} not found", code).as_str()))? | |
| .lock() | |
| .await; | |
| let (partitions, columns) = if codes.is_empty() { | |
| (vec![], vec![]) | |
| } else { | |
| let FragmentsOfCodeV2 { | |
| columns, | |
| partitions, | |
| .. | |
| } = codes.pop_front().unwrap(); | |
| (partitions, columns) | |
| }; | |
| Ok(ResolveFragmentsResponse { | |
| codes_v2: vec![FragmentsOfCodeV2 { | |
| code: code.clone(), | |
| partitions, | |
| columns, | |
| }], | |
| }) | |
| } | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use huge_common::test::{assert_contains, make_time}; | |
| use huge_pb::meta::PartitionInfoV2; | |
| use super::*; | |
| #[tokio::test] | |
| async fn test_resolve_all() { | |
| let foc = FragmentsOfCodeV2 { | |
| code: "btc".to_string(), | |
| partitions: vec![PartitionInfoV2 { | |
| id_to_deprecate: 0, | |
| code: "btc".to_string(), | |
| start: Some(make_time("20230101").into()), | |
| end: Some(make_time("20230102").into()), | |
| fragments: vec![], | |
| ..Default::default() | |
| }], | |
| columns: vec![], | |
| }; | |
| let resolver = MockResolver::new(HashMap::from([( | |
| "btc".to_string(), | |
| VecDeque::from(vec![foc.clone()]), | |
| )])); | |
| let data = DataDescriptor { | |
| dataset_id: 0, | |
| // Compaction is only supported for the patch version 0 (major version) | |
| start: Some(make_time("20230101").into()), | |
| end: Some(make_time("20230102").into()), | |
| codes: vec!["btc".to_string()], | |
| use_range_partition: false, | |
| targets: vec![], | |
| }; | |
| let res = resolver | |
| .resolve_all( | |
| "btc".to_string(), | |
| ResolveFragmentsRequest { | |
| data: Some(data.clone()), | |
| version: huge_pb::meta::resolve_fragments_request::Version::V2.into(), | |
| read_latest: true, | |
| ask_partitions: 0, | |
| not_use_cache: false, | |
| filter_empty: false, | |
| }, | |
| ) | |
| .await; | |
| assert_contains( | |
| res.unwrap_err().to_string(), | |
| "resolve_all can't handle read_latest", | |
| ); | |
| let res = resolver | |
| .resolve_all( | |
| "btc".to_string(), | |
| ResolveFragmentsRequest { | |
| data: None, | |
| version: huge_pb::meta::resolve_fragments_request::Version::V2.into(), | |
| read_latest: false, | |
| ask_partitions: 0, | |
| not_use_cache: false, | |
| filter_empty: false, | |
| }, | |
| ) | |
| .await; | |
| assert_contains(res.unwrap_err().to_string(), "data descriptor is missing"); | |
| let res = resolver | |
| .resolve_all( | |
| "apple".to_string(), | |
| ResolveFragmentsRequest { | |
| data: Some(data.clone()), | |
| version: huge_pb::meta::resolve_fragments_request::Version::V2.into(), | |
| read_latest: false, | |
| ask_partitions: 0, | |
| not_use_cache: false, | |
| filter_empty: false, | |
| }, | |
| ) | |
| .await; | |
| assert_contains( | |
| res.unwrap_err().to_string(), | |
| "resolve_all can't handle multiple codes", | |
| ); | |
| let res = resolver | |
| .resolve_all( | |
| "btc".to_string(), | |
| ResolveFragmentsRequest { | |
| data: Some(data.clone()), | |
| version: huge_pb::meta::resolve_fragments_request::Version::Deprecated.into(), | |
| read_latest: false, | |
| ask_partitions: 0, | |
| not_use_cache: false, | |
| filter_empty: false, | |
| }, | |
| ) | |
| .await; | |
| assert_contains( | |
| res.unwrap_err().to_string(), | |
| "resolve_all can't handle v1 request", | |
| ); | |
| let res = resolver | |
| .resolve_all( | |
| "btc".to_string(), | |
| ResolveFragmentsRequest { | |
| data: Some(data.clone()), | |
| version: huge_pb::meta::resolve_fragments_request::Version::V2.into(), | |
| read_latest: false, | |
| ask_partitions: 0, | |
| not_use_cache: false, | |
| filter_empty: false, | |
| }, | |
| ) | |
| .await; | |
| assert_eq!( | |
| res.unwrap(), | |
| ResolveFragmentsResponse { | |
| codes_v2: vec![foc], | |
| } | |
| ); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment