Created
July 19, 2025 14:50
-
-
Save mgild/53d335f751b3c0e22e13af8de10ff760 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::{env, str::FromStr, sync::Arc, time::Duration}; | |
| use anyhow_ext::{anyhow, Context, Error as AnyhowError, Result}; | |
| use cached::{ | |
| proc_macro::{cached, once}, | |
| SizedCache, | |
| }; | |
| use fancy_regex::Regex; | |
| use futures::future::join_all; | |
| use num_traits::FromPrimitive; | |
| use reqwest::{Client, ClientBuilder}; | |
| use rust_decimal::{Decimal, MathematicalOps}; | |
| use serde::{Deserialize, Serialize}; | |
| use solana_client::nonblocking::rpc_client::RpcClient; | |
| use solana_sdk::{program_pack::Pack, pubkey::Pubkey}; | |
| use spl_token::state::Mint; | |
| use tracing::{info, warn}; | |
| use crate::{ | |
| oracle_job::{jupiter_swap_task::SwapAmount, JupiterSwapTask}, | |
| protos::OracleJob, | |
| utils::{handle_reqwest_err, url_or_from_env}, | |
| TaskInterface, TaskInterfaceAsync, TaskOutput, TaskResult, TaskRunnerContext, | |
| }; | |
| // pub const JUPITER_URL: &str = "https://switchboard.rpcpool.com/5f7ec120-9628-40c4-997f-0cdb4135e059/jupiter"; | |
| pub const JUPITER_URL: &str = "https://jupiter-api.switchboard-oracles.xyz"; | |
| const JUPITER_URL_ENV_VAR: &str = "TASK_RUNNER_JUPITER_URL"; | |
| const HTTP_TIMEOUT: u64 = 5; | |
| // Retry configuration constants | |
| const PRICE_DIRECTION_RETRY_SCALE_FACTOR: u32 = 10; | |
| const PRICE_DIRECTION_MAX_RETRIES: usize = 1; | |
| const PRICE_DIRECTION_MIN_AMOUNT: Decimal = Decimal::ONE; | |
| // Returns the Jupiter URL. | |
| // If the TASK_RUNNER_JUPITER_URL env var is set, returns the value set in there. | |
| pub fn jupiter_url() -> String { | |
| let url = env::var(JUPITER_URL_ENV_VAR).unwrap_or_else(|_| JUPITER_URL.to_string()); | |
| url_or_from_env(&url) | |
| } | |
| pub fn http_client_init(timeout: u64) -> Client { | |
| let timeout = Duration::from_secs(timeout); | |
| ClientBuilder::new() | |
| .tcp_keepalive(Some(Duration::from_secs(75))) | |
| .connect_timeout(timeout) | |
| .timeout(timeout) | |
| .build() | |
| .unwrap() | |
| } | |
| #[once(time = 300)] | |
| fn http_client() -> Arc<Client> { | |
| Arc::new(http_client_init(HTTP_TIMEOUT)) | |
| } | |
| #[cached( | |
| ty = "SizedCache<Vec<Pubkey>, Vec<Mint>>", | |
| create = "{ SizedCache::with_size(1000) }", | |
| convert = r#"{ mint_keys.clone() }"#, | |
| result = true | |
| )] | |
| pub async fn get_mints(rpc_client: &RpcClient, mint_keys: Vec<Pubkey>) -> Result<Vec<Mint>> { | |
| let mint_results = rpc_client | |
| .get_multiple_accounts(&mint_keys) | |
| .await | |
| .with_context(|| "Failed to get mint account data")?; | |
| let mints: Vec<Mint> = mint_results | |
| .iter() | |
| .map(|result| { | |
| let mint_account = result.as_ref().context("Failed to get mint account data")?; | |
| let mint_bytes: &[u8] = &mint_account.data[..]; | |
| Mint::unpack_from_slice(mint_bytes) | |
| .with_context(|| "Failed to unpack mint account data") | |
| }) | |
| .collect::<Result<Vec<Mint>>>()?; | |
| Ok(mints) | |
| } | |
| // pub const ALL_JUPITER_DEXES = [ | |
| // "Lifinity V1", | |
| // "Marinade", | |
| // "Meteora", | |
| // "Penguin", | |
| // "Mercurial", | |
| // "Oasis", | |
| // "Phoenix", | |
| // "Raydium", | |
| // "Jupiter LO", | |
| // "Openbook", | |
| // "StepN", | |
| // "Raydium CLMM", | |
| // "Aldrin V2", | |
| // "Symmetry", | |
| // "Lifinity V2", | |
| // "Bonkswap", | |
| // "Cropper", | |
| // "Balansol", | |
| // "Sanctum", | |
| // "Saber", | |
| // "Invariant", | |
| // "Helium Network", | |
| // "Saros", | |
| // "Orca V1", | |
| // "Crema", | |
| // "Saber (Decimals)", | |
| // "Orca V2", | |
| // "Whirlpool", | |
| // "Aldrin", | |
| // "FluxBeam", | |
| // ]; | |
| fn remove_trailing_zeros(s: &str) -> String { | |
| let re = Regex::new(r"\.0+$").unwrap(); | |
| re.replace(s, "").to_string() | |
| } | |
| fn deserialize_u64_from_string<'de, D>(deserializer: D) -> Result<u64, D::Error> | |
| where | |
| D: serde::Deserializer<'de>, | |
| { | |
| let s = String::deserialize(deserializer)?; | |
| s.parse::<u64>().map_err(serde::de::Error::custom) | |
| } | |
| #[derive(Serialize, Deserialize, Default, PartialEq, Clone, Debug)] | |
| pub enum SwapMode { | |
| #[default] | |
| ExactIn, | |
| ExactOut, | |
| } | |
| impl FromStr for SwapMode { | |
| type Err = AnyhowError; | |
| fn from_str(s: &str) -> Result<Self> { | |
| match s { | |
| "ExactIn" => Ok(Self::ExactIn), | |
| "ExactOut" => Ok(Self::ExactOut), | |
| _ => Err(anyhow!("Failed to parse SwapMode enum ({})", s)), | |
| } | |
| } | |
| } | |
| impl std::fmt::Display for SwapMode { | |
| fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
| match *self { | |
| Self::ExactIn => write!(f, "ExactIn"), | |
| Self::ExactOut => write!(f, "ExactOut"), | |
| } | |
| } | |
| } | |
| #[derive(Clone, Debug, Deserialize)] | |
| #[serde(rename_all = "camelCase")] | |
| pub struct JupiterSwapQuoteResponse { | |
| pub input_mint: String, | |
| #[serde(deserialize_with = "deserialize_u64_from_string")] | |
| pub in_amount: u64, | |
| pub output_mint: String, | |
| #[serde(deserialize_with = "deserialize_u64_from_string")] | |
| pub out_amount: u64, | |
| #[serde(deserialize_with = "deserialize_u64_from_string")] | |
| pub other_amount_threshold: u64, | |
| pub swap_mode: SwapMode, | |
| pub slippage_bps: u64, | |
| pub platform_fee: Option<u64>, | |
| pub price_impact_pct: Option<String>, | |
| pub context_slot: Option<u64>, | |
| pub time_taken: Option<f64>, | |
| // pub route_plan: Vec<JupiterRoutePlan>, | |
| #[serde(default)] | |
| pub route_plan: Option<serde_json::Value>, | |
| } | |
| pub struct JupiterSwapClient { | |
| pub api_key: String, | |
| } | |
| pub struct TokenInput { | |
| pub address: String, | |
| pub decimals: u32, | |
| } | |
| impl JupiterSwapClient { | |
| pub fn new(api_key: Option<String>) -> Self { | |
| let mut jupiter_api_key = api_key.unwrap_or_default(); | |
| if jupiter_api_key.is_empty() { | |
| let jupiter_api_key_env = std::env::var("JUPITER_SWAP_API_KEY").unwrap_or_default(); | |
| if !jupiter_api_key_env.is_empty() { | |
| jupiter_api_key = jupiter_api_key_env; | |
| } | |
| } | |
| Self { | |
| api_key: jupiter_api_key, | |
| } | |
| } | |
| pub fn get_url( | |
| &self, | |
| in_token_address: &str, | |
| out_token_address: &str, | |
| amount: &str, | |
| slippage_bps: u64, | |
| ) -> String { | |
| let base = format!("{}/quote?", jupiter_url()); | |
| let url = format!( | |
| "{}inputMint={}&outputMint={}&amount={}&slippageBps={}&onlyDirectRoutes=false", | |
| base, in_token_address, out_token_address, amount, slippage_bps | |
| ); | |
| url | |
| } | |
| async fn get_quote( | |
| &self, | |
| in_token: &TokenInput, | |
| out_token: &TokenInput, | |
| amount: &str, | |
| slippage_bps: Option<f64>, | |
| ) -> Result<Decimal> { | |
| let in_token_address = in_token.address.clone(); | |
| let in_token_amount: Decimal = | |
| Decimal::from_str(amount).with_context(|| "Failed to parse in token amount")?; | |
| let amount = in_token_amount; | |
| let out_token_address = out_token.address.clone(); | |
| let slippage_bps = (slippage_bps.unwrap_or(1.0) * 100.0) as u64; | |
| let url = self.get_url( | |
| &in_token_address, | |
| &out_token_address, | |
| &remove_trailing_zeros(&amount.to_string()), | |
| slippage_bps, | |
| ); | |
| let response = http_client() | |
| .get(url) | |
| .send() | |
| .await | |
| .map_err(handle_reqwest_err)? | |
| .error_for_status() | |
| .map_err(handle_reqwest_err)?; | |
| if response.status() != 200 { | |
| return Err(anyhow!( | |
| "Jupiter Swap API returned status code {}", | |
| response.status() | |
| )); | |
| } | |
| // Get the response text as a string | |
| let text = response.text().await.map_err(handle_reqwest_err)?; | |
| let response: JupiterSwapQuoteResponse = serde_json::from_str(&text) | |
| .map_err(|_e| anyhow!("Failed to parse JupiterSwapQuoteResponse: {}", text))?; | |
| if response.slippage_bps > slippage_bps { | |
| return Err(anyhow!( | |
| "Slippage exceeded: expected {}, got {}", | |
| slippage_bps, | |
| response.slippage_bps | |
| )); | |
| } | |
| Ok(Decimal::from(response.out_amount)) | |
| } | |
| } | |
| #[derive(Debug, Clone)] | |
| struct RetryAmountCandidate { | |
| amount: Decimal, | |
| scale_factor: i32, | |
| } | |
| async fn jupiter_swap_task_with_retry( | |
| in_token_address: &str, | |
| out_token_address: &str, | |
| original_amount: Decimal, | |
| slippage: Option<f64>, | |
| in_token_mint: &Mint, | |
| out_token_mint: &Mint, | |
| ) -> Result<Decimal> { | |
| let mut retry_amounts = Vec::new(); | |
| let scale_factor = Decimal::from(PRICE_DIRECTION_RETRY_SCALE_FACTOR); | |
| // Try original amount first | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: original_amount, | |
| scale_factor: 0, | |
| }); | |
| // Generate scaled amounts | |
| for i in 1..=PRICE_DIRECTION_MAX_RETRIES { | |
| // Scale down | |
| let scaled_down = original_amount / scale_factor.powi(i as i64); | |
| if scaled_down >= PRICE_DIRECTION_MIN_AMOUNT { | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: scaled_down, | |
| scale_factor: -(i as i32), | |
| }); | |
| } | |
| // Scale up | |
| let scaled_up = original_amount * scale_factor.powi(i as i64); | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: scaled_up, | |
| scale_factor: i as i32, | |
| }); | |
| } | |
| // Create futures for all retry attempts | |
| let futures: Vec<_> = retry_amounts | |
| .iter() | |
| .enumerate() | |
| .map(|(index, candidate)| { | |
| let true_amount = candidate.amount * Decimal::TEN.powi(in_token_mint.decimals.into()); | |
| let in_token_address = in_token_address.to_string(); | |
| let out_token_address = out_token_address.to_string(); | |
| let candidate = candidate.clone(); | |
| let in_decimals = in_token_mint.decimals; | |
| let out_decimals = out_token_mint.decimals; | |
| async move { | |
| let result = jupiter_swap_task_attempt( | |
| &in_token_address, | |
| &out_token_address, | |
| true_amount, | |
| candidate.amount, | |
| slippage, | |
| in_decimals.into(), | |
| out_decimals.into(), | |
| ) | |
| .await; | |
| (index, candidate, result) | |
| } | |
| }) | |
| .collect(); | |
| // Execute all futures in parallel | |
| let results = join_all(futures).await; | |
| // Process results in priority order (maintaining original order) | |
| let mut sorted_results = results; | |
| sorted_results.sort_by_key(|(index, _, _)| *index); | |
| let mut last_price_direction_error = None; | |
| for (_, candidate, result) in sorted_results { | |
| match result { | |
| Ok(price) => { | |
| if candidate.scale_factor != 0 { | |
| warn!( | |
| "Jupiter swap succeeded with scaled amount: original={}, scaled={}, scale_factor={}", | |
| original_amount, candidate.amount, candidate.scale_factor | |
| ); | |
| } | |
| let adjusted_price = price * scale_factor.powi(-candidate.scale_factor as i64); | |
| return Ok(adjusted_price); | |
| } | |
| Err(e) => { | |
| if e.to_string().contains("PriceDirectionDiffOutOfBounds") { | |
| warn!( | |
| "Price direction error with amount {}: {}", | |
| candidate.amount, e | |
| ); | |
| last_price_direction_error = Some(e); | |
| // Continue checking other results | |
| } else { | |
| // For non-price-direction errors, we still check if a better result exists | |
| // but track this error in case all attempts fail | |
| if last_price_direction_error.is_none() { | |
| last_price_direction_error = Some(e); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // If we exhausted all retries, return the last error | |
| Err(last_price_direction_error.unwrap_or_else(|| anyhow!("Failed to execute Jupiter swap after all retries"))) | |
| } | |
| async fn jupiter_swap_task_attempt( | |
| in_token_address: &str, | |
| out_token_address: &str, | |
| true_amount: Decimal, | |
| swap_amount: Decimal, | |
| slippage: Option<f64>, | |
| in_token_decimals: u32, | |
| out_token_decimals: u32, | |
| ) -> Result<Decimal> { | |
| let price = jupiter_swap_task_inner( | |
| in_token_address, | |
| out_token_address, | |
| &true_amount.round().to_string(), | |
| slippage, | |
| in_token_decimals, | |
| out_token_decimals, | |
| ) | |
| .await?; | |
| let rev_price = jupiter_swap_task_inner( | |
| out_token_address, | |
| in_token_address, | |
| &price.round().to_string(), | |
| slippage, | |
| out_token_decimals, | |
| in_token_decimals, | |
| ) | |
| .await?; | |
| info!("{}, {}, {}", true_amount, price, rev_price); | |
| let price_normalized = price / Decimal::TEN.powi(out_token_decimals.into()); | |
| let rev_price_normalized = rev_price / Decimal::TEN.powi(in_token_decimals.into()); | |
| let numerator = std::cmp::min(swap_amount, rev_price_normalized); | |
| let denominator = std::cmp::max(swap_amount, rev_price_normalized); | |
| let pcnt = Decimal::ONE - (numerator / denominator); | |
| // Use the task's slippage value or default to 3.5% | |
| let slippage_threshold = slippage.unwrap_or(3.5) / 100.0; | |
| let slippage_decimal = Decimal::from_f64(slippage_threshold).unwrap(); | |
| if pcnt > slippage_decimal { | |
| return Err(anyhow!( | |
| "PriceDirectionDiffOutOfBounds: price difference {}% exceeds slippage tolerance {}%", | |
| (pcnt * Decimal::from(100)).round_dp(2), | |
| (slippage_decimal * Decimal::from(100)).round_dp(2) | |
| )); | |
| } | |
| Ok(price_normalized) | |
| } | |
| async fn jupiter_swap_task(ctx: &TaskRunnerContext, task: &JupiterSwapTask) -> TaskResult { | |
| let rpc_client = ctx.mainnet_rpc(); | |
| let mut is_flipped = false; | |
| if task.in_token_address.is_none() { | |
| return Err(anyhow!("JupiterSwapTask.in_token_address is empty")); | |
| } | |
| if task.out_token_address.is_none() { | |
| return Err(anyhow!("JupiterSwapTask.out_token_address is empty")); | |
| } | |
| let swap_amount = if let Some(amount) = task.swap_amount.clone() { | |
| match amount { | |
| SwapAmount::BaseAmount(amount) => amount.to_string(), | |
| SwapAmount::BaseAmountString(amount_str) => amount_str, | |
| SwapAmount::QuoteAmount(amount) => { | |
| is_flipped = true; | |
| amount.to_string() | |
| } | |
| SwapAmount::QuoteAmountString(amount_str) => { | |
| is_flipped = true; | |
| amount_str | |
| } | |
| } | |
| } else { | |
| "1".into() | |
| }; | |
| // Get in token mint address | |
| let mut in_token_address = &task | |
| .in_token_address | |
| .as_ref() | |
| .context("Failed to parse in token address")?; | |
| // Get out token mint address | |
| let mut out_token_address = &task | |
| .out_token_address | |
| .as_ref() | |
| .context("Failed to parse out token address")?; | |
| if is_flipped { | |
| std::mem::swap(&mut in_token_address, &mut out_token_address); | |
| } | |
| let in_token_pubkey = | |
| Pubkey::from_str(in_token_address).context("Failed to parse in token address")?; | |
| let out_token_pubkey = | |
| Pubkey::from_str(out_token_address).context("Failed to parse out token address")?; | |
| // Fetch mints for the tokens | |
| let mint_keys = vec![in_token_pubkey, out_token_pubkey]; | |
| let mints = get_mints(&rpc_client, mint_keys).await?; | |
| let [in_token_mint, out_token_mint] = mints[..] else { | |
| return Err(anyhow!("Failed to get mint account data")); | |
| }; | |
| let swap_amount = Decimal::from_str(&swap_amount)?; | |
| let price = jupiter_swap_task_with_retry( | |
| in_token_address, | |
| out_token_address, | |
| swap_amount, | |
| task.slippage, | |
| &in_token_mint, | |
| &out_token_mint, | |
| ) | |
| .await?; | |
| Ok(TaskOutput::Num(price)) | |
| } | |
| impl TaskInterface for JupiterSwapTask { | |
| fn children(&self) -> Vec<OracleJob> { | |
| Vec::new() | |
| } | |
| fn uses_input(&self) -> bool { | |
| false | |
| } | |
| } | |
| #[async_trait::async_trait] | |
| impl TaskInterfaceAsync for JupiterSwapTask { | |
| async fn execute<'a>(&'a self, ctx: &'a mut TaskRunnerContext) -> TaskResult { | |
| jupiter_swap_task(ctx, self).await | |
| } | |
| } | |
| async fn jupiter_swap_task_inner_write( | |
| in_token_address: &str, | |
| out_token_address: &str, | |
| amount: &str, | |
| slippage_bps: Option<f64>, | |
| in_token_decimals: u32, | |
| out_token_decimals: u32, | |
| ) -> Result<Decimal> { | |
| let jupiter = JupiterSwapClient::new(None); | |
| jupiter | |
| .get_quote( | |
| &TokenInput { | |
| address: in_token_address.to_string(), | |
| decimals: in_token_decimals, | |
| }, | |
| &TokenInput { | |
| address: out_token_address.to_string(), | |
| decimals: out_token_decimals, | |
| }, | |
| amount, | |
| slippage_bps, | |
| ) | |
| .await | |
| } | |
| async fn jupiter_swap_task_inner( | |
| in_token_address: &str, | |
| out_token_address: &str, | |
| amount: &str, | |
| slippage_bps: Option<f64>, | |
| in_token_decimals: u32, | |
| out_token_decimals: u32, | |
| ) -> Result<Decimal> { | |
| jupiter_swap_task_inner_write( | |
| in_token_address, | |
| out_token_address, | |
| amount, | |
| slippage_bps, | |
| in_token_decimals, | |
| out_token_decimals, | |
| ) | |
| .await | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use rust_decimal_macros::dec; | |
| use super::*; | |
| use crate::test_utils; | |
| #[tokio::test] | |
| async fn test_jupiter_swap_task1() { | |
| let ctx = test_utils::get_test_task_runner_context(true); | |
| let task = JupiterSwapTask { | |
| in_token_address: Some("DriFtupJYLTosbwoN8koMbEYSx54aFAVLddWsbksjwg7".to_string()), | |
| out_token_address: Some("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v".to_string()), | |
| swap_amount: Some(SwapAmount::BaseAmountString("1000".to_string())), | |
| routes_filters: None, | |
| slippage: Some(2.0), | |
| version: None, | |
| }; | |
| let value = jupiter_swap_task(&ctx, &task).await.unwrap(); | |
| println!("Value: {:?}", value); | |
| } | |
| #[tokio::test] | |
| async fn test_jupiter_swap_task_quote_amount() { | |
| let ctx = test_utils::get_test_task_runner_context(true); | |
| let task = JupiterSwapTask { | |
| in_token_address: Some("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v".to_string()), | |
| out_token_address: Some("DriFtupJYLTosbwoN8koMbEYSx54aFAVLddWsbksjwg7".to_string()), | |
| swap_amount: Some(SwapAmount::QuoteAmountString("1000".to_string())), | |
| routes_filters: None, | |
| slippage: Some(2.0), | |
| version: None, | |
| }; | |
| let value = jupiter_swap_task(&ctx, &task).await.unwrap(); | |
| info!("Value: {:?}", value); | |
| } | |
| #[test] | |
| fn test_retry_amount_candidates_generation() { | |
| let original_amount = dec!(100); | |
| let scale_factor = Decimal::from(PRICE_DIRECTION_RETRY_SCALE_FACTOR); | |
| let mut retry_amounts = Vec::new(); | |
| // Original amount | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: original_amount, | |
| scale_factor: 0, | |
| }); | |
| // Generate scaled amounts | |
| for i in 1..=PRICE_DIRECTION_MAX_RETRIES { | |
| // Scale down | |
| let scaled_down = original_amount / scale_factor.powi(i as i64); | |
| if scaled_down >= PRICE_DIRECTION_MIN_AMOUNT { | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: scaled_down, | |
| scale_factor: -(i as i32), | |
| }); | |
| } | |
| // Scale up | |
| let scaled_up = original_amount * scale_factor.powi(i as i64); | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: scaled_up, | |
| scale_factor: i as i32, | |
| }); | |
| } | |
| // Calculate expected number of scale downs | |
| let mut expected_scale_downs = 0; | |
| for i in 1..=PRICE_DIRECTION_MAX_RETRIES { | |
| let scaled_down = original_amount / scale_factor.powi(i as i64); | |
| if scaled_down >= PRICE_DIRECTION_MIN_AMOUNT { | |
| expected_scale_downs += 1; | |
| } | |
| } | |
| // Verify the generated amounts | |
| assert_eq!(retry_amounts.len(), 1 + expected_scale_downs + PRICE_DIRECTION_MAX_RETRIES as usize); // Original + valid scale downs + all scale ups | |
| assert_eq!(retry_amounts[0].amount, dec!(100)); | |
| assert_eq!(retry_amounts[0].scale_factor, 0); | |
| // Since we know original_amount = 100, we should have 2 valid scale downs | |
| assert_eq!(expected_scale_downs, 2); | |
| // Check the interleaved scale down and scale up amounts | |
| // Order: original, scale_down_1, scale_up_1, scale_down_2, scale_up_2, scale_up_3 | |
| assert_eq!(retry_amounts[1].amount, dec!(10)); | |
| assert_eq!(retry_amounts[1].scale_factor, -1); | |
| assert_eq!(retry_amounts[2].amount, dec!(1000)); | |
| assert_eq!(retry_amounts[2].scale_factor, 1); | |
| assert_eq!(retry_amounts[3].amount, dec!(1)); | |
| assert_eq!(retry_amounts[3].scale_factor, -2); | |
| assert_eq!(retry_amounts[4].amount, dec!(10000)); | |
| assert_eq!(retry_amounts[4].scale_factor, 2); | |
| assert_eq!(retry_amounts[5].amount, dec!(100000)); | |
| assert_eq!(retry_amounts[5].scale_factor, 3); | |
| } | |
| #[test] | |
| fn test_retry_amount_candidates_min_amount_constraint() { | |
| let original_amount = dec!(5); | |
| let scale_factor = Decimal::from(PRICE_DIRECTION_RETRY_SCALE_FACTOR); | |
| let mut retry_amounts = Vec::new(); | |
| // Original amount | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: original_amount, | |
| scale_factor: 0, | |
| }); | |
| // Generate scaled amounts | |
| for i in 1..=PRICE_DIRECTION_MAX_RETRIES { | |
| // Scale down | |
| let scaled_down = original_amount / scale_factor.powi(i as i64); | |
| if scaled_down >= PRICE_DIRECTION_MIN_AMOUNT { | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: scaled_down, | |
| scale_factor: -(i as i32), | |
| }); | |
| } | |
| // Scale up | |
| let scaled_up = original_amount * scale_factor.powi(i as i64); | |
| retry_amounts.push(RetryAmountCandidate { | |
| amount: scaled_up, | |
| scale_factor: i as i32, | |
| }); | |
| } | |
| // With original amount 5, only one scale down should be valid (5/10 = 0.5 < 1) | |
| assert_eq!(retry_amounts.len(), 4); // Original + 0 valid scale downs + 3 scale ups | |
| assert_eq!(retry_amounts[0].amount, dec!(5)); | |
| assert_eq!(retry_amounts[1].amount, dec!(50)); | |
| assert_eq!(retry_amounts[2].amount, dec!(500)); | |
| assert_eq!(retry_amounts[3].amount, dec!(5000)); | |
| } | |
| #[test] | |
| fn test_price_adjustment_after_retry() { | |
| let original_price = dec!(100); | |
| let scale_factor = Decimal::from(PRICE_DIRECTION_RETRY_SCALE_FACTOR); | |
| // Test scaling down input (scale_factor = -1) | |
| let adjusted_price_down = original_price * scale_factor.powi(1); | |
| assert_eq!(adjusted_price_down, dec!(1000)); | |
| // Test scaling up input (scale_factor = 1) | |
| let adjusted_price_up = original_price / scale_factor.powi(1); | |
| assert_eq!(adjusted_price_up, dec!(10)); | |
| // Test no scaling (scale_factor = 0) | |
| let adjusted_price_none = original_price; | |
| assert_eq!(adjusted_price_none, dec!(100)); | |
| } | |
| #[test] | |
| fn test_constant_values() { | |
| assert_eq!(PRICE_DIRECTION_RETRY_SCALE_FACTOR, 10); | |
| assert_eq!(PRICE_DIRECTION_MAX_RETRIES, 3); | |
| assert_eq!(PRICE_DIRECTION_MIN_AMOUNT, Decimal::ONE); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment