Skip to content

Instantly share code, notes, and snippets.

@tabokie
Created June 26, 2025 02:45
Show Gist options
  • Save tabokie/84728599bbb475fc947a53b84d37d04e to your computer and use it in GitHub Desktop.
Save tabokie/84728599bbb475fc947a53b84d37d04e to your computer and use it in GitHub Desktop.
h2 panic client code
use std::{
collections::{BTreeMap, HashMap, HashSet},
error::Error as StdError,
future::Future,
ops::Bound::{Included, Unbounded},
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use arrow_flight::{self, flight_service_client::FlightServiceClient};
use await_tree::InstrumentAwait;
use backon::{BackoffBuilder, RetryableWithContext};
use futures::{FutureExt, future::BoxFuture};
use huge_pb::rbac::CreateTokenV2Request;
use huge_pb::rbac::create_token_v2_request::TraditionalConfig;
use huge_pb::{
common::Address,
rbac::{CreateTokenRequest, rbac_client::RbacClient},
};
use hyper::http;
use hyper_util::client::legacy::connect::HttpConnector;
use moist::auth::Client as AuthClient;
use moist::auth::{CacheSettings, ComplianceSettings, RequestSettings};
use prost::Message;
use rand::prelude::*;
use thiserror::Error;
use tokio::sync::{RwLock, Semaphore, watch};
use tonic::{
Code, Request, Status, Streaming,
body::BoxBody,
codegen::InterceptedService,
transport::{Channel, Endpoint, Uri},
};
use tower::Service;
use tracing::{debug, error, info, instrument, warn};
use uuid::Uuid;
use crate::common_error::{self, CommonError, CommonErrorCtx};
use crate::rbac::Interceptor;
use crate::{AuthNConfig, Credential as CredentialType, utils};
use crate::{allocator::spawn, is_tonic_status_retryable};
/// Tonic hangs when sending >=2GiB requests(https://github.com/hyperium/tonic/issues/352) and
/// reports errors when sending >=2GiB responses(https://github.com/hyperium/hyper/issues/2893).
///
/// We limit the maximum message size to 2GiB(after compression), so that tonic will report an error instead of hanging.
/// NOTE: The error reported by tonic in this case is weird:
/// "h2 protocol error: http2 error: stream error sent by user: unexpected internal error encountered"
/// instead of
/// "Error, message length too large: found {} bytes, the limit is: {} bytes"
pub const MAX_TONIC_MESSAGE_SIZE: usize = i32::MAX as usize;
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20);
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
#[derive(Error, Debug)]
pub enum PoolError {
#[error("failed to connect: {0}")]
ConnectionError(#[from] tonic::transport::Error),
#[error("failed to check connection: {0}")]
CheckConnectionError(#[from] tonic::Status),
}
impl From<PoolError> for tonic::Status {
fn from(err: PoolError) -> Self {
tonic::Status::internal(err.to_string())
}
}
const RECONNECT_INTERVAL: Duration = Duration::from_secs(1);
const RECONNECT_INTERVAL_MAX: Duration = Duration::from_secs(10);
pub async fn do_connect(
endpoint: Endpoint,
socks5_proxy: Option<String>,
) -> Result<Channel, crate::error::CommonError> {
// Socks5 proxy
if let Some(proxy_host) = socks5_proxy {
// Randomly choose a proxy from the ip list
let ips = tokio::net::lookup_host(proxy_host.clone())
.await
.map_err(|e| {
crate::error::CommonError::invalid_argument(format!(
"Failed to lookup socks5 host: {e:?}"
))
})?
.collect::<Vec<_>>();
tracing::info!("lookup socks5 host: {} => {:?}", proxy_host, ips);
let random_addr = *ips.choose(&mut rand::rng()).ok_or_else(|| {
crate::error::CommonError::invalid_argument("No socks5 proxy found".to_string())
})?;
// Connect to the proxy
let mut connector = HttpConnector::new();
connector.enforce_http(false);
tracing::info!("connecting to socks5 proxy: {:?}", random_addr);
let proxy = crate::socks5::SocksConnector {
proxy_addr: random_addr,
connector,
};
return endpoint
.connect_with_connector(proxy)
.await
.map_err(Into::into);
}
// No proxy
endpoint.connect().await.map_err(Into::into)
}
pub fn rewrite_uri(req: &mut hyper::Request<tonic::body::BoxBody>, origin: &Uri) {
let uri = Uri::builder()
.scheme(origin.scheme().unwrap().clone())
.authority(origin.authority().unwrap().clone())
.path_and_query(req.uri().path_and_query().unwrap().clone())
.build()
.unwrap();
*req.uri_mut() = uri;
}
pub struct AddOrigin {
channel: Channel,
origin: Uri,
}
impl AddOrigin {
pub fn new(channel: Channel, origin: Uri) -> Self {
Self { channel, origin }
}
}
impl Service<hyper::Request<tonic::body::BoxBody>> for AddOrigin {
type Response = <Channel as Service<hyper::Request<tonic::body::BoxBody>>>::Response;
type Error = <Channel as Service<hyper::Request<tonic::body::BoxBody>>>::Error;
type Future = <Channel as Service<hyper::Request<tonic::body::BoxBody>>>::Future;
fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.channel.poll_ready(ctx)
}
fn call(&mut self, mut req: hyper::Request<tonic::body::BoxBody>) -> Self::Future {
rewrite_uri(&mut req, &self.origin);
self.channel.call(req)
}
}
pub struct ChannelPool {
inner: RwLock<ChannelPoolInner>,
// Each permit of the semaphore represents a connected channel.
// `poll_ready` acquires a permit before selecting a channel from the pool to ensure `channels`
// is not empty. And if `poll_ready` returns pending because all channels are not available, it
// will be woken up when a channel is reconnected and semaphore permits are added.
alive: Semaphore,
is_draining: watch::Receiver<bool>,
}
struct ChannelPoolInner {
// TODO: make it faster with indexed tree map
channels: BTreeMap<(Address, u64), ChannelWithEpoch>,
enabled_endpoints: HashSet<Address>,
endpoint_config: EndpointConfig,
}
struct ChannelWithEpoch {
channel: Channel,
// epoch is increased after each reconnection to avoid removing a newly created channel.
epoch: u64,
}
impl ChannelPool {
pub fn new(endpoint_config: EndpointConfig) -> Arc<Self> {
Arc::new(ChannelPool {
inner: RwLock::new(ChannelPoolInner {
channels: BTreeMap::new(),
enabled_endpoints: HashSet::new(),
endpoint_config,
}),
alive: Semaphore::new(0),
is_draining: watch::channel(false).1,
})
}
// Creates a ChannelPool from a created channel. Used for testing.
pub fn from_channel(channel: Channel) -> Arc<Self> {
let endpoint_config = EndpointConfig {
conn_per_endpoint: 1,
..Default::default()
};
let addr = Address::default();
Arc::new(ChannelPool {
inner: RwLock::new(ChannelPoolInner {
channels: [((addr.clone(), 0), ChannelWithEpoch { channel, epoch: 0 })]
.into_iter()
.collect(),
enabled_endpoints: [addr].into_iter().collect(),
endpoint_config,
}),
alive: Semaphore::new(Semaphore::MAX_PERMITS),
is_draining: watch::channel(false).1,
})
}
pub fn make_tonic_service(self: &Arc<Self>) -> TonicService {
debug!("new TonicService on channel pool");
TonicService {
pool: self.clone(),
poll_fut: None,
is_draining: self.is_draining.clone(),
ready: None,
prefer_address: None,
config: Default::default(),
}
}
pub async fn make_tonic_service_with_preference(
self: &Arc<Self>,
addr: &Address,
) -> TonicService {
debug!(
"new TonicService on channel pool, with address preference {}",
addr
);
let _ = self.add_endpoint(addr).await;
TonicService {
pool: self.clone(),
poll_fut: None,
is_draining: self.is_draining.clone(),
ready: None,
prefer_address: addr.clone().into(),
config: Default::default(),
}
}
pub async fn add_endpoint(self: &Arc<Self>, addr: &Address) -> bool {
let mut inner = self.inner.write().await;
if !inner.enabled_endpoints.insert(addr.clone()) {
return false;
}
for index in 0..inner.endpoint_config.conn_per_endpoint {
self.clone().spawn_connect(addr.clone(), index, 0);
}
true
}
pub async fn remove_endpoint(self: &Arc<Self>, addr: &Address) -> bool {
let mut inner = self.inner.write().await;
if !inner.enabled_endpoints.remove(addr) {
return false;
}
let mut key = (addr.clone(), 0);
for index in 0..inner.endpoint_config.conn_per_endpoint {
key.1 = index;
if inner.channels.remove(&key).is_some() {
self.alive.acquire().await.unwrap().forget();
}
}
true
}
pub async fn enabled_endpoints(&self) -> Vec<Address> {
let inner = self.inner.read().await;
inner.enabled_endpoints.iter().cloned().collect()
}
pub async fn new_with_uris(
uris: impl IntoIterator<Item = Uri>,
cfg: EndpointConfig,
) -> Arc<Self> {
let pool = ChannelPool::new(cfg);
for uri in uris {
let address = uri.into();
pool.add_endpoint(&address).await;
}
pool
}
pub async fn new_with_pending(
uris: impl IntoIterator<Item = Uri>,
endpoint_config: EndpointConfig,
pending: watch::Receiver<bool>,
) -> Arc<Self> {
let pool = Arc::new(ChannelPool {
inner: RwLock::new(ChannelPoolInner {
channels: BTreeMap::new(),
enabled_endpoints: HashSet::new(),
endpoint_config,
}),
alive: Semaphore::new(0),
is_draining: pending,
});
for uri in uris {
let address = uri.into();
pool.add_endpoint(&address).await;
}
pool
}
async fn handle_transport_error(self: Arc<Self>, addr: Address, index: u64, epoch: u64) {
let permit = self
.alive
.acquire()
.instrument_await("acquire_alive")
.await
.unwrap();
let mut inner = self.inner.write().instrument_await("acquire_inner").await;
let key = (addr.clone(), index);
match inner.channels.get(&key) {
Some(channel) if channel.epoch == epoch => {
inner.channels.remove(&key);
drop(inner);
permit.forget();
self.spawn_connect(key.0, key.1, epoch + 1);
}
_ => {}
}
}
fn spawn_connect(self: &Arc<Self>, addr: Address, index: u64, epoch: u64) {
let weak = Arc::downgrade(self);
let uri = addr.to_uri();
spawn(async move {
let mut interval = RECONNECT_INTERVAL;
while let Some(this) = weak.upgrade() {
debug!("connect to {uri}, index = {index}, epoch = {epoch}");
let (endpoint, socks5_proxy) = {
let cfg = &this.inner.read().await.endpoint_config;
(cfg.build(uri.clone()), cfg.socks5_proxy.clone())
};
match do_connect(endpoint, socks5_proxy).await {
Ok(channel) => {
debug!("connected to {uri}, index = {index}");
let mut inner = this.inner.write().await;
if inner.enabled_endpoints.contains(&addr) {
let channel = ChannelWithEpoch { channel, epoch };
inner.channels.insert((addr.clone(), index), channel);
this.alive.add_permits(1);
}
break;
}
Err(err) => {
debug!("failed to connect to {uri}: {err:?}");
if !this.inner.read().await.enabled_endpoints.contains(&addr) {
break;
}
let jitter = tokio::time::Duration::from_secs(
rand::rng().random_range(0..interval.as_secs()),
);
tokio::time::sleep(interval + jitter).await;
interval += jitter;
interval = interval.min(RECONNECT_INTERVAL_MAX);
}
}
}
});
}
}
pub struct TonicService {
pool: Arc<ChannelPool>,
#[allow(clippy::type_complexity)]
poll_fut:
Option<Pin<Box<dyn Future<Output = Result<ReadyChannel, Status>> + Send + Sync + 'static>>>,
is_draining: watch::Receiver<bool>,
ready: Option<ReadyChannel>,
prefer_address: Option<Address>,
config: TonicServiceConfig,
}
pub(crate) struct TonicServiceConfig {
pub(crate) warn_interval: Duration,
pub(crate) connection_timeout: Duration,
}
impl Default for TonicServiceConfig {
fn default() -> Self {
Self {
warn_interval: Duration::from_secs(30),
connection_timeout: Duration::from_secs(10 * 60),
}
}
}
#[derive(Clone, Debug)]
pub struct EndpointConfig {
pub conn_per_endpoint: u64,
pub connection_window_size: u32,
pub stream_window_size: u32,
pub mstreams_per_node: u64,
pub max_inflights_per_mstream: u64,
/// HTTP/2 keep alive is required for server behind LB, otherwise server
/// changes behind LB may not be detected by the client and the client may get stuck.
pub enable_h2_keep_alive: bool,
pub socks5_proxy: Option<String>,
/// Maximum time to fetch the sender of a preferred address. If this time is exceeded,
/// choose a random msender instead.
pub fetch_preferred_sender_timeout: Duration,
pub reset_mstream_interval: Duration,
}
// The default implementation is only used internally.
// each multiplexed stream could serve around 150 MB/s,
// assume client have 1 GB/s inbound, below are some recommended settings
// Prod - 10 read pods: mstreams_per_node >= 1
// Res - 20 read pods: mstreams_per_node >= 1
// Bench - 6 read pods: mstreams_per_node >= 2
// Dev - 2 read pods: mstreams_per_node >= 4
// Singel read pod: mstreams_per_node >= 8
// `mstreams_per_node` is parsed from `HugestoreOption.proxy_connections`, which can be set in ~/.huge-config.toml
impl Default for EndpointConfig {
fn default() -> Self {
Self {
conn_per_endpoint: 3,
connection_window_size: 1024 * 1024 * 1024,
stream_window_size: 2 * 1024 * 1024,
mstreams_per_node: 2,
max_inflights_per_mstream: 32,
enable_h2_keep_alive: true,
socks5_proxy: None,
fetch_preferred_sender_timeout: Duration::from_secs(1),
reset_mstream_interval: Duration::ZERO, // disable reset by default
}
}
}
impl EndpointConfig {
pub fn build(&self, uri: Uri) -> Endpoint {
let mut builder = Channel::builder(uri)
.initial_connection_window_size(Some(self.connection_window_size))
.initial_stream_window_size(Some(self.stream_window_size))
.tcp_keepalive(Some(DEFAULT_TIMEOUT))
.connect_timeout(DEFAULT_CONNECT_TIMEOUT);
if self.enable_h2_keep_alive {
builder = builder
.keep_alive_while_idle(true)
.http2_keep_alive_interval(Duration::from_secs(5))
.keep_alive_timeout(DEFAULT_TIMEOUT);
}
builder
}
}
struct ReadyChannel {
addr: Address,
index: u64,
channel: Channel,
epoch: u64,
}
impl Clone for TonicService {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
poll_fut: None,
is_draining: self.is_draining.clone(),
ready: None,
prefer_address: self.prefer_address.clone(),
config: Default::default(),
}
}
}
impl Service<http::Request<BoxBody>> for TonicService {
type Response = <Channel as Service<http::Request<BoxBody>>>::Response;
type Error = Box<dyn StdError + Send + Sync>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
if *self.is_draining.borrow() {
self.ready = None;
}
if self.ready.is_some() {
return Poll::Ready(Ok(()));
}
let poll_fut = self.poll_fut.get_or_insert_with(|| {
let pool = self.pool.clone();
let prefer_address = self.prefer_address.clone();
let mut is_draining = self.is_draining.clone();
let warn_interval = self.config.warn_interval;
let connection_timeout = self.config.connection_timeout;
Box::pin(async move {
loop {
// warn every 30s if no channel is ready
let mut warn_interval = tokio::time::interval(warn_interval);
let now = Instant::now();
// first tick happens immediately
warn_interval.tick().await;
loop {
crate::spurious_select! {
_ = warn_interval.tick() => {
if now.elapsed() > connection_timeout {
return Err(Status::unavailable(format!("no endpoint is connected in the past {connection_timeout:?}")));
}
warn!("service not ready, no endpoint is connected");
}
_permit = pool.alive.acquire().instrument_await("acquire_alive") => {
break
}
}
}
if *is_draining.borrow_and_update() {
info!("service is waiting for draining");
let _ = is_draining.changed().await;
}
let inner = pool.inner.read().instrument_await("acquire_inner").await;
debug_assert!(!inner.channels.is_empty());
let candidates = match &prefer_address {
Some(prefered_address) => {
debug!("Poll with preference {:?}", prefer_address);
let mut candidates: Vec<_> = inner
.channels
.range((
Included((prefered_address.clone(), 0)),
Included((
prefered_address.clone(),
inner.endpoint_config.conn_per_endpoint,
)),
))
.collect();
if candidates.is_empty() {
debug!("Service not ready, Locality is changing");
// Prefered channel drained, try next
candidates = inner
.channels
.range((Included((prefered_address.clone(), 0)), Unbounded))
.collect();
if candidates.is_empty() {
// choose first address if next is end
let first_addr =
inner.channels.first_key_value().unwrap().0 .0.clone();
candidates = inner
.channels
.range((
Included((first_addr.clone(), 0)),
Included((
first_addr.clone(),
inner.endpoint_config.conn_per_endpoint,
)),
))
.collect::<Vec<_>>();
} else {
let next_addr = candidates[0].0 .0.clone();
candidates.retain(|candidate| candidate.0 .0 == next_addr);
}
}
candidates
}
None => {
debug!("Poll without preference");
inner.channels.iter().collect::<Vec<_>>()
}
};
let ((addr, index), channel) =
candidates[rand::rng().random_range(0..candidates.len())];
let mut ready = ReadyChannel {
addr: addr.clone(),
index: *index,
channel: channel.channel.clone(),
epoch: channel.epoch,
};
drop(inner);
// `Channel::poll_ready()` returns errors only when an user request failed on
// the channel, so it can't be used to detect the channel is still alive,
// so as the http2 keepalive request which is not an user request.
match futures::future::poll_fn(|cx| ready.channel.poll_ready(cx))
.instrument_await("poll_ready")
.await
{
Ok(()) => {
if *is_draining.borrow() {
debug!("service is draining");
continue;
}
return Ok(ready);
}
Err(err) => {
warn!("transport error {err:?} on {}", ready.addr);
pool.clone()
.handle_transport_error(
ready.addr.clone(),
ready.index,
ready.epoch,
)
.instrument_await("handle_transport_error_in_poll_ready")
.await;
}
}
}
})
});
match poll_fut.poll_unpin(cx) {
Poll::Ready(res) => {
self.poll_fut = None;
match res {
Ok(ready) => {
self.ready = Some(ready);
Poll::Ready(Ok(()))
}
Err(err) => Poll::Ready(Err(Box::new(err))),
}
}
Poll::Pending => Poll::Pending,
}
}
fn call(&mut self, mut request: http::Request<BoxBody>) -> Self::Future {
let pool = self.pool.clone();
let mut ready = self.ready.take().unwrap();
let uuid = utils::get_bin_metadata(request.headers(), "tracer-bin", &mut [0; 24])
.and_then(|v| Uuid::from_slice(v).ok());
rewrite_uri(&mut request, &ready.addr.to_uri());
let api = request.uri().path().split('/').last().unwrap_or("");
let span: await_tree::Span = match uuid {
Some(v) => format!("call[{v:?}][{api}]").into(),
None => format!("call[{api}]").into(),
};
Box::pin(async move {
let res = ready.channel.call(request).instrument_await(span).await;
if let Err(err) = res.as_ref() {
warn!("transport error {err:?} on {}", ready.addr);
pool.handle_transport_error(ready.addr.clone(), ready.index, ready.epoch)
.instrument_await("handle_transport_error_in_call")
.await;
}
res.map_err(Into::into)
})
}
}
pub type TonicServiceWithToken = InterceptedService<TonicService, Interceptor>;
pub struct DataClientPool {
pool: Arc<ChannelPool>,
clients: HashMap<Address, FlightServiceClient<TonicServiceWithToken>>,
}
impl DataClientPool {
pub fn new(config: EndpointConfig) -> Self {
Self {
pool: ChannelPool::new(config),
clients: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.clients.is_empty()
}
pub fn fetch_random(&self) -> Option<FlightServiceClient<TonicServiceWithToken>> {
self.clients
.iter()
.choose(&mut rand::rng())
.map(|(_, client)| client.clone())
}
pub fn fetch_with_preferred_address(
&self,
addr: &Address,
) -> Option<FlightServiceClient<TonicServiceWithToken>> {
self.clients.get(addr).cloned()
}
pub fn fetch_all(&self) -> HashMap<Address, FlightServiceClient<TonicServiceWithToken>> {
self.clients.clone()
}
pub async fn add_address(
&mut self,
addr: Address,
interceptor: Interceptor,
) -> FlightServiceClient<TonicServiceWithToken> {
if self.clients.contains_key(&addr) {
return self.clients[&addr].clone();
}
let service = self.pool.make_tonic_service_with_preference(&addr).await;
let client = FlightServiceClient::with_interceptor(service, interceptor)
.max_encoding_message_size(crate::MAX_TONIC_MESSAGE_SIZE)
.max_decoding_message_size(crate::MAX_TONIC_MESSAGE_SIZE);
self.clients.insert(addr, client.clone());
client
}
pub async fn remove_address(&mut self, addr: &Address) -> bool {
self.clients.remove(addr);
self.pool.remove_endpoint(addr).await
}
pub async fn enabled_endpoints(&self) -> Vec<Address> {
self.pool.enabled_endpoints().await
}
pub fn make_tonic_service(&self) -> TonicService {
self.pool.make_tonic_service()
}
}
pub async fn create_token_v2(
meta_service: TonicService,
config: AuthNConfig,
) -> Result<String, CommonError> {
moist::async_fail_point!("create_token_v2");
let mut rbac_client = RbacClient::new(meta_service);
let max_nr_retry = 6;
// extract traditional config
let traditional_cfg = if let Some((username, password)) = &config.traditional {
Some(TraditionalConfig {
username: username.to_string(),
password: password.to_string(),
})
} else {
None
};
// extract auth service config
let auth_options = if let Some(auth_config) = &config.auth_service {
let (path, env, region) = if auth_config.credentail == CredentialType::Path {
(
auth_config.config_path.clone(),
Some(auth_config.env),
Some(auth_config.region),
)
} else {
(auth_config.config_path.clone(), None, None)
};
let auth_options = moist::auth::Options::from_file(path, env);
match auth_options {
Ok((options, env_region)) => {
let (env, region) = env_region
.or(match (env, region) {
(Some(env), Some(region)) => Some((env, region)),
_ => None,
}).ok_or_else(|| common_error::rbac_error!(
"env and region must be provided in hugestore config file or auth service config file"
)
)?;
Some((options, env, region))
}
Err(e) => {
if traditional_cfg.is_none() {
return Err(common_error::rbac_error!(
"Failed to read auth service config file: {}",
e.to_string()
));
} else {
warn!(
"Failed to read auth service config file: {}, use username/password to login hugestore..",
e.to_string()
);
None
}
}
}
} else {
None
};
let mut last_err = None;
for _ in 0..max_nr_retry {
moist::async_fail_point!("create_token_v2_loop");
let traditional_cfg = traditional_cfg.clone();
let create_token_result: Result<_, CommonError> = async {
let token = if let Some((options, env, region)) = &auth_options {
let auth_client = AuthClient::initialize(
options,
env.to_owned(),
region.to_owned(),
&RequestSettings::default(),
&CacheSettings::default(),
&ComplianceSettings::default(),
)
.await?;
let mut access_token = auth_client.self_access_token().await?;
(|| {
moist::fail_point!("common::self_access_token", |s| {
access_token = s.unwrap();
})
})();
Some(access_token)
} else {
None
};
let create_token_resp = rbac_client
.create_token_v2(Request::new(CreateTokenV2Request {
traditional_cfg,
access_token: token,
expire_in_days: 30,
}))
.await
.map_err(|e| {
common_error::rbac_error!("Failed to create token: {}", e.to_string())
})?;
Ok(create_token_resp)
}
.await;
match create_token_result {
Ok(res) => {
let res = res.into_inner();
let token = res.jwt.unwrap().token;
// use traditional config successfully, but failed to use auth service
if let Some(msg) = res.err_msg {
moist::async_fail_point!("common::create_token_v2::auth_error");
warn!("error about auth service: {}", msg);
}
return Ok(token);
}
Err(err) => {
warn!("failed to create token: {err:?}, retrying...");
last_err = Some(err);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
Err(common_error::rbac_error!(
"Failed to create token after maximum retry attempts, last error: {last_err:?}",
))
}
pub async fn create_token(
meta_service: TonicService,
username: String,
password: String,
) -> Result<String, CommonError> {
let mut rbac_client = RbacClient::new(meta_service);
let mut nr_retry = 0;
loop {
let res = rbac_client
.create_token(Request::new(CreateTokenRequest {
user_name: username.clone(),
password: password.clone(),
expire_in_days: 30,
}))
.await;
match res {
Ok(res) => {
let token = res.into_inner().token.unwrap().token;
info!("token obtained");
return Ok(token);
}
Err(status) if status.code() == Code::Unauthenticated => {
return Err(CommonError::unauthenticated(status));
}
Err(err) => {
warn!("failed to create token: {err:?}, retry");
nr_retry += 1;
if nr_retry > 3 {
return Err(common_error::rbac_error!(
"Failed to create token after maximum retry attempts: {}",
err
));
}
tokio::time::sleep(std::time::Duration::from_millis(nr_retry * 100)).await;
}
}
}
}
// use this function if you want stream response from data server otherwise data_single_action will
// be a better choice
#[instrument(skip_all)]
pub async fn proxy_do_action(
proxy: &mut FlightServiceClient<TonicServiceWithToken>,
req: huge_pb::proxy::proxy_request::Request,
) -> Result<tonic::Response<Streaming<arrow_flight::Result>>, tonic::Status> {
let proxy_req = huge_pb::proxy::ProxyRequest { request: Some(req) };
let mut body = Vec::new();
proxy_req
.encode(&mut body)
.map_err(|e| Status::internal(format!("failed to encode request: {:?}", e.to_string())))?;
let action = arrow_flight::Action {
r#type: "".into(),
body: body.into(),
};
proxy.do_action(tonic::Request::new(action)).await
}
// do action and fetch only the first message from stream response
#[instrument(skip_all)]
pub async fn data_single_action(
data: &mut FlightServiceClient<TonicServiceWithToken>,
req: huge_pb::proxy::proxy_request::Request,
) -> Result<huge_pb::proxy::ProxyResponse, tonic::Status> {
moist::async_fail_point!("data_single_action_error", |_| {
Err(tonic::Status::unknown("data_single_action fail_point"))
});
let mut stream = proxy_do_action(data, req).await?.into_inner();
let body = stream
.message()
.await?
.ok_or_else(|| Status::internal("empty response from do_action, this should not happen"))?
.body;
huge_pb::proxy::ProxyResponse::decode(body).map_err(|e| {
error!("failed to decode proxy response: {}", e.to_string());
Status::internal(e.to_string())
})
}
pub async fn data_action_with_backoff(
data: FlightServiceClient<TonicServiceWithToken>,
req_fn: impl Fn() -> huge_pb::proxy::proxy_request::Request + Clone,
backoff: impl BackoffBuilder,
) -> Result<huge_pb::proxy::ProxyResponse, tonic::Status> {
let (_, resp) = (|mut data| async {
let req = req_fn();
let resp = data_single_action(&mut data, req).await;
(data, resp)
})
.retry(&backoff)
.context(data)
.when(is_tonic_status_retryable)
.notify(|e, duration| {
warn!("retryable error: {:?}, sleep after {:?}", e, duration);
})
.await;
resp
}
#[cfg(test)]
mod tests {
use std::{os::fd::AsRawFd, sync::Mutex, time::Instant};
use super::*;
use futures::{Stream, TryStreamExt};
use tokio::net::TcpListener;
use tonic::{
Response, async_trait,
transport::{Server, server::TcpIncoming},
};
use tonic_health::{
ServingStatus,
pb::{
HealthCheckRequest, HealthCheckResponse,
health_client::HealthClient,
health_server::{Health, HealthServer},
},
};
struct HealthService;
#[async_trait]
impl Health for HealthService {
async fn check(
&self,
_request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(Response::new(HealthCheckResponse {
status: ServingStatus::Serving as i32,
}))
}
type WatchStream =
Pin<Box<dyn Stream<Item = Result<HealthCheckResponse, Status>> + Send + 'static>>;
async fn watch(
&self,
_request: Request<HealthCheckRequest>,
) -> Result<Response<Self::WatchStream>, Status> {
let output = async_stream::try_stream! {
loop {
let status = ServingStatus::Serving as i32;
yield HealthCheckResponse { status };
tokio::time::sleep(Duration::from_millis(100)).await;
}
#[allow(unreachable_code)]
()
};
Ok(Response::new(Box::pin(output) as Self::WatchStream))
}
}
#[tokio::test]
async fn test_multiplex() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
spawn(async move {
Server::builder()
.add_service(HealthServer::new(HealthService))
.serve_with_incoming(TcpIncoming::from_listener(listener, true, None).unwrap())
.await
.unwrap();
});
let uri = Uri::try_from(format!("http://{}", addr)).unwrap();
let service = ChannelPool::new_with_uris([uri], Default::default())
.await
.make_tonic_service();
// unary
let begin = Instant::now();
futures::future::join_all((0..100).map(|_| async {
let mut client = HealthClient::new(service.clone());
let resp = client
.check(HealthCheckRequest::default())
.await
.unwrap()
.into_inner();
assert_eq!(resp.status, ServingStatus::Serving as i32);
}))
.await;
assert!(begin.elapsed() < Duration::from_secs(10));
// stream
let begin = Instant::now();
futures::future::join_all((0..100).map(|_| async {
let mut client = HealthClient::new(service.clone());
let mut stream = client
.watch(HealthCheckRequest::default())
.await
.unwrap()
.into_inner();
for _ in 0..3 {
let resp = stream.try_next().await.unwrap().unwrap();
assert_eq!(resp.status, ServingStatus::Serving as i32);
}
}))
.await;
assert!(begin.elapsed() < Duration::from_secs(10));
}
#[tokio::test]
async fn test_reconnect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let tcp_fds = Arc::new(Mutex::new(Vec::new()));
let tcp_fds_ = tcp_fds.clone();
let listener_stream =
tokio_stream::wrappers::TcpListenerStream::new(listener).inspect_ok(move |tcp| {
tcp_fds_.lock().unwrap().push(tcp.as_raw_fd());
});
spawn(async move {
Server::builder()
.add_service(HealthServer::new(HealthService))
.serve_with_incoming(listener_stream)
.await
.unwrap();
});
let uri = Uri::try_from(format!("http://{}", addr)).unwrap();
let service = ChannelPool::new_with_uris([uri], Default::default())
.await
.make_tonic_service();
let semaphore = Arc::new(Semaphore::new(0));
let mut handles = Vec::new();
for _ in 0..100 {
let service = service.clone();
let semaphore = semaphore.clone();
handles.push(spawn(async move {
let mut client = HealthClient::new(service.clone());
let mut stream = client
.watch(HealthCheckRequest::default())
.await?
.into_inner();
semaphore.add_permits(1);
while let Some(resp) = stream.try_next().await? {
assert_eq!(resp.status, ServingStatus::Serving as i32);
}
Ok::<_, anyhow::Error>(())
}));
}
semaphore.acquire_many(100).await.unwrap().forget();
let fds = tcp_fds.lock().unwrap().clone();
for fd in fds {
unsafe {
libc::shutdown(fd, libc::SHUT_RDWR);
}
}
for handle in handles {
assert!(handle.await.unwrap().is_err());
}
let mut client = HealthClient::new(service.clone());
let resp = client
.check(HealthCheckRequest::default())
.await
.unwrap()
.into_inner();
assert_eq!(resp.status, ServingStatus::Serving as i32);
}
#[tokio::test]
async fn test_connection_timeout() {
let mut tonic_service = ChannelPool::new(Default::default()).make_tonic_service();
tonic_service.config.warn_interval = Duration::from_millis(1);
tonic_service.config.connection_timeout = Duration::from_millis(1);
let mut client = HealthClient::new(tonic_service);
let err = client
.check(HealthCheckRequest::default())
.await
.unwrap_err();
// error returned by tonic poll_ready is always unknown
assert_eq!(err.code(), Code::Unknown);
assert!(err.message().contains("no endpoint is connected"));
}
}
/// Fault tolerable Fragments Resolver
///
/// Meta-Server
/// ^ |
/// | | -> agent request on-track, retry them when recover from error
/// (resolve)| v (response)
/// MetaResolveAgent -> Rebuilt the meta connection when notified meta-server offline. while met critical error, notify all the code readers through ReaderProducer
/// ^
/// | (resolve fragment)
/// ------^------
/// | |
/// ReaderProducer ReaderProducer -> forge reader from resolved ticket
/// | | -> streamingly send readers
/// v v
/// CodeReader1, CodeReader2, ...
use await_tree::InstrumentAwait;
use futures::future::try_join_all;
use huge_common::{
ALL_CODES, AbortableHandle, TonicServiceWithToken,
error::{CommonErrorCtx, HugeErrorKind},
tracer::with_local_tracer,
utils::is_tonic_status_retryable,
};
use huge_pb::{
LOGICAL_TIME_MAX, LOGICAL_TIME_MIN,
common::DataDescriptor,
meta::{
FragmentsOfCodeV2, ResolveFragmentsRequest, ResolveFragmentsResponse,
meta_client::MetaClient,
},
proxy::Ticket,
};
use moist::async_fail_point;
use std::{
collections::{HashMap, VecDeque},
sync::Arc,
time::Duration,
};
use tokio::sync::{
Mutex,
mpsc::{Receiver, Sender},
oneshot,
};
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
use tonic::Streaming;
use tracing::info;
use crate::{InternalReadOptions, errors::ClientError, utils::wait_exponential};
#[async_trait::async_trait]
pub trait ResolveFragment: Send + Sync {
async fn resolve(
&self,
code: String,
request: ResolveFragmentsRequest,
) -> Result<ResolveFragmentsResponse, ClientError>;
/// Resolve all fragments in the given range in timely order. Can't handle read_latest, only 1
/// code is allowed
async fn resolve_all(
&self,
code: String,
mut req: ResolveFragmentsRequest,
) -> Result<ResolveFragmentsResponse, ClientError> {
if req.read_latest {
return Err(ClientError::internal(
"resolve_all can't handle read_latest",
));
}
if req.version() == huge_pb::meta::resolve_fragments_request::Version::Deprecated {
return Err(ClientError::internal("resolve_all can't handle v1 request"));
}
if req.data.is_none() {
return Err(ClientError::internal("data descriptor is missing"));
}
if req.data.as_ref().unwrap().codes != vec![code.clone()] || code == *ALL_CODES {
return Err(ClientError::internal(
"resolve_all can't handle multiple codes",
));
}
let mut partitions = vec![];
let mut columns = vec![];
loop {
let (start, end) = {
let dd = req.data.as_ref().unwrap();
(dd.start, dd.end)
};
match (start, end) {
(Some(start), Some(end)) if start < end => {
let foc = self
.resolve(code.clone(), req.clone())
.await?
.codes_v2
.remove(0);
if columns.is_empty() {
columns = foc.columns;
}
if foc.partitions.is_empty() {
break;
}
let end = foc
.partitions
.iter()
.map(|p| *p.end.as_ref().unwrap())
.max()
.ok_or_else(|| {
ClientError::internal("end of resolved partition is none")
})?;
req.data.as_mut().unwrap().start = Some(end);
partitions.extend(foc.partitions);
}
_ => break,
};
}
for p in partitions.windows(2) {
debug_assert_eq!(p[0].end.as_ref().unwrap(), p[1].start.as_ref().unwrap());
if p[0].end.as_ref().unwrap() != p[1].start.as_ref().unwrap() {
tracing::warn!(
"resolve dataset[{}] with code {}: partition end {:?} != next partition start {:?}",
req.data.as_ref().unwrap().dataset_id,
code,
p[0].end.as_ref().unwrap(),
p[1].start.as_ref().unwrap()
);
}
}
Ok(ResolveFragmentsResponse {
codes_v2: vec![FragmentsOfCodeV2 {
code,
columns,
partitions,
}],
})
}
}
pub async fn resolve_all_codes<R: ResolveFragment + 'static>(
resolver: Arc<R>,
req: ResolveFragmentsRequest,
concurrency: usize,
) -> Result<ResolveFragmentsResponse, ClientError> {
assert!(req.version() == huge_pb::meta::resolve_fragments_request::Version::V2);
let codes = req.data.as_ref().unwrap().codes.clone();
let mut codes_v2 = vec![];
let mut tasks = Vec::with_capacity(codes.len());
let sema = Arc::new(tokio::sync::Semaphore::new(concurrency));
for code in codes {
let mut req = req.clone();
let sema = sema.clone();
let resolver = resolver.clone();
let task = tokio::spawn(async move {
let _permit = sema.acquire().await;
req.data.as_mut().unwrap().codes = vec![code.clone()];
resolver.resolve_all(code, req.clone()).await
});
tasks.push(async move { AbortableHandle::new(task).await? });
}
for resp in try_join_all(tasks).await? {
codes_v2.extend(resp.codes_v2);
}
Ok(ResolveFragmentsResponse { codes_v2 })
}
/// An agent to resolve fragments with `AgentRequest`
#[derive(Clone)]
pub struct MetaResolveAgent {
agent_req_txs: Vec<Sender<AgentRequest>>,
router: Arc<std::sync::atomic::AtomicUsize>,
}
pub const SLOT_DIVISOR: u64 = 1000;
impl MetaResolveAgent {
// Build a MetaResolveAgent that accept requests for speicific code
pub fn new_per_code(
meta: MetaClient<TonicServiceWithToken>,
// The number of `resolve_fragments` streams.
// Usually, it's `(columns.len() / SLOT_DIVISOR).max(1)`.
num_slots: u64,
max_retry_times: u32,
) -> Arc<Self> {
Self::new::<PerCodeRequestTracker>(meta, num_slots, max_retry_times)
}
// Build a MetaResolveAgent that only accept requests for __ALL_CODES__
pub fn new_all_codes(
meta: MetaClient<TonicServiceWithToken>,
max_retry_times: u32,
) -> Arc<Self> {
Self::new::<AllCodesRequestTracker>(meta, 1, max_retry_times)
}
fn new<T>(
meta: MetaClient<TonicServiceWithToken>,
// The number of `resolve_fragments` streams.
// Usually, it's `(columns.len() / SLOT_DIVISOR).max(1)`.
num_slots: u64,
max_retry_times: u32,
) -> Arc<Self>
where
T: AgentRequestTracker + Send + 'static,
{
let mut agent_req_txs = vec![];
for _ in 0..num_slots {
let (agent_req_tx, agent_req_rx) = tokio::sync::mpsc::channel(64);
let backend =
MetaResolveAgentBackend::<T>::new(meta.clone(), agent_req_rx, max_retry_times);
huge_common::spawn(with_local_tracer("resolve_backend", async move {
backend.run().await;
}));
agent_req_txs.push(agent_req_tx);
}
Arc::new(Self {
agent_req_txs,
router: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
})
}
}
#[async_trait::async_trait]
impl ResolveFragment for MetaResolveAgent {
async fn resolve(
&self,
code: String,
request: ResolveFragmentsRequest,
) -> Result<ResolveFragmentsResponse, ClientError> {
let (resp_tx, resp_rx) = oneshot::channel();
let agent_req = AgentRequest {
code,
request,
resp_tx,
};
let slot = self
.router
.fetch_add(1, std::sync::atomic::Ordering::Release)
% self.agent_req_txs.len();
self.agent_req_txs[slot]
.send(agent_req)
.instrument_await("send")
.await
.map_err(|_| {
ClientError::cancelled_by_other_error("fragements resovler sender error")
})?;
match resp_rx.await {
Ok(resp) => resp,
Err(_) => Err(ClientError::cancelled_by_other_error(
"fragements resovler receiver error",
)),
}
}
}
// Help to resolve desired fragments, then streamingly send the flight info to downstream.
// After all fragments were resolved, it will close
pub struct TicketProducer {
pub code: String,
pub fragment_resolver: Arc<dyn ResolveFragment>,
pub address: Sender<Result<Ticket, ClientError>>,
pub data_descriptor: DataDescriptor,
pub include_indices: [bool; 3],
pub internal_read_options: InternalReadOptions,
}
impl TicketProducer {
#[inline]
pub async fn run(self) {
debug_assert_eq!(self.data_descriptor.codes.len(), 1);
// Otherwise response will contain multiple codes.
debug_assert!(
self.data_descriptor.use_range_partition
|| self.data_descriptor.codes != *crate::ALL_CODES_VEC
);
if self.internal_read_options.read_latest {
self.run_for_latest_day().await;
} else {
self.run_for_all().await;
}
}
/// Resolves all partitions in the given range in timely order. Produces
/// readers iteratively.
async fn run_for_all(mut self) {
// endless resolve, get notified when new data comes
let endless = self.data_descriptor.end.is_none();
if endless {
self.data_descriptor.end = Some(*LOGICAL_TIME_MAX);
}
loop {
// If start reaches end, we are done.
match (self.data_descriptor.start, self.data_descriptor.end) {
(Some(start), Some(end)) if start < end => {}
_ => return,
}
let data = Some(self.data_descriptor.clone());
let req = ResolveFragmentsRequest {
data,
version: huge_pb::meta::resolve_fragments_request::Version::V2.into(),
read_latest: false,
ask_partitions: 0,
not_use_cache: false,
filter_empty: !endless, // endless resolve need empty partition end hint to generate promise
};
let resp = self
.fragment_resolver
.resolve(self.code.clone(), req)
.instrument_await("resolve")
.await;
match resp {
Ok(mut resp) => {
let code = resp.codes_v2.remove(0);
let ordered_part = code.partitions;
if ordered_part.is_empty() {
if endless {
// Make a promise with next lower bound
if self
.address
.send(Ok(Ticket::new_promise(self.data_descriptor.start.unwrap())))
.instrument_await("send_promise")
.await
.is_err()
{
return;
}
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
} else {
// resolve to the end, break the loop
return;
}
}
for partition in ordered_part {
let partition_start = partition.start;
let partition_end = partition.end;
let start = match (partition_start, self.data_descriptor.start) {
(Some(p_start), Some(dd_start)) => Some(p_start.max(dd_start)),
_ => self.data_descriptor.start,
};
let end = match (partition_end, self.data_descriptor.end) {
(Some(p_end), Some(dd_end)) => Some(p_end.min(dd_end)),
_ => self.data_descriptor.end,
};
// filter out empty partition
if partition.fragments.is_empty() {
self.data_descriptor.start = end;
continue;
}
// filter out out of range partition
let (min_frag_start, max_frag_end) = partition
.fragments
.iter()
.fold((*LOGICAL_TIME_MAX, *LOGICAL_TIME_MIN), |acc, f| {
(acc.0.min(f.start.unwrap()), acc.1.max(f.end.unwrap()))
});
if min_frag_start >= end.unwrap() || max_frag_end <= start.unwrap() {
async_fail_point!("TicketProducer::fragments_out_of_range");
self.data_descriptor.start = end;
continue;
}
let ticket = Ticket {
partition_id: partition.id_to_deprecate,
code_v2: Some(FragmentsOfCodeV2 {
columns: code.columns.clone(),
partitions: vec![partition],
code: code.code.clone(),
}),
include_indices: self.include_indices.into(),
start,
end,
dataset_id: self.data_descriptor.dataset_id,
bypass_cache: false,
support_nanosecond_index: self.support_nanosecond_index(),
};
if self
.address
.send(Ok(ticket))
.instrument_await("send_address")
.await
.is_err()
{
return;
}
self.data_descriptor.start = end;
}
}
Err(err) if err.huge_kind() == HugeErrorKind::CancelledByOtherError => {
// Avoid broadcast the error, sliently close
return;
}
Err(err) => {
let _ = self.address.send(Err(err)).await;
return;
}
}
}
}
/// Resolves all partitions in the latest day, still produces readers in
/// timely ascending order.
///
/// It does not guarantee only outputing data in one single day. This only
/// happens when one fragment contains data from multiple days and user
/// requested only part of them. Downstream consumer is responsible for
/// additional filtering.
///
/// The performance is worse than `run` because it cannot pipeline the
/// handling of different partitions and readers.
async fn run_for_latest_day(mut self) {
// `run_for_latest_day` does not support endless resolve
debug_assert!(self.data_descriptor.end.is_some());
// Match the default value of `max_partitions_returned`.
const ASK_PARTITIONS: usize = 8;
let mut columns = Vec::with_capacity(1);
let mut partitions = VecDeque::with_capacity(ASK_PARTITIONS);
let mut latest_date = None;
let user_end = self.data_descriptor.end.unwrap();
// This is required by the assertion:
// `latest_date < user_end.trading_date`.
assert!(user_end.raw_time_is_zero());
#[cfg(debug_assertions)]
let mut received_focs = Vec::new();
'out: loop {
let req = ResolveFragmentsRequest {
data: Some(self.data_descriptor.clone()),
version: huge_pb::meta::resolve_fragments_request::Version::V2.into(),
read_latest: true,
ask_partitions: ASK_PARTITIONS as i32,
not_use_cache: false,
filter_empty: false,
};
let resp = self.fragment_resolver.resolve(self.code.clone(), req).await;
match resp {
Ok(mut resp) => {
let foc = resp.codes_v2.remove(0);
let len = foc.partitions.len();
#[cfg(debug_assertions)]
received_focs.push(foc.clone());
if foc.partitions.is_empty() {
// resolve to the end, break the loop
break;
}
debug_assert!(
foc.partitions
.windows(2)
.all(|p| { p[0].end.unwrap() <= p[1].start.unwrap() }),
"{:?}",
foc
);
columns.push(foc.columns);
for partition in foc.partitions.into_iter().rev() {
assert!(partition.start.is_some());
assert!(partition.end.is_some());
if partition.fragments.is_empty() {
self.data_descriptor.end = partition.start;
continue;
}
// TODO: Is this true for data truncate?
debug_assert!(partition.fragments[0].num_rows > 0);
// For feature fragments, the boundaries might not be aligned.
let tight_start_time = std::cmp::max(
partition.start.unwrap(),
partition
.fragments
.iter()
.map(|f| f.start.unwrap())
.min()
.unwrap(),
);
let tight_end_time = std::cmp::min(
partition.end.unwrap(),
partition
.fragments
.iter()
.map(|f| f.end.unwrap())
.max()
.unwrap(),
);
if tight_start_time >= self.data_descriptor.end.unwrap() {
debug_assert!(self.data_descriptor.end >= partition.start);
self.data_descriptor.end = partition.start;
continue;
}
if let Some(latest_date) = latest_date {
if latest_date > tight_end_time.trading_date {
break 'out;
}
} else {
// TODO:
// We want to use `tight_end_date` as the latest date if it
// is within range, it relies on invariant:
// `frag.end_time.date == frag.last_row.date`
//
// But the invariant does not hold when truncation happened.
assert!(tight_start_time.trading_date < user_end.trading_date);
// This is not accurate, the actual date might be larger.
// Consider this case:
// part = [day_0, day_1, day_10], part.end = day_10_1,
// read_latest(until = day_9).
// latest_date will be day_0.
latest_date = Some(tight_start_time.trading_date);
}
self.data_descriptor.end = partition.start;
partitions.push_front((partition, columns.len() - 1));
}
if len < ASK_PARTITIONS {
break;
}
}
Err(err) if err.huge_kind() == HugeErrorKind::CancelledByOtherError => {
// Avoid broadcast the error, sliently close
return;
}
Err(err) => {
let _ = self.address.send(Err(err)).await;
return;
}
}
}
#[cfg(debug_assertions)]
{
let key = format!(
"read_latest_{}_{}",
self.data_descriptor.dataset_id, self.code
);
let msg = format!(
"user_end = {:?}, received_focs = {:?}, latest_date = {:?}",
user_end, received_focs, latest_date
);
huge_common::test::debug_info_map().insert(key, msg);
}
if latest_date.is_none() {
return;
}
let window_start = huge_pb::common::LogicalTime::new_with_micro(latest_date.unwrap(), 0);
for (partition, column_idx) in partitions {
let start = partition.start.unwrap().max(window_start);
let end = user_end.min(partition.end.unwrap());
let ticket = Ticket {
partition_id: partition.id_to_deprecate,
code_v2: Some(FragmentsOfCodeV2 {
columns: columns[column_idx].clone(),
partitions: vec![partition],
code: self.code.clone(),
}),
include_indices: self.include_indices.into(),
start: Some(start),
end: Some(end),
dataset_id: self.data_descriptor.dataset_id,
bypass_cache: false,
support_nanosecond_index: self.support_nanosecond_index(),
};
if self.address.send(Ok(ticket)).await.is_err() {
return;
}
}
}
fn support_nanosecond_index(&self) -> bool {
moist::fail_point!("TicketProducer::old_client", |_| false);
true
}
}
/// State transition
///
/// 0. Init
/// * Upstream is built
/// Init --> Recover(0)
/// * build upstream failed, but could retry
/// Init --> RetryableFailure
/// * build upstream failed, but could not retry
/// Init --> Finish with error
/// 1. Normal
/// * a retryable error is received
/// Normal --> RetryableFailure
/// * a critical error is received
/// Normal --> Finish with error
/// * order channel closed
/// Normal --> Finish OK
/// 3. Finish
/// * exit state
/// 4. RetryableFailure
/// * upstream is rebuilt
/// RetryableFailure --> Normal
/// * rebuild upstream failed, but could retry
/// RetryableFailure --> RetryableFailure
/// * rebuild upstream failed, and could not retry
/// RetryableFailure --> Finish with error
#[derive(Debug)]
enum AgentState {
Init,
Normal {
sender: Sender<ResolveFragmentsRequest>,
receiver: Streaming<ResolveFragmentsResponse>,
},
Finish {
error: Option<ClientError>,
},
RetryableFailure {
retry_left: u32,
},
}
trait AgentRequestTracker: Default {
fn complete_agent_req(&mut self, resp: ResolveFragmentsResponse);
fn refs(&self) -> Box<dyn Iterator<Item = &AgentRequest> + '_ + Send>;
fn drain(&mut self) -> Box<dyn Iterator<Item = AgentRequest> + '_>;
fn insert(&mut self, req: AgentRequest) -> bool;
fn is_empty(&self) -> bool {
self.refs().next().is_none()
}
}
#[derive(Default)]
struct PerCodeRequestTracker {
reqs_ontrack: HashMap<String, AgentRequest>,
}
impl AgentRequestTracker for PerCodeRequestTracker {
fn complete_agent_req(&mut self, resp: ResolveFragmentsResponse) {
assert_eq!(resp.codes_v2.len(), 1);
let req = self
.reqs_ontrack
.remove(&resp.codes_v2[0].code)
.expect("unexpected: phantom agent request");
req.resp_tx
.send(Ok(resp))
.expect("unexpected: Reader closed");
}
fn refs(&self) -> Box<dyn Iterator<Item = &AgentRequest> + '_ + Send> {
Box::new(self.reqs_ontrack.values())
}
fn drain(&mut self) -> Box<dyn Iterator<Item = AgentRequest> + '_> {
Box::new(self.reqs_ontrack.drain().map(|(_, req)| req))
}
fn insert(&mut self, req: AgentRequest) -> bool {
self.reqs_ontrack.insert(req.code.clone(), req).is_none()
}
}
#[derive(Default)]
struct AllCodesRequestTracker {
req: Option<AgentRequest>,
}
impl AgentRequestTracker for AllCodesRequestTracker {
fn complete_agent_req(&mut self, resp: ResolveFragmentsResponse) {
self.req
.take()
.expect("unexpected: phantom agent request")
.resp_tx
.send(Ok(resp))
.expect("unexpected: Reader closed");
}
fn refs(&self) -> Box<dyn Iterator<Item = &AgentRequest> + '_ + Send> {
Box::new(self.req.iter())
}
fn drain(&mut self) -> Box<dyn Iterator<Item = AgentRequest> + '_> {
Box::new(self.req.take().into_iter())
}
fn insert(&mut self, req: AgentRequest) -> bool {
assert!(req.code == *ALL_CODES);
self.req.replace(req).is_none()
}
}
/// MetaResolveAgentBackend runs as the bridge between MetaResolveAgent and meta connection
/// It can automatically retry on retryable failures from the meta connection
struct MetaResolveAgentBackend<T>
where
T: AgentRequestTracker,
{
meta: MetaClient<TonicServiceWithToken>,
agent_req_rx: Receiver<AgentRequest>,
max_retry_times: u32,
agent_reqs_ontrack: T,
}
struct AgentRequest {
code: String, // Reader should guarantee there is at most 1 order in-fly
request: ResolveFragmentsRequest,
resp_tx: oneshot::Sender<Result<ResolveFragmentsResponse, ClientError>>,
}
impl<T> MetaResolveAgentBackend<T>
where
T: AgentRequestTracker,
{
pub fn new(
meta: MetaClient<TonicServiceWithToken>,
agent_req_rx: Receiver<AgentRequest>,
max_retry_times: u32,
) -> Self {
Self {
meta,
agent_req_rx,
max_retry_times,
agent_reqs_ontrack: T::default(),
}
}
pub async fn build_meta_connection(&mut self, retry_left: u32) -> AgentState {
let (req_tx, req_rx) = tokio::sync::mpsc::channel::<ResolveFragmentsRequest>(64);
let res = match self
.meta
.resolve_fragments_v2(ReceiverStream::new(req_rx))
.await
.map(|resp| resp.into_inner())
{
Ok(streaming) => AgentState::Normal {
sender: req_tx,
receiver: streaming,
},
Err(err) => self.handle_error(err, retry_left),
};
moist::async_fail_point!("MetaResolveAgentBackend::after_build_meta_connection");
res
}
fn handle_error(&mut self, status: tonic::Status, retry_left: u32) -> AgentState {
if is_tonic_status_retryable(&status) {
if retry_left > 0 {
info!("Resolve fragments connection error: {:?}, retry...", status);
AgentState::RetryableFailure {
retry_left: retry_left - 1,
}
} else {
let error = if self.max_retry_times == 0 {
Some(status.into())
} else {
Some(ClientError::internal("Reach retry limit"))
};
AgentState::Finish { error }
}
} else {
AgentState::Finish {
error: Some(status.into()),
}
}
}
async fn execute_normal(
&mut self,
sender: Sender<ResolveFragmentsRequest>,
mut receiver: Streaming<ResolveFragmentsResponse>,
) -> AgentState {
// recover the agent reqs in-fly from last failure
for agent_req in self.agent_reqs_ontrack.refs() {
let _ = sender
.send(agent_req.request.clone())
.instrument_await("recover_send")
.await;
}
loop {
huge_common::spurious_select! {
resp = receiver.next() => {
match resp {
Some(Ok(resp)) => {
self.agent_reqs_ontrack.complete_agent_req(resp);
}
Some(Err(error)) => {
info!("MetaResolveAgentBackend encounters error: {:?}", error);
return self.handle_error(error, self.max_retry_times);
}
None => unreachable!("unexpected: meta-server closed the stream"),
}
}
agent_req = self.agent_req_rx.recv() => {
if let Some(agent_req) = agent_req {
let req = agent_req.request.clone();
assert!(self.agent_reqs_ontrack.insert(agent_req));
let _ = sender.send(req).instrument_await("send").await; // ignore the error in send
} else {
// reqs senders are all dropped which means resovle fragments completed
debug_assert!(self.agent_reqs_ontrack.is_empty());
return AgentState::Finish { error: None };
}
}
}
}
}
fn cancel_agent_reqs_ontrack(&mut self, error: ClientError) {
let mut reqs = self.agent_reqs_ontrack.drain();
if let Some(req) = reqs.next() {
let _ = req.resp_tx.send(Err(error));
}
for req in reqs {
let _ = req.resp_tx.send(Err(ClientError::cancelled_by_other_error(
"cancel agent requests ontrack",
)));
}
}
pub async fn run(mut self) {
let mut state = AgentState::Init;
loop {
state = match state {
AgentState::Init => self.build_meta_connection(self.max_retry_times).await,
AgentState::Normal { sender, receiver } => {
self.execute_normal(sender, receiver)
.instrument_await("execute_normal")
.await
}
AgentState::Finish { error } => {
if let Some(error) = error {
tracing::error!("Resolve fragments meets critical error {:?}", error);
self.cancel_agent_reqs_ontrack(error);
}
break;
}
AgentState::RetryableFailure { retry_left } => {
moist::async_fail_point!("MetaResolveAgentBackend::handle_retryable_failure");
wait_exponential(self.max_retry_times - retry_left - 1)
.instrument_await(format!("wait[retry_left={retry_left}]"))
.await;
self.build_meta_connection(retry_left).await
}
}
}
tracing::debug!("MetaResolveAgentBackend closed");
}
}
#[derive(Clone)]
pub struct MockResolver {
all_codes_v2: HashMap<String, Arc<Mutex<VecDeque<FragmentsOfCodeV2>>>>,
}
impl MockResolver {
pub fn new(all_codes_v2: HashMap<String, VecDeque<FragmentsOfCodeV2>>) -> Self {
Self {
all_codes_v2: all_codes_v2
.into_iter()
.map(|(code, partitions)| (code, Arc::new(Mutex::new(partitions))))
.collect(),
}
}
pub fn new_from_resolve_response(resp: ResolveFragmentsResponse) -> Self {
let mut all_codes_v2 = HashMap::new();
resp.codes_v2.into_iter().for_each(|foc| {
all_codes_v2
.entry(foc.code.clone())
.or_insert_with(VecDeque::new)
.push_back(foc);
});
Self::new(all_codes_v2)
}
}
#[async_trait::async_trait]
impl ResolveFragment for MockResolver {
async fn resolve(
&self,
code: String,
_req: ResolveFragmentsRequest,
) -> Result<ResolveFragmentsResponse, ClientError> {
let mut codes = self
.all_codes_v2
.get(&code)
.ok_or_else(|| ClientError::internal(format!("code {} not found", code).as_str()))?
.lock()
.await;
let (partitions, columns) = if codes.is_empty() {
(vec![], vec![])
} else {
let FragmentsOfCodeV2 {
columns,
partitions,
..
} = codes.pop_front().unwrap();
(partitions, columns)
};
Ok(ResolveFragmentsResponse {
codes_v2: vec![FragmentsOfCodeV2 {
code: code.clone(),
partitions,
columns,
}],
})
}
}
#[cfg(test)]
mod tests {
use huge_common::test::{assert_contains, make_time};
use huge_pb::meta::PartitionInfoV2;
use super::*;
#[tokio::test]
async fn test_resolve_all() {
let foc = FragmentsOfCodeV2 {
code: "btc".to_string(),
partitions: vec![PartitionInfoV2 {
id_to_deprecate: 0,
code: "btc".to_string(),
start: Some(make_time("20230101").into()),
end: Some(make_time("20230102").into()),
fragments: vec![],
..Default::default()
}],
columns: vec![],
};
let resolver = MockResolver::new(HashMap::from([(
"btc".to_string(),
VecDeque::from(vec![foc.clone()]),
)]));
let data = DataDescriptor {
dataset_id: 0,
// Compaction is only supported for the patch version 0 (major version)
start: Some(make_time("20230101").into()),
end: Some(make_time("20230102").into()),
codes: vec!["btc".to_string()],
use_range_partition: false,
targets: vec![],
};
let res = resolver
.resolve_all(
"btc".to_string(),
ResolveFragmentsRequest {
data: Some(data.clone()),
version: huge_pb::meta::resolve_fragments_request::Version::V2.into(),
read_latest: true,
ask_partitions: 0,
not_use_cache: false,
filter_empty: false,
},
)
.await;
assert_contains(
res.unwrap_err().to_string(),
"resolve_all can't handle read_latest",
);
let res = resolver
.resolve_all(
"btc".to_string(),
ResolveFragmentsRequest {
data: None,
version: huge_pb::meta::resolve_fragments_request::Version::V2.into(),
read_latest: false,
ask_partitions: 0,
not_use_cache: false,
filter_empty: false,
},
)
.await;
assert_contains(res.unwrap_err().to_string(), "data descriptor is missing");
let res = resolver
.resolve_all(
"apple".to_string(),
ResolveFragmentsRequest {
data: Some(data.clone()),
version: huge_pb::meta::resolve_fragments_request::Version::V2.into(),
read_latest: false,
ask_partitions: 0,
not_use_cache: false,
filter_empty: false,
},
)
.await;
assert_contains(
res.unwrap_err().to_string(),
"resolve_all can't handle multiple codes",
);
let res = resolver
.resolve_all(
"btc".to_string(),
ResolveFragmentsRequest {
data: Some(data.clone()),
version: huge_pb::meta::resolve_fragments_request::Version::Deprecated.into(),
read_latest: false,
ask_partitions: 0,
not_use_cache: false,
filter_empty: false,
},
)
.await;
assert_contains(
res.unwrap_err().to_string(),
"resolve_all can't handle v1 request",
);
let res = resolver
.resolve_all(
"btc".to_string(),
ResolveFragmentsRequest {
data: Some(data.clone()),
version: huge_pb::meta::resolve_fragments_request::Version::V2.into(),
read_latest: false,
ask_partitions: 0,
not_use_cache: false,
filter_empty: false,
},
)
.await;
assert_eq!(
res.unwrap(),
ResolveFragmentsResponse {
codes_v2: vec![foc],
}
);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment