diff --git a/src/agent/utils/rpc_multi_client.rs b/src/agent/utils/rpc_multi_client.rs index 7cc5fbc..09415ef 100644 --- a/src/agent/utils/rpc_multi_client.rs +++ b/src/agent/utils/rpc_multi_client.rs @@ -17,32 +17,189 @@ use { transaction::Transaction, }, solana_transaction_status::TransactionStatus, - std::time::Duration, + std::{ + future::Future, + pin::Pin, + sync::Arc, + time::{ + Duration, + Instant, + }, + }, + tokio::sync::Mutex, url::Url, }; + +#[derive(Debug, Clone)] +struct EndpointState { + last_failure: Option, + is_healthy: bool, +} + +#[derive(Debug)] +struct RoundRobinState { + current_index: usize, + endpoint_states: Vec, + cooldown_duration: Duration, +} + + +impl RoundRobinState { + fn new(endpoint_count: usize, cooldown_duration: Duration) -> Self { + Self { + current_index: 0, + endpoint_states: vec![ + EndpointState { + last_failure: None, + is_healthy: true, + }; + endpoint_count + ], + cooldown_duration, + } + } +} + pub struct RpcMultiClient { - rpc_clients: Vec, + rpc_clients: Vec, + round_robin_state: Arc>, } impl RpcMultiClient { + async fn retry_with_round_robin<'a, T, F>( + &'a self, + operation_name: &str, + operation: F, + ) -> anyhow::Result + where + F: Fn(&'a RpcClient) -> Pin> + Send + 'a>>, + { + let mut attempts = 0; + let max_attempts = self.rpc_clients.len() * 2; + + while attempts < max_attempts { + let index_option = self.get_next_endpoint().await; + + if let Some(index) = index_option { + let future = operation( + self.rpc_clients + .get(index) + .ok_or(anyhow::anyhow!("Index out of bounds"))?, + ); + match future.await { + Ok(result) => { + let mut state = self.round_robin_state.lock().await; + + #[allow(clippy::indexing_slicing, reason = "index is checked")] + if index < state.endpoint_states.len() { + state.endpoint_states[index].is_healthy = true; + state.endpoint_states[index].last_failure = None; + } + return Ok(result); + } + Err(e) => { + #[allow(clippy::indexing_slicing, reason = "index is checked")] + let client = &self.rpc_clients[index]; + tracing::warn!( + "{} error for rpc endpoint {}: {}", + operation_name, + client.url(), + e + ); + let mut state = self.round_robin_state.lock().await; + + #[allow(clippy::indexing_slicing, reason = "index is checked")] + if index < state.endpoint_states.len() { + state.endpoint_states[index].last_failure = Some(Instant::now()); + state.endpoint_states[index].is_healthy = false; + } + } + } + } + attempts += 1; + } + + bail!( + "{} failed for all RPC endpoints after {} attempts", + operation_name, + attempts + ) + } + + async fn get_next_endpoint(&self) -> Option { + let mut state = self.round_robin_state.lock().await; + let now = Instant::now(); + let start_index = state.current_index; + + let mut found_index = None; + for _ in 0..state.endpoint_states.len() { + let index = state.current_index; + state.current_index = (state.current_index + 1) % state.endpoint_states.len(); + + #[allow(clippy::indexing_slicing, reason = "index is checked")] + let endpoint_state = &state.endpoint_states[index]; + if endpoint_state.is_healthy + || endpoint_state.last_failure.is_none_or(|failure_time| { + now.duration_since(failure_time) >= state.cooldown_duration + }) + { + found_index = Some(index); + break; + } + } + + if found_index.is_none() { + let index = start_index; + state.current_index = (start_index + 1) % state.endpoint_states.len(); + found_index = Some(index); + } + found_index + } + pub fn new_with_timeout(rpc_urls: Vec, timeout: Duration) -> Self { - let clients = rpc_urls + Self::new_with_timeout_and_cooldown(rpc_urls, timeout, Duration::from_secs(30)) + } + + pub fn new_with_timeout_and_cooldown( + rpc_urls: Vec, + timeout: Duration, + cooldown_duration: Duration, + ) -> Self { + let clients: Vec = rpc_urls .iter() .map(|rpc_url| RpcClient::new_with_timeout(rpc_url.to_string(), timeout)) .collect(); + let round_robin_state = Arc::new(Mutex::new(RoundRobinState::new( + clients.len(), + cooldown_duration, + ))); Self { rpc_clients: clients, + round_robin_state, } } pub fn new_with_commitment(rpc_urls: Vec, commitment_config: CommitmentConfig) -> Self { - let clients = rpc_urls + Self::new_with_commitment_and_cooldown(rpc_urls, commitment_config, Duration::from_secs(30)) + } + + pub fn new_with_commitment_and_cooldown( + rpc_urls: Vec, + commitment_config: CommitmentConfig, + cooldown_duration: Duration, + ) -> Self { + let clients: Vec = rpc_urls .iter() .map(|rpc_url| RpcClient::new_with_commitment(rpc_url.to_string(), commitment_config)) .collect(); + let round_robin_state = Arc::new(Mutex::new(RoundRobinState::new( + clients.len(), + cooldown_duration, + ))); Self { rpc_clients: clients, + round_robin_state, } } @@ -51,7 +208,21 @@ impl RpcMultiClient { timeout: Duration, commitment_config: CommitmentConfig, ) -> Self { - let clients = rpc_urls + Self::new_with_timeout_commitment_and_cooldown( + rpc_urls, + timeout, + commitment_config, + Duration::from_secs(30), + ) + } + + pub fn new_with_timeout_commitment_and_cooldown( + rpc_urls: Vec, + timeout: Duration, + commitment_config: CommitmentConfig, + cooldown_duration: Duration, + ) -> Self { + let clients: Vec = rpc_urls .iter() .map(|rpc_url| { RpcClient::new_with_timeout_and_commitment( @@ -61,148 +232,135 @@ impl RpcMultiClient { ) }) .collect(); + let round_robin_state = Arc::new(Mutex::new(RoundRobinState::new( + clients.len(), + cooldown_duration, + ))); Self { rpc_clients: clients, + round_robin_state, } } pub async fn get_balance(&self, kp: &Keypair) -> anyhow::Result { - for client in self.rpc_clients.iter() { - match client.get_balance(&kp.pubkey()).await { - Ok(balance) => return Ok(balance), - Err(e) => { - tracing::warn!("getBalance error for rpc endpoint {}: {}", client.url(), e) - } - } - } - bail!("getBalance failed for all RPC endpoints") + let pubkey = kp.pubkey(); + self.retry_with_round_robin("getBalance", |client| { + Box::pin(async move { + client + .get_balance(&pubkey) + .await + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn send_transaction_with_config( &self, transaction: &Transaction, ) -> anyhow::Result { - for rpc_client in self.rpc_clients.iter() { - match rpc_client - .send_transaction_with_config( - transaction, - RpcSendTransactionConfig { - skip_preflight: true, - ..RpcSendTransactionConfig::default() - }, - ) - .await - { - Ok(signature) => return Ok(signature), - Err(e) => tracing::warn!( - "sendTransactionWithConfig failed for rpc endpoint {}: {:?}", - rpc_client.url(), - e - ), - } - } - bail!("sendTransactionWithConfig failed for all rpc endpoints") + self.retry_with_round_robin("sendTransactionWithConfig", |client| { + let transaction = transaction.clone(); + Box::pin(async move { + client + .send_transaction_with_config( + &transaction, + RpcSendTransactionConfig { + skip_preflight: true, + ..RpcSendTransactionConfig::default() + }, + ) + .await + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn get_signature_statuses( &self, signatures_contiguous: &mut [Signature], ) -> anyhow::Result>> { - for rpc_client in self.rpc_clients.iter() { - match rpc_client - .get_signature_statuses(signatures_contiguous) - .await - { - Ok(statuses) => return Ok(statuses.value), - Err(e) => tracing::warn!( - "getSignatureStatus failed for rpc endpoint {}: {:?}", - rpc_client.url(), - e - ), - } - } - bail!("getSignatureStatuses failed for all rpc endpoints") + self.retry_with_round_robin("getSignatureStatuses", |client| { + let signatures = signatures_contiguous.to_vec(); + Box::pin(async move { + client + .get_signature_statuses(&signatures) + .await + .map(|statuses| statuses.value) + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn get_recent_prioritization_fees( &self, price_accounts: &[Pubkey], ) -> anyhow::Result> { - for rpc_client in self.rpc_clients.iter() { - match rpc_client - .get_recent_prioritization_fees(price_accounts) - .await - { - Ok(fees) => return Ok(fees), - Err(e) => tracing::warn!( - "getRecentPrioritizationFee failed for rpc endpoint {}: {:?}", - rpc_client.url(), - e - ), - } - } - bail!("getRecentPrioritizationFees failed for every rpc endpoint") + self.retry_with_round_robin("getRecentPrioritizationFees", |client| { + let price_accounts = price_accounts.to_vec(); + Box::pin(async move { + client + .get_recent_prioritization_fees(&price_accounts) + .await + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn get_program_accounts( &self, oracle_program_key: Pubkey, ) -> anyhow::Result> { - for rpc_client in self.rpc_clients.iter() { - match rpc_client.get_program_accounts(&oracle_program_key).await { - Ok(accounts) => return Ok(accounts), - Err(e) => tracing::warn!( - "getProgramAccounts failed for rpc endpoint {}: {}", - rpc_client.url(), - e - ), - } - } - bail!("getProgramAccounts failed for all rpc endpoints") + self.retry_with_round_robin("getProgramAccounts", |client| { + Box::pin(async move { + client + .get_program_accounts(&oracle_program_key) + .await + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn get_account_data(&self, publisher_config_key: &Pubkey) -> anyhow::Result> { - for rpc_client in self.rpc_clients.iter() { - match rpc_client.get_account_data(publisher_config_key).await { - Ok(data) => return Ok(data), - Err(e) => tracing::warn!( - "getAccountData failed for rpc endpoint {}: {:?}", - rpc_client.url(), - e - ), - } - } - bail!("getAccountData failed for all rpc endpoints") + self.retry_with_round_robin("getAccountData", |client| { + Box::pin(async move { + client + .get_account_data(publisher_config_key) + .await + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn get_slot_with_commitment( &self, commitment_config: CommitmentConfig, ) -> anyhow::Result { - for rpc_client in self.rpc_clients.iter() { - match rpc_client.get_slot_with_commitment(commitment_config).await { - Ok(slot) => return Ok(slot), - Err(e) => tracing::warn!( - "getSlotWithCommitment failed for rpc endpoint {}: {:?}", - rpc_client.url(), - e - ), - } - } - bail!("getSlotWithCommitment failed for all rpc endpoints") + self.retry_with_round_robin("getSlotWithCommitment", |client| { + Box::pin(async move { + client + .get_slot_with_commitment(commitment_config) + .await + .map_err(anyhow::Error::from) + }) + }) + .await } pub async fn get_latest_blockhash(&self) -> anyhow::Result { - for rpc_client in self.rpc_clients.iter() { - match rpc_client.get_latest_blockhash().await { - Ok(hash) => return Ok(hash), - Err(e) => tracing::warn!( - "getLatestBlockhash failed for rpc endpoint {}: {:?}", - rpc_client.url(), - e - ), - } - } - bail!("getLatestBlockhash failed for all rpc endpoints") + self.retry_with_round_robin("getLatestBlockhash", |client| { + Box::pin(async move { + client + .get_latest_blockhash() + .await + .map_err(anyhow::Error::from) + }) + }) + .await } }