Last active
February 8, 2024 21:30
-
-
Save BlinkyStitt/837ffdbee2892c8857d192a5f4604b9c to your computer and use it in GitHub Desktop.
I'll make this into a proper PR. But figured I would share what I have so far now
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use async_trait::async_trait; | |
use ethers::{ | |
contract::{ | |
multicall_contract::{Call3, Multicall3}, | |
ContractError, Multicall, | |
}, | |
providers::{Middleware, MiddlewareError}, | |
types::{transaction::eip2718::TypedTransaction, Address, BlockId, Bytes}, | |
}; | |
use futures::{FutureExt, TryFutureExt}; | |
use std::{collections::HashMap, fmt::Debug, sync::Arc, time::Duration}; | |
use thiserror::Error; | |
use tokio::{ | |
pin, select, | |
sync::oneshot, | |
task::JoinHandle, | |
time::{sleep_until, Instant}, | |
}; | |
use tracing::{debug, error, trace}; | |
pub type MulticallResponse<M> = Result<Bytes, ContractError<M>>; | |
#[derive(Debug)] | |
pub struct MulticallMiddleware<M: Middleware> { | |
inner: Arc<M>, | |
multicall_address: Address, | |
task_tx: flume::Sender<PendingAction<M>>, | |
pub task_handle: JoinHandle<anyhow::Result<()>>, | |
} | |
#[derive(Debug)] | |
pub enum MulticallAction { | |
Call(TypedTransaction, Option<BlockId>), | |
// Balance(Address, Option<BlockId>), | |
// TODO: there are more things | |
} | |
struct MulticallMiddlewareTask<M: Middleware> { | |
multicall: Multicall<M>, | |
batch_size: usize, | |
max_wait: Duration, | |
queue: HashMap<Option<BlockId>, PendingActions<M>>, | |
rx: flume::Receiver<PendingAction<M>>, | |
} | |
#[derive(Error, Debug)] | |
pub enum MulticallMiddlewareError<M: Middleware> { | |
#[error("{0}")] | |
Middleware(M::Error), | |
#[error("{0}")] | |
Recv(#[from] oneshot::error::RecvError), | |
#[error("{0:?}")] | |
PendingActionSend(#[from] flume::SendError<PendingAction<M>>), | |
#[error("{0}")] | |
/// TODO: this doesn't feel right | |
Contract(#[from] ContractError<M>), | |
} | |
impl MulticallAction { | |
fn block_id(&self) -> &Option<BlockId> { | |
match self { | |
MulticallAction::Call(_, block) => block, | |
// MulticallAction::Balance(_, block) => block, | |
} | |
} | |
} | |
impl<M: Middleware + 'static> MulticallMiddleware<M> { | |
pub async fn new( | |
inner: Arc<M>, | |
capacity: Option<usize>, | |
address: Option<Address>, | |
batch_size: Option<usize>, | |
max_wait: Option<Duration>, | |
) -> Result<Self, MulticallMiddlewareError<M>> { | |
let (tx, rx) = if let Some(capacity) = capacity { | |
// TODO! be careful with capacity. deadlocks are possible because we use `send`. we could `try_send` and if that fails do a future? | |
// flume::bounded(capacity) | |
error!("bounded capacity is currently ignored. it doesn't play well with using the blocking send"); | |
flume::unbounded() | |
} else { | |
flume::unbounded() | |
}; | |
let batch_size = batch_size.unwrap_or(200); | |
let max_wait = max_wait.unwrap_or_else(|| Duration::from_millis(1)); | |
// TODO: don't unwrap | |
let multicall = Multicall::<M>::new(inner.clone(), address).await.unwrap(); | |
let address = multicall.contract.address(); | |
let queue = HashMap::new(); | |
let task = MulticallMiddlewareTask { | |
multicall, | |
batch_size, | |
max_wait, | |
queue, | |
rx, | |
}; | |
let task_handle = tokio::spawn(task.run().inspect_err(|e| { | |
// TODO: i think this needs to be a panic, but i'm not positive | |
panic!("MulticallMiddlewareTask error: {:?}", e); | |
})); | |
let x = Self { | |
inner, | |
multicall_address: address, | |
task_tx: tx, | |
task_handle, | |
}; | |
Ok(x) | |
} | |
} | |
impl<M: Middleware + 'static> MiddlewareError for MulticallMiddlewareError<M> { | |
type Inner = M::Error; | |
fn from_err(src: M::Error) -> MulticallMiddlewareError<M> { | |
MulticallMiddlewareError::Middleware(src) | |
} | |
fn as_inner(&self) -> Option<&Self::Inner> { | |
match self { | |
MulticallMiddlewareError::Middleware(e) => Some(e), | |
_ => None, | |
} | |
} | |
} | |
#[async_trait] | |
impl<M> Middleware for MulticallMiddleware<M> | |
where | |
M: Middleware + 'static, | |
{ | |
type Error = MulticallMiddlewareError<M>; | |
type Provider = M::Provider; | |
type Inner = M; | |
fn inner(&self) -> &M { | |
&self.inner | |
} | |
/// notice that this does not use the async keyword. the oneshot is set up and then a future is returned! | |
fn call<'life0, 'life1, 'async_trait>( | |
&'life0 self, | |
tx: &'life1 TypedTransaction, | |
block: Option<BlockId>, | |
) -> ::core::pin::Pin< | |
Box< | |
dyn ::core::future::Future<Output = Result<Bytes, Self::Error>> | |
+ ::core::marker::Send | |
+ 'async_trait, | |
>, | |
> | |
where | |
'life0: 'async_trait, | |
'life1: 'async_trait, | |
Self: 'async_trait, | |
{ | |
let tx_to = tx.to().and_then(|x| x.as_address()); | |
// if the call is already a multicall, do not batch it | |
if tx_to == Some(&self.multicall_address) { | |
let direct_f = self | |
.inner | |
.call(tx, block) | |
.map_err(MulticallMiddlewareError::from_err); | |
return direct_f.boxed(); | |
} | |
// set up a channel for the result. the background task will do the actual querying | |
let (result_tx, result_rx) = oneshot::channel(); | |
let pending_action = PendingAction { | |
action: MulticallAction::Call(tx.clone(), block), | |
result_tx, | |
}; | |
// be very careful with bounded channels! they can block the tokio runtime! | |
if let Err(err) = self.task_tx.send(pending_action) { | |
let err_f = async move { Err(err.into()) }; | |
return err_f.boxed(); | |
} | |
// TODO: this match feels weird. is it right? | |
result_rx | |
.into_future() | |
.map(|x| match x { | |
Ok(Ok(x)) => Ok(x), | |
Ok(Err(e)) => Err(e.into()), | |
Err(e) => Err(e.into()), | |
}) | |
.boxed() | |
} | |
} | |
#[derive(Debug)] | |
pub struct PendingAction<M: Middleware> { | |
action: MulticallAction, | |
result_tx: oneshot::Sender<MulticallResponse<M>>, | |
} | |
// TODO: this type is way too complex. think about how to re-arrange it | |
pub struct PendingActions<M: Middleware> { | |
must_send_by: Instant, | |
actions: Vec<PendingAction<M>>, | |
} | |
impl<M: Middleware> PendingActions<M> { | |
fn new(timeout: Duration) -> Self { | |
Self { | |
must_send_by: Instant::now() + timeout, | |
actions: vec![], | |
} | |
} | |
} | |
impl<M: Middleware + 'static> MulticallMiddlewareTask<M> { | |
/// TODO: this probably needs to be in an Arc | |
async fn run(mut self) -> anyhow::Result<()> { | |
let mut batches = 0; | |
let mut total_calls = 0; | |
let mut single_item_batches = 0; | |
loop { | |
let first = self.rx.recv_async().await?; | |
let first_id = *first.action.block_id(); | |
{ | |
let first_entry = self | |
.queue | |
.entry(first_id) | |
.or_insert_with(|| PendingActions::new(self.max_wait)); | |
first_entry.actions.push(first); | |
} | |
// there might be items still in the queue. we want to use their must_send_by | |
let must_send_at = self | |
.queue | |
.values() | |
.min_by(|a, b| a.must_send_by.cmp(&b.must_send_by)) | |
.unwrap() | |
.must_send_by; | |
let wait_until = sleep_until(must_send_at); | |
pin!(wait_until); | |
loop { | |
select! { | |
x = self.rx.recv_async() => { | |
let x = x?; | |
let block_id = *x.action.block_id(); | |
let entry = self.queue.entry(block_id).or_insert_with(|| PendingActions::new(self.max_wait)); | |
entry.actions.push(x); | |
let key_len = entry.actions.len(); | |
// TODO: this size should be configurable | |
if key_len >= self.batch_size { | |
trace!("size met"); | |
// TODO: breaking here means we drain ALL of them. but we actually only want to drain this key! maybe instead of break, we call flush_id(key) | |
break; | |
} | |
}, | |
_ = &mut wait_until => { | |
trace!("multicall aged out"); | |
break; | |
}, | |
}; | |
} | |
// TODO: don't drain everything. only drain queues that are full or have been waiting for a certain amount of time | |
for (block_id, pending_actions) in self.queue.drain() { | |
let multicall_contract = self.multicall.contract.clone(); | |
let mut calls = vec![]; | |
for pending_action in pending_actions.actions.iter() { | |
match &pending_action.action { | |
MulticallAction::Call(tx, _) => calls.push(Call3 { | |
target: *tx.to_addr().unwrap(), | |
call_data: tx.data().cloned().unwrap_or_else(Bytes::new), | |
allow_failure: true, | |
}), | |
} | |
} | |
let new_calls = calls.len(); | |
total_calls += new_calls; | |
batches += 1; | |
if new_calls == 1 { | |
// TODO: call directly? | |
// log this because it might mean that we aren't looping in a way that allows batching | |
debug!("single call: {:?}", calls[0]); | |
single_item_batches += 1; | |
} | |
debug!( | |
"batching {} calls. ({} reduced to {}. {} singles)", | |
new_calls, total_calls, batches, single_item_batches, | |
); | |
let f = multicall_aggregate(multicall_contract, pending_actions, calls, block_id); | |
tokio::spawn(f); | |
} | |
} | |
} | |
} | |
/// spawn this to run the multicall in the background | |
async fn multicall_aggregate<M: Middleware + 'static>( | |
multicall_contract: Multicall3<M>, | |
pending_actions: PendingActions<M>, | |
calls: Vec<Call3>, | |
block_id: Option<BlockId>, | |
) { | |
let mut aggregate_call = multicall_contract.aggregate_3(calls); | |
if let Some(x) = block_id { | |
aggregate_call = aggregate_call.block(x); | |
} | |
let results = match aggregate_call.await { | |
Err(err) => { | |
// drop the pending_actions. they will resolve with RecvError and can be retried | |
// TODO: do something else? | |
error!("multicall aggregate3 failed: {:?}", err); | |
return; | |
} | |
Ok(x) => x, | |
}; | |
for (pending_action, result) in pending_actions.actions.into_iter().zip(results) { | |
let result = if result.success { | |
Ok(result.return_data) | |
} else { | |
// TODO: can we use the inner's existing Error::Revert? | |
// TODO: should this return a ContractError instead? | |
Err(ContractError::Revert(result.return_data)) | |
}; | |
let _ = pending_action.result_tx.send(result); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment