diff --git a/Cargo.lock b/Cargo.lock index ae26ee1..fc31326 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1480,6 +1480,7 @@ name = "drift-rs" version = "1.0.0-alpha.1" dependencies = [ "abi_stable", + "ahash 0.8.11", "anchor-lang", "base64 0.22.1", "bytemuck", @@ -1487,7 +1488,6 @@ dependencies = [ "dashmap 6.1.0", "drift-idl-gen", "env_logger 0.11.5", - "fnv", "futures-util", "hex", "hex-literal", diff --git a/Cargo.toml b/Cargo.toml index edacafc..418ec3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,12 +23,12 @@ rpc_tests = [] [dependencies] abi_stable = "0.11" +ahash = "0.8.11" anchor-lang = { version = "0.30", features = ["derive"] } base64 = "0.22" bytemuck = "1.17" dashmap = "6" env_logger = "0.11" -fnv = "1" futures-util = "0.3" log = "0.4" rayon = { version = "1.9.0", optional = true } @@ -49,4 +49,4 @@ hex = "0.4" hex-literal = "0.4" [build-dependencies] -drift-idl-gen = { version = "0.1.1", path = "crates/drift-idl-gen"} \ No newline at end of file +drift-idl-gen = { version = "0.1.1", path = "crates/drift-idl-gen"} diff --git a/crates/src/account_map.rs b/crates/src/account_map.rs index 8b87cc0..c4358f8 100644 --- a/crates/src/account_map.rs +++ b/crates/src/account_map.rs @@ -1,7 +1,8 @@ use std::sync::{Arc, Mutex, RwLock}; use anchor_lang::AccountDeserialize; -use fnv::FnvHashMap; +use dashmap::DashMap; +use log::debug; use solana_sdk::{clock::Slot, commitment_config::CommitmentConfig, pubkey::Pubkey}; use crate::{ @@ -9,6 +10,8 @@ use crate::{ UnsubHandle, }; +const LOG_TARGET: &str = "accountmap"; + #[derive(Clone, Default)] pub struct AccountSlot { raw: Vec, @@ -24,7 +27,7 @@ pub struct DataAndSlot { pub struct AccountMap { endpoint: String, commitment: CommitmentConfig, - inner: RwLock>>, + inner: DashMap, ahash::RandomState>, } impl AccountMap { @@ -37,26 +40,23 @@ impl AccountMap { } /// Subscribe user account pub async fn subscribe_account(&self, account: &Pubkey) -> SdkResult<()> { - { - let map = self.inner.read().expect("acquired"); - if map.contains_key(account) { - return Ok(()); - } + if self.inner.contains_key(account) { + return Ok(()); } + debug!(target: LOG_TARGET, "subscribing: {account:?}"); let user = AccountSub::new(&self.endpoint, self.commitment, *account); let user = user.subscribe().await?; - let mut map = self.inner.write().expect("acquired"); - map.insert(*account, user); + self.inner.insert(*account, user); Ok(()) } /// Unsubscribe user account pub fn unsubscribe_account(&self, account: &Pubkey) { - let mut map = self.inner.write().expect("acquired"); - if let Some(u) = map.remove(account) { - let _ = u.unsubscribe(); + if let Some((acc, unsub)) = self.inner.remove(account) { + debug!(target: LOG_TARGET, "unsubscribing: {acc:?}"); + let _ = unsub.unsubscribe(); } } /// Return data of the given `account` as T, if it exists @@ -68,8 +68,9 @@ impl AccountMap { &self, account: &Pubkey, ) -> Option> { - let accounts = self.inner.read().expect("read"); - accounts.get(account).map(|u| u.get_account_data_and_slot()) + self.inner + .get(account) + .map(|u| u.get_account_data_and_slot()) } } @@ -111,7 +112,7 @@ impl AccountSub { let data_and_slot = Arc::new(RwLock::new(AccountSlot::default())); let unsub = self .subscription - .subscribe(Self::SUBSCRIPTION_ID, { + .subscribe(Self::SUBSCRIPTION_ID, true, { let data_and_slot = Arc::clone(&data_and_slot); move |update| { let mut guard = data_and_slot.write().expect("acquired"); diff --git a/crates/src/event_subscriber.rs b/crates/src/event_subscriber.rs index debf389..d403f6f 100644 --- a/crates/src/event_subscriber.rs +++ b/crates/src/event_subscriber.rs @@ -6,9 +6,9 @@ use std::{ time::Duration, }; +use ahash::HashSet; use anchor_lang::{AnchorDeserialize, Discriminator}; use base64::Engine; -use fnv::FnvHashSet; use futures_util::{future::BoxFuture, stream::FuturesOrdered, FutureExt, Stream, StreamExt}; use log::{debug, info, warn}; use regex::Regex; @@ -656,7 +656,7 @@ impl DriftEvent { /// fixed capacity cache of tx signatures struct TxSignatureCache { capacity: usize, - entries: FnvHashSet, + entries: HashSet, age: VecDeque, } @@ -664,7 +664,7 @@ impl TxSignatureCache { fn new(capacity: usize) -> Self { Self { capacity, - entries: FnvHashSet::::with_capacity_and_hasher(capacity, Default::default()), + entries: HashSet::::with_capacity_and_hasher(capacity, Default::default()), age: VecDeque::with_capacity(capacity), } } @@ -689,9 +689,9 @@ impl TxSignatureCache { #[cfg(test)] mod test { + use ahash::HashMap; use anchor_lang::prelude::*; use base64::Engine; - use fnv::FnvHashMap; use futures_util::future::ready; use solana_sdk::{ hash::Hash, @@ -852,7 +852,7 @@ mod test { async fn polled_event_stream_caching() { let _ = env_logger::try_init(); struct MockRpcProvider { - tx_responses: FnvHashMap, + tx_responses: HashMap, signatures: tokio::sync::Mutex>, } @@ -952,7 +952,7 @@ mod test { let signatures: Vec = (0..order_events.len()) .map(|_| Signature::new_unique().to_string()) .collect(); - let mut tx_responses = FnvHashMap::::default(); + let mut tx_responses = HashMap::::default(); for s in signatures.iter() { let (oar, or) = order_events.pop().unwrap(); tx_responses.insert( diff --git a/crates/src/jit_client.rs b/crates/src/jit_client.rs index 1f2204e..80daf36 100644 --- a/crates/src/jit_client.rs +++ b/crates/src/jit_client.rs @@ -20,7 +20,7 @@ use crate::{ accounts::User, build_accounts, constants::{self, state_account, JIT_PROXY_ID}, - DriftClient, MarketId, MarketType, PostOnlyParam, ReferrerInfo, SdkError, SdkResult, + drift_idl, DriftClient, MarketId, MarketType, PostOnlyParam, ReferrerInfo, SdkError, SdkResult, TransactionBuilder, Wallet, }; @@ -131,7 +131,7 @@ impl JitProxyClient { let program_data = tx_builder.program_data(); let account_data = tx_builder.account_data(); - let writable_markets = match order.market_type { + let writable_markets = match order.market_type.into() { MarketType::Perp => { vec![MarketId::perp(order.market_index)] } @@ -161,18 +161,14 @@ impl JitProxyClient { accounts.push(AccountMeta::new(referrer_info.referrer_stats(), false)); } - if order.market_type == MarketType::Spot { + if order.market_type == drift_idl::types::MarketType::Spot { let spot_market_vault = self .drift_client - .get_spot_market_account_and_slot(order.market_index) - .expect("spot market exists") - .data + .try_get_spot_market_account(order.market_index)? .vault; let quote_spot_market_vault = self .drift_client - .get_spot_market_account_and_slot(MarketId::QUOTE_SPOT.index()) - .expect("quote market exists") - .data + .try_get_spot_market_account(MarketId::QUOTE_SPOT.index())? .vault; accounts.push(AccountMeta::new_readonly(spot_market_vault, false)); accounts.push(AccountMeta::new_readonly(quote_spot_market_vault, false)); diff --git a/crates/src/lib.rs b/crates/src/lib.rs index 05cafd4..504908f 100644 --- a/crates/src/lib.rs +++ b/crates/src/lib.rs @@ -31,7 +31,7 @@ use crate::{ oraclemap::{Oracle, OracleMap}, types::{ accounts::{PerpMarket, SpotMarket, User, UserStats}, - *, + MarketType, *, }, utils::get_http_url, }; @@ -78,11 +78,11 @@ pub mod dlob; /// It is not recommended to create multiple instances with `::new()` as this will not re-use underlying resources such /// as network connections or memory allocations /// -/// The client can be used as is to fetch data ad-hoc over RPC or subscribed to receive live data changes +/// The client can be used as is to fetch data ad-hoc over RPC or subscribed to receive live updates /// ```example(no_run) /// let client = DriftClient::new( /// Context::MainNet, -/// RpcClient::new("https://"), +/// RpcClient::new("https://rpc.example.com"), /// key_pair.into() /// ).await.expect("initializes"); /// @@ -90,7 +90,9 @@ pub mod dlob; /// let sol_perp_price = client.oracle_price(MarketId::perp(0)).await; /// /// // Subscribe to live program changes e.g oracle prices, spot/perp market changes, user accounts -/// client.subscribe().await.expect("subscribes"); +/// let markets = [MarketId::perp(0), MarketId::spot(2)]; +/// client.subscribe_markets(&markets).await.expect("subscribes"); +/// client.subscribe_oracles(&markets).await.expect("subscribes"); /// /// // after subscribing, uses Ws-backed local storage /// let sol_perp_price = client.oracle_price(MarketId::perp(0)).await; @@ -108,53 +110,40 @@ pub struct DriftClient { impl DriftClient { /// Create a new `DriftClient` instance /// - /// `context` devnet or mainnet - /// `rpc_client` an RpcClient instance - /// `wallet` wallet to use for tx signing convenience + /// * `context` - devnet or mainnet + /// * `rpc_client` - an RpcClient instance + /// * `wallet` - wallet to use for tx signing convenience pub async fn new(context: Context, rpc_client: RpcClient, wallet: Wallet) -> SdkResult { // check URL format here to fail early, otherwise happens at request time. let _ = get_http_url(&rpc_client.url())?; Ok(Self { backend: Box::leak(Box::new( - DriftClientBackend::new(context, Arc::new(rpc_client), ConfiguredMarkets::All) - .await?, + DriftClientBackend::new(context, Arc::new(rpc_client)).await?, )), context, wallet, }) } - /// Create a new `DriftClient` instance configured for use with a subset of markets - /// Useful to reduce the quantity of network subscriptions/requests + /// Starts background subscriptions for live blockhashes /// - /// `context` devnet or mainnet - /// `rpc_client` an RpcClient instance - /// `wallet` wallet to use for tx signing convenience - /// `markets` subset of markets to use for program lifetime - pub async fn with_markets( - context: Context, - rpc_client: RpcClient, - wallet: Wallet, - markets: ConfiguredMarkets, - ) -> SdkResult { - // check URL format here to fail early, otherwise happens at request time. - let _ = get_http_url(&rpc_client.url())?; - Ok(Self { - context, - backend: Box::leak(Box::new( - DriftClientBackend::new(context, Arc::new(rpc_client), markets).await?, - )), - wallet, - }) + /// This is a no-op if already subscribed + pub async fn subscribe_blockhashes(&self) -> SdkResult<()> { + self.backend.subscribe_blockhashes().await } - /// Starts background subscriptions for live Solana and Drift data e.g. latest blockhashes, oracle prices, markets, etc. - /// The client will subsequently use these values from memory where possible rather - /// than perform network queries. + /// Starts background subscriptions for live market account updates /// /// This is a no-op if already subscribed - pub async fn subscribe(&self) -> SdkResult<()> { - self.backend.subscribe().await + pub async fn subscribe_markets(&self, markets: &[MarketId]) -> SdkResult<()> { + self.backend.subscribe_markets(markets).await + } + + /// Starts background subscriptions for live oracle account updates by market + /// + /// This is a no-op if already subscribed + pub async fn subscribe_oracles(&self, markets: &[MarketId]) -> SdkResult<()> { + self.backend.subscribe_oracles(markets).await } /// Unsubscribe from network resources @@ -179,7 +168,7 @@ impl DriftClient { /// Get an account's open order by id /// - /// `account` the drift user PDA + /// * `account` - the drift user PDA pub async fn get_order_by_id( &self, account: &Pubkey, @@ -192,7 +181,7 @@ impl DriftClient { /// Get an account's open order by user assigned id /// - /// `account` the drift user PDA + /// * `account` - the drift user PDA pub async fn get_order_by_user_id( &self, account: &Pubkey, @@ -209,7 +198,7 @@ impl DriftClient { /// Get all the account's open orders /// - /// `account` the drift user PDA + /// * `account` - the drift user PDA pub async fn all_orders(&self, account: &Pubkey) -> SdkResult> { let user = self.backend.get_user_account(account).await?; @@ -223,7 +212,7 @@ impl DriftClient { /// Get all the account's active positions /// - /// `account` the drift user PDA + /// * `account` - the drift user PDA pub async fn all_positions( &self, account: &Pubkey, @@ -246,7 +235,7 @@ impl DriftClient { /// Get a perp position by market /// - /// `account` the drift user PDA + /// * `account` - the drift user PDA /// /// Returns the position if it exists pub async fn perp_position( @@ -265,7 +254,7 @@ impl DriftClient { /// Get a spot position by market /// - /// `account` the drift user PDA + /// * `account` - the drift user PDA /// /// Returns the position if it exists pub async fn spot_position( @@ -288,9 +277,9 @@ impl DriftClient { } /// Get the user account data - /// Uses cached value if subscribed, fallsback to network query + /// Uses cached value if subscribed, falls back to network query /// - /// `account` the drift user PDA (subaccount) + /// * `account` - the drift user PDA (subaccount) /// /// Returns the deserialized account data (`User`) pub async fn get_user_account(&self, account: &Pubkey) -> SdkResult { @@ -306,15 +295,15 @@ impl DriftClient { } /// Get the latest recent_block_hash - /// uses latest cached if subscribed, otherwise fallsback to network query + /// uses latest cached if subscribed, otherwise falls back to network query pub async fn get_latest_blockhash(&self) -> SdkResult { self.backend.get_latest_blockhash().await } /// Get some account value deserialized as T - /// Uses cached value if subscribed, fallsback to network query + /// Uses cached value if subscribed, falls back to network query /// - /// `account` any onchain account + /// * `account` - any onchain account /// /// Returns the deserialized account data (`User`) pub async fn get_account_value(&self, account: &Pubkey) -> SdkResult { @@ -322,6 +311,7 @@ impl DriftClient { } /// Try to get `account` as `T` using latest local value + /// /// requires account was previously subscribed too. /// like `get_account_value` without async/network fallback pub fn try_get_account(&self, account: &Pubkey) -> SdkResult { @@ -347,8 +337,8 @@ impl DriftClient { /// Sign and send a tx to the network /// - /// `recent_block_hash` some block hash to use for tx signing, if not provided it will be automatically set - /// `config` custom RPC config to use when submitting the tx + /// * `recent_block_hash` - some block hash to use for tx signing, if not provided it will be automatically set + /// * `config` - custom RPC config to use when submitting the tx /// /// Returns the signature on success pub async fn sign_and_send_with_config( @@ -367,18 +357,46 @@ impl DriftClient { .map_err(|err| err.to_out_of_sol_error().unwrap_or(err)) } - /// Get live info of a spot market - /// uses latest cached if subscribed, otherwise fallsback to network query - pub async fn get_spot_market_info(&self, market_index: u16) -> SdkResult { - let market = derive_spot_market_account(market_index); - self.backend.get_account(&market).await + /// Get spot market account + /// uses latest cached if subscribed, otherwise falls back to network query + pub async fn get_spot_market_account(&self, market_index: u16) -> SdkResult { + match self.backend.get_spot_market_account_and_slot(market_index) { + Some(market) => Ok(market.data), + None => { + let market = derive_spot_market_account(market_index); + self.backend.get_account(&market).await + } + } + } + + /// Get perp market account + /// uses latest cached if subscribed, otherwise falls back to network query + pub async fn get_perp_market_account(&self, market_index: u16) -> SdkResult { + match self.backend.get_perp_market_account_and_slot(market_index) { + Some(market) => Ok(market.data), + None => { + let market = derive_perp_market_account(market_index); + self.backend.get_account(&market).await + } + } + } + + /// Try to spot market account from cache + pub fn try_get_spot_market_account(&self, market_index: u16) -> SdkResult { + if let Some(market) = self.backend.get_spot_market_account_and_slot(market_index) { + Ok(market.data) + } else { + Err(SdkError::NoData) + } } - /// Get live info of a perp market - /// uses latest cached if subscribed, otherwise fallsback to network query - pub async fn get_perp_market_info(&self, market_index: u16) -> SdkResult { - let market = derive_perp_market_account(market_index); - self.backend.get_account(&market).await + /// Try to get perp market account from cache + pub fn try_get_perp_market_account(&self, market_index: u16) -> SdkResult { + if let Some(market) = self.backend.get_perp_market_account_and_slot(market_index) { + Ok(market.data) + } else { + Err(SdkError::NoData) + } } /// Lookup a market by symbol @@ -409,7 +427,7 @@ impl DriftClient { } /// Get live oracle price for `market` - /// uses latest cached if subscribed, otherwise fallsback to network query + /// uses latest cached if subscribed, otherwise falls back to network query pub async fn oracle_price(&self, market: MarketId) -> SdkResult { self.backend.oracle_price(market).await } @@ -448,58 +466,11 @@ impl DriftClient { .await } - pub fn get_perp_market_account_and_slot( - &self, - market_index: u16, - ) -> Option> { - self.backend.get_perp_market_account_and_slot(market_index) - } - - pub fn get_spot_market_account_and_slot( - &self, - market_index: u16, - ) -> Option> { - self.backend.get_spot_market_account_and_slot(market_index) - } - - pub fn get_perp_market_account(&self, market_index: u16) -> Option { - self.backend - .get_perp_market_account_and_slot(market_index) - .map(|x| x.data) - } - - pub fn get_spot_market_account(&self, market_index: u16) -> Option { - self.backend - .get_spot_market_account_and_slot(market_index) - .map(|x| x.data) - } - - pub fn num_perp_markets(&self) -> usize { - self.backend.num_perp_markets() - } - - pub fn num_spot_markets(&self) -> usize { - self.backend.num_spot_markets() - } - - pub fn get_oracle_price_data_and_slot(&self, oracle_pubkey: &Pubkey) -> Option { - self.backend.get_oracle_price_data_and_slot(oracle_pubkey) - } - - pub fn get_oracle_price_data_and_slot_for_perp_market( - &self, - market_index: u16, - ) -> Option { - self.backend - .get_oracle_price_data_and_slot_for_perp_market(market_index) - } - - pub fn get_oracle_price_data_and_slot_for_spot_market( - &self, - market_index: u16, - ) -> Option { - self.backend - .get_oracle_price_data_and_slot_for_spot_market(market_index) + /// Try get the latest oracle data for `market` + /// + /// If only the price is required use `oracle_price` intstead + pub fn try_get_oracle_price_data_and_slot(&self, market: MarketId) -> Option { + self.backend.try_get_oracle_price_data_and_slot(market) } /// Subscribe to live updates for some `account` @@ -533,11 +504,7 @@ pub struct DriftClientBackend { impl DriftClientBackend { /// Initialize a new `DriftClientBackend` - async fn new( - context: Context, - rpc_client: Arc, - configured_markets: ConfiguredMarkets, - ) -> SdkResult { + async fn new(context: Context, rpc_client: Arc) -> SdkResult { let perp_market_map = MarketMap::::new(rpc_client.commitment(), rpc_client.url(), true); let spot_market_map = @@ -554,24 +521,24 @@ impl DriftClientBackend { )?; let lookup_table = utils::deserialize_alt(lookup_table_address, &lut)?; - let perp_oracles = perp_market_map - .oracles() - .into_iter() - .filter(|(idx, _, _)| configured_markets.wants(MarketId::perp(*idx))) - .collect(); - let spot_oracles = spot_market_map + let mut all_oracles = Vec::<(MarketId, Pubkey, OracleSource)>::with_capacity( + perp_market_map.size() + spot_market_map.size(), + ); + for market_oracle_info in perp_market_map .oracles() - .into_iter() - .filter(|(idx, _, _)| configured_markets.wants(MarketId::spot(*idx))) - .collect(); + .iter() + .chain(spot_market_map.oracles().iter()) + { + all_oracles.push(*market_oracle_info); + } let oracle_map = OracleMap::new( rpc_client.commitment(), rpc_client.url(), - true, - perp_oracles, - spot_oracles, + all_oracles.as_slice(), ); + let account_map = AccountMap::new(rpc_client.url(), rpc_client.commitment()); + account_map.subscribe_account(state_account()).await?; Ok(Self { rpc_client: Arc::clone(&rpc_client), @@ -584,33 +551,44 @@ impl DriftClientBackend { perp_market_map.values(), lookup_table, ), - account_map: AccountMap::new(rpc_client.url(), rpc_client.commitment()), + account_map, perp_market_map, spot_market_map, oracle_map, }) } - /// Start subscription workers for live program data - async fn subscribe(&self) -> SdkResult<()> { + /// Start subscription for latest block hashes + async fn subscribe_blockhashes(&self) -> SdkResult<()> { self.blockhash_subscriber.subscribe(); + Ok(()) + } + + /// Start subscriptions for market accounts + async fn subscribe_markets(&self, markets: &[MarketId]) -> SdkResult<()> { + let (perps, spot) = markets + .iter() + .partition::, _>(|x| x.is_perp()); let _ = tokio::try_join!( - self.perp_market_map.subscribe(), - self.spot_market_map.subscribe(), - self.oracle_map.subscribe(), - self.account_map.subscribe_account(state_account()), + self.perp_market_map.subscribe(perps.as_slice()), + self.spot_market_map.subscribe(spot.as_slice()), )?; Ok(()) } + /// Start subscriptions for market oracle accounts + async fn subscribe_oracles(&self, markets: &[MarketId]) -> SdkResult<()> { + self.oracle_map.subscribe(markets).await + } + /// End subscriptions to live program data async fn unsubscribe(&self) -> SdkResult<()> { self.blockhash_subscriber.unsubscribe(); - self.perp_market_map.unsubscribe()?; - self.spot_market_map.unsubscribe()?; + self.perp_market_map.unsubscribe_all()?; + self.spot_market_map.unsubscribe_all()?; self.account_map.unsubscribe_account(state_account()); - self.oracle_map.unsubscribe().await + self.oracle_map.unsubscribe_all() } fn get_perp_market_account_and_slot( @@ -627,48 +605,32 @@ impl DriftClientBackend { self.spot_market_map.get(&market_index) } - fn num_perp_markets(&self) -> usize { - self.perp_market_map.size() - } - - fn num_spot_markets(&self) -> usize { - self.spot_market_map.size() + fn try_get_oracle_price_data_and_slot(&self, market: MarketId) -> Option { + self.oracle_map.get_by_market(market) } - fn get_oracle_price_data_and_slot(&self, oracle_pubkey: &Pubkey) -> Option { - self.oracle_map.get(oracle_pubkey) - } - - fn get_oracle_price_data_and_slot_for_perp_market(&self, market_index: u16) -> Option { - let market = self.get_perp_market_account_and_slot(market_index)?; - - let oracle = market.data.amm.oracle; + /// Same as `get_oracle_price_data_and_slot` but checks the oracle pubkey has not changed + /// this can be useful if the oracle address changes in the program + fn get_oracle_price_data_and_slot_checked(&self, market: MarketId) -> Option { let current_oracle = self .oracle_map - .current_perp_oracle(market_index) - .expect("oracle"); - - if oracle != current_oracle { - panic!("invalid perp oracle: {}", market_index); - } - - self.get_oracle_price_data_and_slot(¤t_oracle) - } - - fn get_oracle_price_data_and_slot_for_spot_market(&self, market_index: u16) -> Option { - let market = self.get_spot_market_account_and_slot(market_index)?; + .get_by_market(market) + .expect("oracle") + .pubkey; - let oracle = market.data.oracle; - let current_oracle = self - .oracle_map - .current_spot_oracle(market_index) - .expect("oracle"); + let program_configured_oracle = if market.is_perp() { + let market = self.get_perp_market_account_and_slot(market.index())?; + market.data.amm.oracle + } else { + let market = self.get_spot_market_account_and_slot(market.index())?; + market.data.oracle + }; - if oracle != current_oracle { - panic!("invalid spot oracle: {}", market_index); + if program_configured_oracle != current_oracle { + panic!("invalid oracle: {}", market.index()); } - self.get_oracle_price_data_and_slot(&market.data.oracle) + self.try_get_oracle_price_data_and_slot(market) } /// Return a handle to the inner RPC client @@ -678,7 +640,8 @@ impl DriftClientBackend { /// Get recent tx priority fees /// - /// - `window` # of slots to include in the fee calculation + /// * `writable_markets` - markets to consider for write locks + /// * `window` - # of slots to include in the fee calculation async fn get_recent_priority_fees( &self, writable_markets: &[MarketId], @@ -723,9 +686,26 @@ impl DriftClientBackend { } } + /// Fetch `account` as an Anchor account type `T` along with the slot + async fn get_account_with_slot( + &self, + account: &Pubkey, + ) -> SdkResult> { + if let Some(value) = self.account_map.account_data_and_slot(account) { + Ok(value) + } else { + let (account, slot) = self.get_account_with_slot_raw(account).await?; + Ok(account_map::DataAndSlot { + slot, + data: T::try_deserialize(&mut account.data.as_slice()) + .map_err(|err| SdkError::Anchor(Box::new(err)))?, + }) + } + } + /// Fetch `account` as a drift User account /// - /// uses latest cached if subscribed, otherwise fallsback to network query + /// uses latest cached if subscribed, otherwise falls back to network query async fn get_user_account(&self, account: &Pubkey) -> SdkResult { self.get_account(account).await } @@ -740,7 +720,7 @@ impl DriftClientBackend { /// Returns latest blockhash /// - /// uses latest cached if subscribed, otherwise fallsback to network query + /// uses latest cached if subscribed, otherwise falls back to network query pub async fn get_latest_blockhash(&self) -> SdkResult { match self.blockhash_subscriber.get_latest_blockhash() { Some(hash) => Ok(hash), @@ -787,39 +767,39 @@ impl DriftClientBackend { } /// Fetch the live oracle price for `market` - /// Uses latest local value from an `OracleMap` if subscribed, fallsback to network query + /// + /// Uses latest local value from an `OracleMap` if subscribed, falls back to network query pub async fn oracle_price(&self, market: MarketId) -> SdkResult { - let (oracle, oracle_source) = match market.kind() { - MarketType::Perp => { - let market = self - .program_data - .perp_market_config_by_index(market.index()) - .ok_or(SdkError::InvalidOracle)?; - (market.amm.oracle, market.amm.oracle_source) - } - MarketType::Spot => { - let market = self - .program_data - .spot_market_config_by_index(market.index()) - .ok_or(SdkError::InvalidOracle)?; - (market.oracle, market.oracle_source) - } - }; - - if self.oracle_map.is_subscribed().await { + if self.oracle_map.is_subscribed(&market) { Ok(self - .get_oracle_price_data_and_slot(&oracle) + .try_get_oracle_price_data_and_slot(market) .expect("oracle exists") .data .price) } else { - let (account_data, slot) = self.get_account_with_slot(&oracle).await?; + let (oracle, oracle_source) = match market.kind() { + MarketType::Perp => { + let market = self + .program_data + .perp_market_config_by_index(market.index()) + .ok_or(SdkError::InvalidOracle)?; + (market.amm.oracle, market.amm.oracle_source) + } + MarketType::Spot => { + let market = self + .program_data + .spot_market_config_by_index(market.index()) + .ok_or(SdkError::InvalidOracle)?; + (market.oracle, market.oracle_source) + } + }; + let (account_data, slot) = self.get_account_with_slot_raw(&oracle).await?; ffi::get_oracle_price(oracle_source, &mut (oracle, account_data), slot).map(|o| o.price) } } /// Get account via rpc along with retrieved slot number - pub async fn get_account_with_slot(&self, pubkey: &Pubkey) -> SdkResult<(Account, Slot)> { + async fn get_account_with_slot_raw(&self, pubkey: &Pubkey) -> SdkResult<(Account, Slot)> { match self .rpc_client .get_account_with_commitment(pubkey, self.rpc_client.commitment()) @@ -908,10 +888,10 @@ pub struct TransactionBuilder<'a> { impl<'a> TransactionBuilder<'a> { /// Initialize a new `TransactionBuilder` for default signer /// - /// `program_data` program data from chain - /// `sub_account` drift sub-account address - /// `account_data` drift sub-account data - /// `delegated` set true to build tx for delegated signing + /// * `program_data` - program data from chain + /// * `sub_account` - drift sub-account address + /// * `user` - drift sub-account data + /// * `delegated` - set true to build tx for delegated signing pub fn new<'b>( program_data: &'b ProgramData, sub_account: Pubkey, @@ -956,7 +936,7 @@ impl<'a> TransactionBuilder<'a> { } /// Set the priority fee of the tx /// - /// `microlamports_per_cu` the price per unit of compute in µ-lamports + /// * `microlamports_per_cu` - the price per unit of compute in µ-lamports pub fn with_priority_fee(mut self, microlamports_per_cu: u64, cu_limit: Option) -> Self { let cu_limit_ix = ComputeBudgetInstruction::set_compute_unit_price(microlamports_per_cu); self.ixs.insert(0, cu_limit_ix); @@ -1052,7 +1032,7 @@ impl<'a> TransactionBuilder<'a> { pub fn place_orders(mut self, orders: Vec) -> Self { let mut readable_accounts: Vec = orders .iter() - .map(|o| (o.market_index, o.market_type).into()) + .map(|o| (o.market_index, o.market_type.into()).into()) .collect(); readable_accounts.extend(&self.force_markets.readable); @@ -1109,9 +1089,8 @@ impl<'a> TransactionBuilder<'a> { /// Cancel account's orders matching some criteria /// - /// `market` - tuple of market ID and type (spot or perp) - /// - /// `direction` - long or short + /// * `market` - tuple of market ID and type (spot or perp) + /// * `direction` - long or short pub fn cancel_orders( mut self, market: (u16, MarketType), @@ -1137,7 +1116,7 @@ impl<'a> TransactionBuilder<'a> { accounts, data: InstructionData::data(&drift_idl::instructions::CancelOrders { market_index: Some(idx), - market_type: Some(kind), + market_type: Some(kind.into()), direction, }), }; @@ -1258,11 +1237,11 @@ impl<'a> TransactionBuilder<'a> { /// Add a place and make instruction /// - /// `order` the order to place - /// `taker_info` taker account address and data - /// `taker_order_id` the id of the taker's order to match with - /// `referrer` pukey of the taker's referrer account, if any - /// `fulfilment_type` type of fill for spot orders, ignored for perp orders + /// * `order` - the order to place + /// * `taker_info` - taker account address and data + /// * `taker_order_id` - the id of the taker's order to match with + /// * `referrer` - pukey of the taker's referrer account, if any + /// * `fulfillment_type` - type of fill for spot orders, ignored for perp orders pub fn place_and_make( mut self, order: OrderParams, @@ -1272,7 +1251,7 @@ impl<'a> TransactionBuilder<'a> { fulfillment_type: Option, ) -> Self { let (taker, taker_account) = taker_info; - let is_perp = order.market_type == MarketType::Perp; + let is_perp = order.market_type == MarketType::Perp.into(); let perp_writable = [MarketId::perp(order.market_index)]; let spot_writable = [MarketId::spot(order.market_index), MarketId::QUOTE_SPOT]; let mut accounts = build_accounts( @@ -1303,7 +1282,7 @@ impl<'a> TransactionBuilder<'a> { accounts.push(AccountMeta::new(referrer, false)); } - let ix = if order.market_type == MarketType::Perp { + let ix = if order.market_type == MarketType::Perp.into() { Instruction { program_id: constants::PROGRAM_ID, accounts, @@ -1330,13 +1309,10 @@ impl<'a> TransactionBuilder<'a> { /// Add a place and take instruction /// - /// `order` the order to place - /// - /// `maker_info` pubkey of the maker/counterparty to take against and account data - /// - /// `referrer` pubkey of the maker's referrer account, if any - /// - /// `fulfilment_type` type of fill for spot orders, ignored for perp orders + /// * `order` - the order to place + /// * `maker_info` - pubkey of the maker/counterparty to take against and account data + /// * `referrer` - pubkey of the maker's referrer account, if any + /// * `fulfillment_type` - type of fill for spot orders, ignored for perp orders pub fn place_and_take( mut self, order: OrderParams, @@ -1350,7 +1326,7 @@ impl<'a> TransactionBuilder<'a> { user_accounts.push(maker); } - let is_perp = order.market_type == MarketType::Perp; + let is_perp = order.market_type == MarketType::Perp.into(); let perp_writable = [MarketId::perp(order.market_index)]; let spot_writable = [MarketId::spot(order.market_index), MarketId::QUOTE_SPOT]; @@ -1434,13 +1410,10 @@ impl<'a> TransactionBuilder<'a> { /// Builds a set of required accounts from a user's open positions and additional given accounts /// -/// `base_accounts` base anchor accounts -/// -/// `user` Drift user account data -/// -/// `markets_readable` IDs of markets to include as readable -/// -/// `markets_writable` IDs of markets to include as writable (takes priority over readable) +/// * `base_accounts` - base anchor accounts +/// * `user` - Drift user account data +/// * `markets_readable` - IDs of markets to include as readable +/// * `markets_writable` - IDs of markets to include as writable (takes priority over readable) /// /// # Panics /// if the user has positions in an unknown market (i.e unsupported by the SDK) @@ -1561,17 +1534,17 @@ impl Wallet { /// # panics /// if the key is invalid pub fn from_seed_bs58(seed: &str) -> Self { - let authority: Keypair = Keypair::from_base58_string(seed); + let authority = Keypair::from_base58_string(seed); Self::new(authority) } /// Init wallet from seed bytes, uses default sub-account pub fn from_seed(seed: &[u8]) -> SdkResult { - let authority: Keypair = keypair_from_seed(seed).map_err(|_| SdkError::InvalidSeed)?; + let authority = keypair_from_seed(seed).map_err(|_| SdkError::InvalidSeed)?; Ok(Self::new(authority)) } /// Init wallet with keypair /// - /// `authority` keypair for tx signing + /// * `authority` - keypair for tx signing pub fn new(authority: Keypair) -> Self { Self { stats: Wallet::derive_stats_account(&authority.pubkey()), @@ -1622,7 +1595,6 @@ impl Wallet { let signer: &dyn Signer = self.signer.as_ref(); Ok(signer.sign_message(message)) } - /// Return the wallet authority address pub fn authority(&self) -> &Pubkey { &self.authority @@ -1695,9 +1667,7 @@ mod tests { oracle_map: OracleMap::new( CommitmentConfig::processed(), DEVNET_ENDPOINT.to_string(), - true, - vec![], - vec![], + &[], ), blockhash_subscriber: BlockhashSubscriber::new( Duration::from_secs(2), diff --git a/crates/src/marketmap.rs b/crates/src/marketmap.rs index 6abfd88..fc27337 100644 --- a/crates/src/marketmap.rs +++ b/crates/src/marketmap.rs @@ -19,13 +19,11 @@ use solana_sdk::{clock::Slot, commitment_config::CommitmentConfig, pubkey::Pubke use crate::{ accounts::State, constants::{self, derive_perp_market_account, derive_spot_market_account, state_account}, - drift_idl::types::{MarketType, OracleSource}, + drift_idl::types::OracleSource, memcmp::get_market_filter, utils::get_ws_url, - websocket_program_account_subscriber::{ - ProgramAccountUpdate, WebsocketProgramAccountOptions, WebsocketProgramAccountSubscriber, - }, - DataAndSlot, PerpMarket, SdkError, SdkResult, SpotMarket, UnsubHandle, + websocket_account_subscriber::WebsocketAccountSubscriber, + DataAndSlot, MarketId, MarketType, PerpMarket, SdkError, SdkResult, SpotMarket, UnsubHandle, }; const LOG_TARGET: &str = "marketmap"; @@ -33,7 +31,7 @@ const LOG_TARGET: &str = "marketmap"; pub trait Market { const MARKET_TYPE: MarketType; fn market_index(&self) -> u16; - fn oracle_info(&self) -> (u16, Pubkey, OracleSource); + fn oracle_info(&self) -> (MarketId, Pubkey, OracleSource); } impl Market for PerpMarket { @@ -43,8 +41,12 @@ impl Market for PerpMarket { self.market_index } - fn oracle_info(&self) -> (u16, Pubkey, OracleSource) { - (self.market_index(), self.amm.oracle, self.amm.oracle_source) + fn oracle_info(&self) -> (MarketId, Pubkey, OracleSource) { + ( + MarketId::perp(self.market_index), + self.amm.oracle, + self.amm.oracle_source, + ) } } @@ -55,14 +57,18 @@ impl Market for SpotMarket { self.market_index } - fn oracle_info(&self) -> (u16, Pubkey, OracleSource) { - (self.market_index(), self.oracle, self.oracle_source) + fn oracle_info(&self) -> (MarketId, Pubkey, OracleSource) { + ( + MarketId::spot(self.market_index), + self.oracle, + self.oracle_source, + ) } } pub struct MarketMap { - subscription: WebsocketProgramAccountSubscriber, - marketmap: Arc>>, + marketmap: Arc, ahash::RandomState>>, + subscriptions: DashMap, sync_lock: Option>, latest_slot: Arc, rpc: RpcClient, @@ -76,22 +82,12 @@ where pub const SUBSCRIPTION_ID: &'static str = "marketmap"; pub fn new(commitment: CommitmentConfig, endpoint: String, sync: bool) -> Self { - let filters = vec![get_market_filter(T::MARKET_TYPE)]; - let options = WebsocketProgramAccountOptions { - filters, - commitment, - encoding: UiAccountEncoding::Base64Zstd, - }; - - let url = get_ws_url(&endpoint.clone()).unwrap(); - let subscription = WebsocketProgramAccountSubscriber::new(url, options); - let marketmap = Arc::new(DashMap::new()); let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment); let sync_lock = if sync { Some(Mutex::new(())) } else { None }; Self { - subscription, - marketmap, + subscriptions: Default::default(), + marketmap: Arc::default(), sync_lock, latest_slot: Arc::new(AtomicU64::new(0)), rpc, @@ -99,52 +95,94 @@ where } } - pub async fn subscribe(&self) -> SdkResult<()> { + /// Subscribe to market account updates + pub async fn subscribe(&self, markets: &[MarketId]) -> SdkResult<()> { log::debug!(target: LOG_TARGET, "subscribing: {:?}", T::MARKET_TYPE); if self.sync_lock.is_some() { self.sync().await?; } - let unsub = self.subscription.subscribe(Self::SUBSCRIPTION_ID, { + let url = get_ws_url(&self.rpc.url()).expect("valid url"); + + let mut pending_subscriptions = + Vec::<(u16, WebsocketAccountSubscriber)>::with_capacity(markets.len()); + + for market in markets { + let market_pubkey = match T::MARKET_TYPE { + MarketType::Perp => derive_perp_market_account(market.index()), + MarketType::Spot => derive_spot_market_account(market.index()), + }; + + let market_subscriber = + WebsocketAccountSubscriber::new(url.clone(), market_pubkey, self.rpc.commitment()); + + pending_subscriptions.push((market.index(), market_subscriber)); + } + + let futs_iter = pending_subscriptions.into_iter().map(|(idx, fut)| { let marketmap = Arc::clone(&self.marketmap); let latest_slot = self.latest_slot.clone(); - move |update: &ProgramAccountUpdate| { - if update.data_and_slot.slot > latest_slot.load(Ordering::Relaxed) { - latest_slot.store(update.data_and_slot.slot, Ordering::Relaxed); - } - marketmap.insert( - update.data_and_slot.data.market_index(), - update.data_and_slot.clone(), - ); + async move { + let unsub = fut + .subscribe(Self::SUBSCRIPTION_ID, false, { + move |update| { + if update.slot > latest_slot.load(Ordering::Relaxed) { + latest_slot.store(update.slot, Ordering::Relaxed); + } + marketmap.insert( + idx, + DataAndSlot { + slot: update.slot, + data: T::deserialize(&mut update.data.as_slice()) + .expect("valid market"), + }, + ); + } + }) + .await; + (idx, unsub) } }); - let mut guard = self.unsub.lock().unwrap(); - *guard = Some(unsub); + + let mut subscription_futs = FuturesUnordered::from_iter(futs_iter); + while let Some((market, unsub)) = subscription_futs.next().await { + self.subscriptions.insert(market, unsub?); + } + log::debug!(target: LOG_TARGET, "subscribed: {:?}", T::MARKET_TYPE); Ok(()) } - pub fn unsubscribe(&self) -> SdkResult<()> { - log::debug!(target: LOG_TARGET, "unsubscribing: {:?}", T::MARKET_TYPE); - let mut guard = self.unsub.lock().expect("uncontested"); - if let Some(unsub) = guard.take() { - if unsub.send(()).is_err() { - log::error!("couldn't unsubscribe"); + /// Unsubscribe from updates for the given `markets` + pub fn unsubscribe(&self, markets: &[MarketId]) -> SdkResult<()> { + for market in markets { + if let Some((market, unsub)) = self.subscriptions.remove(&market.index()) { + let _ = unsub.send(()); + self.marketmap.remove(&market); } - self.marketmap.clear(); - self.latest_slot.store(0, Ordering::Relaxed); } - log::debug!(target: LOG_TARGET, "unsubscribed: {:?}", T::MARKET_TYPE); + log::debug!(target: LOG_TARGET, "unsubscribed markets: {markets:?}"); Ok(()) } + /// Unsubscribe from all market updates + pub fn unsubscribe_all(&self) -> SdkResult<()> { + let all_markets: Vec = self + .subscriptions + .iter() + .map(|x| (*x.key(), T::MARKET_TYPE).into()) + .collect(); + self.unsubscribe(&all_markets) + } + pub fn values(&self) -> Vec { self.marketmap.iter().map(|x| x.data.clone()).collect() } - pub fn oracles(&self) -> Vec<(u16, Pubkey, OracleSource)> { + /// Returns a list of oracle info for each market + pub fn oracles(&self) -> Vec<(MarketId, Pubkey, OracleSource)> { self.values().iter().map(|x| x.oracle_info()).collect() } @@ -162,6 +200,7 @@ where .map(|market| market.clone()) } + /// Sync all market accounts #[allow(clippy::await_holding_lock)] pub(crate) async fn sync(&self) -> SdkResult<()> { if self.unsub.lock().unwrap().is_some() { @@ -200,10 +239,10 @@ where /// Fetch all market (program) accounts with multiple fallbacks /// -/// Tries progressively less intensive RPC methods for wider compatiblity with RPC providers: -/// getProgramAccounts, getMultipleAccounts, latstly multiple getAccountInfo +/// Tries progressively less intensive RPC methods for wider compatibility with RPC providers: +/// getProgramAccounts, getMultipleAccounts, lastly multiple getAccountInfo /// -/// Returns deserialized accounts and retreived slot +/// Returns deserialized accounts and retrieved slot pub async fn get_market_accounts_with_fallback( rpc: &RpcClient, ) -> SdkResult<(Vec, Slot)> { @@ -259,10 +298,10 @@ pub async fn get_market_accounts_with_fallback( }; // try 'getMultipleAccounts' - let market_respones = rpc + let market_responses = rpc .get_multiple_accounts_with_commitment(market_pdas.as_slice(), rpc.commitment()) .await; - if let Ok(response) = market_respones { + if let Ok(response) = market_responses { for account in response.value { match account { Some(account) => { @@ -285,8 +324,8 @@ pub async fn get_market_accounts_with_fallback( let mut market_requests = FuturesUnordered::from_iter(market_pdas.iter().map(|acc| rpc.get_account_data(acc))); - while let Some(market_repsonse) = market_requests.next().await { - match market_repsonse { + while let Some(market_response) = market_requests.next().await { + match market_response { Ok(data) => { markets .push(T::deserialize(&mut &data.as_slice()[8..]).expect("market deserializes")); diff --git a/crates/src/math/account_map_builder.rs b/crates/src/math/account_map_builder.rs index 1de9f13..fc7aca5 100644 --- a/crates/src/math/account_map_builder.rs +++ b/crates/src/math/account_map_builder.rs @@ -1,4 +1,4 @@ -use fnv::FnvHashSet; +use ahash::HashMap; use solana_sdk::{account::Account, pubkey::Pubkey}; use crate::{ @@ -21,33 +21,38 @@ pub(crate) struct AccountsListBuilder { impl AccountsListBuilder { /// Constructs the account map + drift state account pub fn build(&mut self, client: &DriftClient, user: &User) -> SdkResult { - let mut oracles = FnvHashSet::::default(); + let mut oracle_markets = + HashMap::::with_capacity_and_hasher(16, Default::default()); let mut spot_markets = Vec::::with_capacity(user.spot_positions.len()); let mut perp_markets = Vec::::with_capacity(user.perp_positions.len()); let drift_state = client.state_config()?; for p in user.spot_positions.iter().filter(|p| !p.is_available()) { - let market = client - .get_spot_market_account(p.market_index) - .expect("spot market"); - if oracles.insert(market.oracle) { + let market = client.try_get_spot_market_account(p.market_index)?; + if oracle_markets + .insert(market.oracle, MarketId::spot(market.market_index)) + .is_none() + { spot_markets.push(market); } } - let quote_market = client - .get_spot_market_account(MarketId::QUOTE_SPOT.index()) - .expect("spot market"); - if oracles.insert(quote_market.oracle) { + let quote_market = client.try_get_spot_market_account(MarketId::QUOTE_SPOT.index())?; + if oracle_markets + .insert(quote_market.oracle, MarketId::QUOTE_SPOT) + .is_none() + { spot_markets.push(quote_market); } for p in user.perp_positions.iter().filter(|p| !p.is_available()) { - let market = client - .get_perp_market_account(p.market_index) - .expect("perp market"); - oracles.insert(market.amm.oracle); - perp_markets.push(market); + let market = client.try_get_perp_market_account(p.market_index)?; + if oracle_markets + .insert(market.amm.oracle, MarketId::perp(market.market_index)) + .is_none() + { + perp_markets.push(market); + }; } for market in spot_markets.iter() { @@ -78,9 +83,9 @@ impl AccountsListBuilder { ); } - for oracle_key in oracles.iter() { + for (oracle_key, market) in oracle_markets.iter() { let oracle = client - .get_oracle_price_data_and_slot(oracle_key) + .try_get_oracle_price_data_and_slot(*market) .expect("oracle exists"); let oracle_owner = oracle_source_to_owner(client.context, oracle.source); diff --git a/crates/src/math/liquidation.rs b/crates/src/math/liquidation.rs index 6eda4c7..9a83bdb 100644 --- a/crates/src/math/liquidation.rs +++ b/crates/src/math/liquidation.rs @@ -19,7 +19,7 @@ use crate::{ accounts::{PerpMarket, SpotMarket, User}, MarginRequirementType, PerpPosition, }, - DriftClient, SdkError, SdkResult, SpotPosition, + DriftClient, MarketId, SdkError, SdkResult, SpotPosition, }; /// Info on a positions liquidation price and unrealized PnL @@ -37,12 +37,11 @@ pub fn calculate_liquidation_price_and_unrealized_pnl( user: &User, market_index: u16, ) -> SdkResult { - let perp_market = client - .get_perp_market_account(market_index) - .ok_or(SdkError::InvalidAccount)?; + let perp_market = client.try_get_perp_market_account(market_index)?; let oracle = client - .get_oracle_price_data_and_slot(&perp_market.amm.oracle) - .ok_or(SdkError::InvalidAccount)?; + .try_get_oracle_price_data_and_slot(MarketId::perp(market_index)) + .ok_or(SdkError::NoData)?; + let position = user .get_perp_position(market_index) .map_err(|_| SdkError::NoPosiiton(market_index))?; @@ -79,9 +78,10 @@ pub fn calculate_unrealized_pnl( ) -> SdkResult { if let Ok(position) = user.get_perp_position(market_index) { let oracle_price = client - .get_oracle_price_data_and_slot_for_perp_market(market_index) + .try_get_oracle_price_data_and_slot(MarketId::perp(market_index)) .map(|x| x.data.price) - .unwrap_or(0); + .ok_or(SdkError::NoData)?; + calculate_unrealized_pnl_inner(&position, oracle_price) } else { Err(SdkError::NoPosiiton(market_index)) @@ -108,12 +108,12 @@ pub fn calculate_liquidation_price( ) -> SdkResult { let mut accounts_builder = AccountsListBuilder::default(); let mut account_maps = accounts_builder.build(client, user)?; - let perp_market = client - .get_perp_market_account(market_index) - .ok_or(SdkError::InvalidAccount)?; + let perp_market = client.try_get_perp_market_account(market_index)?; + let oracle = client - .get_oracle_price_data_and_slot(&perp_market.amm.oracle) - .ok_or(SdkError::InvalidAccount)?; + .try_get_oracle_price_data_and_slot(MarketId::perp(market_index)) + .ok_or(SdkError::NoData)?; + // matching spot market e.g. sol-perp => SOL spot let spot_market = client .program_data() diff --git a/crates/src/oraclemap.rs b/crates/src/oraclemap.rs index a6926f3..e576ded 100644 --- a/crates/src/oraclemap.rs +++ b/crates/src/oraclemap.rs @@ -1,6 +1,9 @@ -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, Mutex, +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, }; use dashmap::DashMap; @@ -10,20 +13,20 @@ use solana_client::nonblocking::rpc_client::RpcClient; use solana_sdk::{ account::Account, clock::Slot, commitment_config::CommitmentConfig, pubkey::Pubkey, }; -use tokio::sync::RwLock; use crate::{ drift_idl::types::OracleSource, ffi::{get_oracle_price, OraclePriceData}, utils::get_ws_url, websocket_account_subscriber::{AccountUpdate, WebsocketAccountSubscriber}, - SdkError, SdkResult, UnsubHandle, + MarketId, SdkError, SdkResult, UnsubHandle, }; const LOG_TARGET: &str = "oraclemap"; -#[derive(Clone, Debug)] +#[derive(Clone, Default, Debug)] pub struct Oracle { + pub market: MarketId, pub pubkey: Pubkey, pub data: OraclePriceData, pub source: OracleSource, @@ -31,130 +34,139 @@ pub struct Oracle { pub raw: Vec, } -pub(crate) struct OracleMap { - pub(crate) oraclemap: Arc>, - oracle_infos: DashMap, - sync_lock: Option>, +/// Dynamic map of Drift market oracle data +/// +/// Caller can subscribe to some subset of markets for Ws backed updates +/// Alternatively, the caller may drive the map by calling `sync` periodically +pub struct OracleMap { + /// Oracle info keyed by market + oraclemap: Arc>, + /// Oracle subscription handles keyed by market + oracle_subscriptions: DashMap, latest_slot: Arc, rpc: RpcClient, - oracle_subscribers: RwLock>, - perp_oracles: DashMap, - spot_oracles: DashMap, } impl OracleMap { - pub const SUBSCRIPTION_ID: &'static str = "oraclemap"; + pub const SUBSCRIPTION_ID: &str = "oraclemap"; + /// Create a new `OracleMap` + /// + /// * `all_oracles` - Exhaustive list of all Drift oracle pubkeys and source by market pub fn new( commitment: CommitmentConfig, endpoint: String, - sync: bool, - perp_oracles: Vec<(u16, Pubkey, OracleSource)>, - spot_oracles: Vec<(u16, Pubkey, OracleSource)>, + all_oracles: &[(MarketId, Pubkey, OracleSource)], ) -> Self { - let oraclemap = Arc::new(DashMap::new()); let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment); - let sync_lock = if sync { Some(Mutex::new(())) } else { None }; - - let oracle_infos_map: DashMap<_, _> = perp_oracles - .iter() - .chain(spot_oracles.iter()) - .map(|(_, pubkey, oracle_source)| (*pubkey, *oracle_source)) - .collect(); - - let perp_oracles_map: DashMap<_, _> = perp_oracles - .iter() - .map(|(market_index, pubkey, _)| (*market_index, *pubkey)) - .collect(); - - let spot_oracles_map: DashMap<_, _> = spot_oracles + let oraclemap = all_oracles .iter() - .map(|(market_index, pubkey, _)| (*market_index, *pubkey)) + .copied() + .map(|(market, pubkey, source)| { + ( + market, + Oracle { + pubkey, + source, + ..Default::default() + }, + ) + }) .collect(); Self { - oraclemap, - oracle_infos: oracle_infos_map, - sync_lock, + oraclemap: Arc::new(oraclemap), + oracle_subscriptions: Default::default(), latest_slot: Arc::new(AtomicU64::new(0)), rpc, - oracle_subscribers: Default::default(), - perp_oracles: perp_oracles_map, - spot_oracles: spot_oracles_map, } } - pub async fn subscribe(&self) -> SdkResult<()> { - log::debug!(target: LOG_TARGET, "subscribing"); - if self.sync_lock.is_some() { - self.sync().await?; - } - - if self.is_subscribed().await { - return Ok(()); - } + /// Subscribe to oracle updates for given `markets` + /// Can be called multiple times to subscribe to additional markets + /// + /// Panics + /// + /// If the `market` oracle pubkey is not loaded + pub async fn subscribe(&self, markets: &[MarketId]) -> SdkResult<()> { + log::debug!(target: LOG_TARGET, "subscribe market oracles: {markets:?}"); + self.sync(markets).await?; let url = get_ws_url(&self.rpc.url()).expect("valid url"); let mut pending_subscriptions = - Vec::::with_capacity(self.oracle_infos.len()); - for oracle_info in self.oracle_infos.iter() { - let oracle_pubkey = oracle_info.key(); - let oracle_subscriber = - WebsocketAccountSubscriber::new(url.clone(), *oracle_pubkey, self.rpc.commitment()); - pending_subscriptions.push(oracle_subscriber); + Vec::<(WebsocketAccountSubscriber, Oracle)>::with_capacity(markets.len()); + + for market in markets { + let oracle_info = self.oraclemap.get(market).expect("oracle exists"); // caller did not supply in `OracleMap::new()` + let oracle_subscriber = WebsocketAccountSubscriber::new( + url.clone(), + oracle_info.pubkey, + self.rpc.commitment(), + ); + + pending_subscriptions.push((oracle_subscriber, oracle_info.clone())); } - let futs_iter = pending_subscriptions.iter().map(|s| { - let source = *self.oracle_infos.get(&s.pubkey).expect("oracle source"); - s.subscribe(Self::SUBSCRIPTION_ID, { - let oracle_map = Arc::clone(&self.oraclemap); - move |update| handler_fn(&oracle_map, source, update) - }) + let futs_iter = pending_subscriptions.into_iter().map(|(sub_fut, info)| { + let oraclemap = Arc::clone(&self.oraclemap); + async move { + let unsub = sub_fut + .subscribe(Self::SUBSCRIPTION_ID, false, { + move |update| update_handler(update, info.market, info.source, &oraclemap) + }) + .await; + (info.market, unsub) + } }); + let mut subscription_futs = FuturesUnordered::from_iter(futs_iter); - let mut oracle_subscriptions = self.oracle_subscribers.write().await; - while let Some(unsub) = subscription_futs.next().await { - oracle_subscriptions.push(unsub.expect("oracle subscribed")); + while let Some((market, unsub)) = subscription_futs.next().await { + self.oracle_subscriptions.insert(market, unsub?); } log::debug!(target: LOG_TARGET, "subscribed"); Ok(()) } - pub async fn unsubscribe(&self) -> SdkResult<()> { - { - let mut oracle_subscribers = self.oracle_subscribers.write().await; - for unsub in oracle_subscribers.drain(..) { + /// Unsubscribe from oracle updates for the given `markets` + pub fn unsubscribe(&self, markets: &[MarketId]) -> SdkResult<()> { + for market in markets { + if let Some((market, unsub)) = self.oracle_subscriptions.remove(market) { let _ = unsub.send(()); + self.oraclemap.remove(&market); } } - - self.oraclemap.clear(); - self.latest_slot.store(0, Ordering::Relaxed); + log::debug!(target: LOG_TARGET, "unsubscribed markets: {markets:?}"); Ok(()) } - #[allow(clippy::await_holding_lock)] - async fn sync(&self) -> SdkResult<()> { - log::debug!(target: LOG_TARGET, "start sync"); - let sync_lock = self.sync_lock.as_ref().expect("expected sync lock"); + /// Unsubscribe from all oracle updates + pub fn unsubscribe_all(&self) -> SdkResult<()> { + let all_markets: Vec = + self.oracle_subscriptions.iter().map(|x| *x.key()).collect(); + self.unsubscribe(&all_markets) + } - let _lock = match sync_lock.try_lock() { - Ok(lock) => lock, - Err(_) => return Ok(()), - }; + /// Fetches account data for each market oracle set by `markets` + /// + /// This may be invoked manually to resync oracle data for some set of markets + pub async fn sync(&self, markets: &[MarketId]) -> SdkResult<()> { + log::debug!(target: LOG_TARGET, "sync oracles for: {markets:?}"); - let oralce_pubkeys = self - .oracle_infos - .iter() - .map(|oracle_info_ref| *oracle_info_ref.key()) - .collect::>(); + let mut market_by_oracle_key = HashMap::::with_capacity(markets.len()); + for market in markets { + if let Some(oracle) = self.oraclemap.get(market) { + market_by_oracle_key.insert(oracle.value().pubkey, *market); + } + } + + let oracle_pubkeys: Vec = market_by_oracle_key.keys().copied().collect(); let (synced_oracles, latest_slot) = - match get_multi_account_data_with_fallback(&self.rpc, &oralce_pubkeys).await { + match get_multi_account_data_with_fallback(&self.rpc, &oracle_pubkeys).await { Ok(result) => result, Err(err) => { warn!(target: LOG_TARGET, "failed to sync oracle accounts"); @@ -162,31 +174,27 @@ impl OracleMap { } }; - if synced_oracles.len() != oralce_pubkeys.len() { - warn!(target: LOG_TARGET, "failed to sync oracle all accounts"); + if synced_oracles.len() != oracle_pubkeys.len() { + warn!(target: LOG_TARGET, "failed to sync all oracle accounts"); return Err(SdkError::InvalidOracle); } for (oracle_pubkey, oracle_account) in synced_oracles.iter() { - let oracle_source = self - .oracle_infos + let market = market_by_oracle_key .get(oracle_pubkey) - .expect("oracle info exists"); - let price_data = get_oracle_price( - *oracle_source, - &mut (*oracle_pubkey, oracle_account.clone()), - latest_slot, - )?; - self.oraclemap.insert( - *oracle_pubkey, - Oracle { - pubkey: *oracle_pubkey, - data: price_data, - source: *oracle_source, - slot: latest_slot, - raw: oracle_account.data.clone(), - }, - ); + .expect("market oracle syncd"); + self.oraclemap.entry(*market).and_modify(|o| { + let price_data = get_oracle_price( + o.source, + &mut (*oracle_pubkey, oracle_account.clone()), + latest_slot, + ) + .expect("valid oracle data"); + + o.raw.clone_from(&oracle_account.data); + o.data = price_data; + o.slot = latest_slot; + }); } self.latest_slot.store(latest_slot, Ordering::Relaxed); @@ -195,63 +203,57 @@ impl OracleMap { Ok(()) } - /// Return whether the `OracleMap`` is subscribed to network changes - pub async fn is_subscribed(&self) -> bool { - let subscribers = self.oracle_subscribers.read().await; - !subscribers.is_empty() - } - + /// Number of oracles known to the `OracleMap` #[allow(dead_code)] - pub fn size(&self) -> usize { + pub fn len(&self) -> usize { self.oraclemap.len() } - pub fn contains(&self, key: &Pubkey) -> bool { - self.oracle_infos.contains_key(key) + /// Returns true if the oraclemap has a subscription for `market` + pub fn is_subscribed(&self, market: &MarketId) -> bool { + self.oracle_subscriptions.contains_key(market) } + /// Get the address of a perp market oracle pub fn current_perp_oracle(&self, market_index: u16) -> Option { - self.perp_oracles.get(&market_index).map(|x| *x) + self.oraclemap + .get(&MarketId::perp(market_index)) + .map(|x| x.pubkey) } + /// Get the address of a spot market oracle pub fn current_spot_oracle(&self, market_index: u16) -> Option { - self.spot_oracles.get(&market_index).map(|x| *x) + self.oraclemap + .get(&MarketId::spot(market_index)) + .map(|x| x.pubkey) } + /// Return Oracle data by pubkey, if known + /// deprecated, see `get_by_key` instead + #[deprecated] pub fn get(&self, key: &Pubkey) -> Option { - self.oraclemap.get(key).map(|x| x.clone()) + self.oraclemap + .iter() + .find(|o| &o.pubkey == key) + .map(|o| o.value().clone()) } - #[allow(dead_code)] - pub fn values(&self) -> Vec { - self.oraclemap.iter().map(|x| x.clone()).collect() + /// Return Oracle data by pubkey, if known + pub fn get_by_key(&self, key: &Pubkey) -> Option { + self.oraclemap + .iter() + .find(|o| &o.pubkey == key) + .map(|o| o.value().clone()) } - pub async fn add_oracle(&self, oracle: Pubkey, source: OracleSource) -> SdkResult<()> { - if self.contains(&oracle) { - return Ok(()); // don't add a duplicate - } - - self.oracle_infos.insert(oracle, source); - - let new_oracle_subscriber = WebsocketAccountSubscriber::new( - get_ws_url(&self.rpc.url()).expect("valid url"), - oracle, - self.rpc.commitment(), - ); - let oracle_source = *self.oracle_infos.get(&oracle).expect("oracle source"); - - let unsub = new_oracle_subscriber - .subscribe(Self::SUBSCRIPTION_ID, { - let oracle_map = Arc::clone(&self.oraclemap); - move |update| handler_fn(&oracle_map, oracle_source, update) - }) - .await?; - - let mut oracle_subscribers = self.oracle_subscribers.write().await; - oracle_subscribers.push(unsub); + /// Return Oracle data by market, if known + pub fn get_by_market(&self, market: MarketId) -> Option { + self.oraclemap.get(&market).map(|o| o.clone()) + } - Ok(()) + #[allow(dead_code)] + pub fn values(&self) -> Vec { + self.oraclemap.iter().map(|x| x.clone()).collect() } pub fn get_latest_slot(&self) -> u64 { @@ -260,10 +262,11 @@ impl OracleMap { } /// Handler fn for new oracle account data -fn handler_fn( - oracle_map: &Arc>, - oracle_source: OracleSource, +fn update_handler( update: &AccountUpdate, + oracle_market: MarketId, + oracle_source: OracleSource, + oracle_map: &DashMap, ) { let oracle_pubkey = update.pubkey; let lamports = update.lamports; @@ -282,13 +285,14 @@ fn handler_fn( ) { Ok(price_data) => { oracle_map - .entry(oracle_pubkey) + .entry(oracle_market) .and_modify(|o| { o.data = price_data; o.slot = update.slot; o.raw.clone_from(&update.data); }) .or_insert(Oracle { + market: oracle_market, pubkey: oracle_pubkey, data: price_data, source: oracle_source, @@ -304,10 +308,10 @@ fn handler_fn( /// Fetch all accounts with multiple fallbacks /// -/// Tries progressively less intensive RPC methods for wider compatiblity with RPC providers: -/// getMultipleAccounts, latstly multiple getAccountInfo +/// Tries progressively less intensive RPC methods for wider compatibility with RPC providers: +/// getMultipleAccounts, lastly multiple getAccountInfo /// -/// Returns deserialized accounts and retreived slot +/// Returns deserialized accounts and retrieved slot async fn get_multi_account_data_with_fallback( rpc: &RpcClient, pubkeys: &[Pubkey], diff --git a/crates/src/priority_fee_subscriber.rs b/crates/src/priority_fee_subscriber.rs index b1dd4ba..d620161 100644 --- a/crates/src/priority_fee_subscriber.rs +++ b/crates/src/priority_fee_subscriber.rs @@ -129,13 +129,13 @@ impl PriorityFeeSubscriber { } } - /// Returns the median priority fee in micro-lamports over the lookback window + /// Returns the median priority fee in micro-lamports over the look-back window pub fn priority_fee(&self) -> u64 { self.priority_fee_nth(0.5) } - /// Returns the n-th percentile priority fee in micro-lamports over the lookback window - /// `precentile` given as decimal 0.0 < n <= 1.0 + /// Returns the n-th percentile priority fee in micro-lamports over the look-back window + /// `percentile` given as decimal 0.0 < n <= 1.0 pub fn priority_fee_nth(&self, percentile: f32) -> u64 { let percentile = percentile.min(1.0); let lock = self.latest_fees.read().expect("acquired"); diff --git a/crates/src/types.rs b/crates/src/types.rs index 0abdb05..0c87c5b 100644 --- a/crates/src/types.rs +++ b/crates/src/types.rs @@ -28,7 +28,7 @@ pub use crate::drift_idl::{ }; use crate::{ constants::{ids, LUT_DEVNET, LUT_MAINNET}, - drift_idl::errors::ErrorCode, + drift_idl::{self, errors::ErrorCode}, Wallet, }; @@ -96,7 +96,7 @@ where } /// Id of a Drift market -#[derive(Copy, Clone, Debug, Default, PartialEq)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct MarketId { index: u16, kind: MarketType, @@ -148,6 +148,26 @@ impl From<(u16, MarketType)> for MarketId { } } +/// Shadow the IDL market type to add some extra traits e.g. Eq + Hash +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +pub enum MarketType { + #[default] + Spot, + Perp, +} + +impl From for MarketType { + fn from(value: drift_idl::types::MarketType) -> Self { + unsafe { std::mem::transmute(value) } + } +} + +impl From for drift_idl::types::MarketType { + fn from(value: MarketType) -> Self { + unsafe { std::mem::transmute(value) } + } +} + /// Provides builder API for Orders #[derive(Default)] pub struct NewOrder { @@ -222,7 +242,7 @@ impl NewOrder { OrderParams { order_type: self.order_type, market_index: self.market_id.index, - market_type: self.market_id.kind, + market_type: self.market_id.kind.into(), price: self.price, base_asset_amount: self.amount, reduce_only: self.reduce_only, @@ -283,7 +303,7 @@ pub enum SdkError { MaxReconnectionAttemptsReached, #[error("jit taker order not found")] JitOrderNotFound, - #[error("not data, client may be unsubsribed")] + #[error("no data, client may be unsubsribed")] NoData, #[error("component is already subscribed")] AlreadySubscribed, @@ -461,30 +481,6 @@ impl ReferrerInfo { } } -#[derive(Default)] -/// Confgured markets for DriftClient setup -pub enum ConfiguredMarkets { - #[default] - All, - Minimal { - perp: Vec, - spot: Vec, - }, -} - -impl ConfiguredMarkets { - /// Returns whether this config wants `market` - pub fn wants(&self, market: MarketId) -> bool { - match self { - Self::All => true, - Self::Minimal { perp, spot } => match market.kind() { - MarketType::Spot => spot.contains(&market), - MarketType::Perp => perp.contains(&market), - }, - } - } -} - impl ToString for MarketType { fn to_string(&self) -> String { match self { diff --git a/crates/src/websocket_account_subscriber.rs b/crates/src/websocket_account_subscriber.rs index cd5b451..60f67de 100644 --- a/crates/src/websocket_account_subscriber.rs +++ b/crates/src/websocket_account_subscriber.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use futures_util::StreamExt; use log::warn; use solana_account_decoder::UiAccountEncoding; @@ -41,46 +43,50 @@ impl WebsocketAccountSubscriber { /// Start a Ws account subscription task /// - /// `subscription_name` some user defined identifier for the subscription - /// `handler_fn` handles updates from the subscription task + /// * `subscription_name` - some user defined identifier for the subscription + /// * `sync` - true if subscription should fetch account data on start + /// * `handler_fn` - handles updates from the subscription task /// /// Fetches the account to set the initial value, then uses event based updates pub async fn subscribe( &self, subscription_name: &'static str, + sync: bool, handler_fn: F, ) -> SdkResult where F: 'static + Send + Fn(&AccountUpdate), { - // seed initial account state - log::debug!("seeding account: {subscription_name}-{:?}", self.pubkey); - let owner: Pubkey; - let rpc = RpcClient::new(get_http_url(&self.url)?); - match rpc - .get_account_with_commitment(&self.pubkey, self.commitment) - .await - { - Ok(response) => { - if let Some(account) = response.value { - owner = account.owner; - handler_fn(&AccountUpdate { - owner, - lamports: account.lamports, - pubkey: self.pubkey, - data: account.data, - slot: response.context.slot, - }); - } else { - return Err(SdkError::InvalidAccount); + if sync { + // seed initial account state + log::debug!("seeding account: {subscription_name}-{:?}", self.pubkey); + let owner: Pubkey; + let rpc = RpcClient::new(get_http_url(&self.url)?); + match rpc + .get_account_with_commitment(&self.pubkey, self.commitment) + .await + { + Ok(response) => { + if let Some(account) = response.value { + owner = account.owner; + handler_fn(&AccountUpdate { + owner, + lamports: account.lamports, + pubkey: self.pubkey, + data: account.data, + slot: response.context.slot, + }); + } else { + return Err(SdkError::InvalidAccount); + } + } + Err(err) => { + warn!("seeding account failed: {err:?}"); + return Err(err.into()); } } - Err(err) => { - warn!("seeding account failed: {err:?}"); - return Err(err.into()); - } + drop(rpc); } - drop(rpc); let account_config = RpcAccountInfoConfig { commitment: Some(self.commitment), @@ -138,7 +144,7 @@ impl WebsocketAccountSubscriber { latest_slot = slot; if let Some(data) = message.value.data.decode() { let account_update = AccountUpdate { - owner, + owner: Pubkey::from_str(&message.value.owner).unwrap(), lamports: message.value.lamports, pubkey, data, diff --git a/tests/integration.rs b/tests/integration.rs index b107118..7e65c24 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,14 +1,14 @@ use drift_rs::{ event_subscriber::RpcClient, math::constants::{BASE_PRECISION_I64, LAMPORTS_PER_SOL_I64, PRICE_PRECISION_U64}, - types::{accounts::User, ConfiguredMarkets, Context, MarketId, NewOrder, PostOnlyParam}, - utils::test_envs::{devnet_endpoint, test_keypair}, + types::{accounts::User, Context, MarketId, NewOrder, PostOnlyParam}, + utils::test_envs::{devnet_endpoint, mainnet_endpoint, test_keypair}, DriftClient, TransactionBuilder, Wallet, }; use solana_sdk::signature::Keypair; #[tokio::test] -async fn get_oracle_prices() { +async fn client_sync_subscribe_devnet() { let client = DriftClient::new( Context::DevNet, RpcClient::new(devnet_endpoint()), @@ -16,6 +16,50 @@ async fn get_oracle_prices() { ) .await .expect("connects"); + let markets = [ + MarketId::spot(1), + MarketId::spot(2), + MarketId::perp(0), + MarketId::perp(1), + MarketId::perp(2), + ]; + tokio::try_join!( + client.subscribe_markets(&markets), + client.subscribe_oracles(&markets), + ) + .expect("subscribes"); + + let price = client.oracle_price(MarketId::perp(1)).await.expect("ok"); + assert!(price > 0); + dbg!(price); + let price = client.oracle_price(MarketId::spot(2)).await.expect("ok"); + assert!(price > 0); + dbg!(price); +} + +#[tokio::test] +async fn client_sync_subscribe_mainnet() { + let _ = env_logger::try_init(); + let client = DriftClient::new( + Context::MainNet, + RpcClient::new(mainnet_endpoint()), + Keypair::new().into(), + ) + .await + .expect("connects"); + let markets = [ + MarketId::spot(1), + MarketId::spot(2), + MarketId::perp(0), + MarketId::perp(1), + MarketId::perp(2), + ]; + tokio::try_join!( + client.subscribe_markets(&markets), + client.subscribe_oracles(&markets), + ) + .expect("subscribes"); + let price = client.oracle_price(MarketId::perp(1)).await.expect("ok"); assert!(price > 0); dbg!(price); @@ -31,18 +75,13 @@ async fn place_and_cancel_orders() { let sol_spot = MarketId::spot(1); let wallet: Wallet = test_keypair().into(); - let client = DriftClient::with_markets( + let client = DriftClient::new( Context::DevNet, RpcClient::new(devnet_endpoint()), wallet.clone(), - ConfiguredMarkets::Minimal { - perp: vec![btc_perp], - spot: vec![sol_spot], - }, ) .await .expect("connects"); - client.subscribe().await.unwrap(); let user: User = client .get_user_account(&wallet.default_sub_account())