diff --git a/crates/src/blockhash_subscriber.rs b/crates/src/blockhash_subscriber.rs index bc03d10..0ab727b 100644 --- a/crates/src/blockhash_subscriber.rs +++ b/crates/src/blockhash_subscriber.rs @@ -169,7 +169,7 @@ mod tests { // after unsub, returns none blockhash_subscriber.unsubscribe(); - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(4)).await; assert!(blockhash_subscriber.get_latest_blockhash().is_none()); } } diff --git a/crates/src/lib.rs b/crates/src/lib.rs index 5b536cb..b5b07a6 100644 --- a/crates/src/lib.rs +++ b/crates/src/lib.rs @@ -696,6 +696,21 @@ impl DriftClient { Ok(()) } + + /// Return a reference to the internal spot market map + pub fn spot_market_map(&self) -> Arc>> { + self.backend.spot_market_map.map() + } + + /// Return a reference to the internal perp market map + pub fn perp_market_map(&self) -> Arc>> { + self.backend.perp_market_map.map() + } + + /// Return a reference to the internal oracle map + pub fn oracle_map(&self) -> Arc> { + self.backend.oracle_map.map() + } } /// Provides the heavy-lifting and network facing features of the SDK diff --git a/crates/src/marketmap.rs b/crates/src/marketmap.rs index b7fc4d3..8cc98a1 100644 --- a/crates/src/marketmap.rs +++ b/crates/src/marketmap.rs @@ -25,6 +25,7 @@ use crate::{ constants::{self, derive_perp_market_account, derive_spot_market_account, state_account}, drift_idl::types::OracleSource, memcmp::get_market_filter, + types::MapOf, websocket_account_subscriber::WebsocketAccountSubscriber, DataAndSlot, MarketId, MarketType, PerpMarket, SdkResult, SpotMarket, UnsubHandle, }; @@ -97,6 +98,11 @@ where } } + /// Return a reference to the internal map data structure + pub fn map(&self) -> Arc>> { + Arc::clone(&self.marketmap) + } + /// Subscribe to market account updates pub async fn subscribe(&self, markets: &[MarketId]) -> SdkResult<()> { log::debug!(target: LOG_TARGET, "subscribing: {:?}", T::MARKET_TYPE); diff --git a/crates/src/oraclemap.rs b/crates/src/oraclemap.rs index 5709420..28d6b29 100644 --- a/crates/src/oraclemap.rs +++ b/crates/src/oraclemap.rs @@ -16,6 +16,7 @@ use solana_sdk::{ use crate::{ drift_idl::types::OracleSource, ffi::{get_oracle_price, OraclePriceData}, + types::MapOf, websocket_account_subscriber::{AccountUpdate, WebsocketAccountSubscriber}, MarketId, SdkError, SdkResult, UnsubHandle, }; @@ -38,11 +39,11 @@ pub struct Oracle { /// Alternatively, the caller may drive the map by calling `sync` periodically pub struct OracleMap { /// Oracle data keyed by pubkey - oraclemap: Arc>, + oraclemap: Arc>, /// Oracle subscription handles by pubkey - subcriptions: DashMap, - /// Oracle pubkey by MarketId (immutable) - oracle_by_market: ReadOnlyView, + subcriptions: DashMap<(Pubkey, u8), UnsubHandle, ahash::RandomState>, + /// Oracle (pubkey, source) by MarketId (immutable) + oracle_by_market: ReadOnlyView, latest_slot: Arc, commitment: CommitmentConfig, pubsub: Arc, @@ -68,7 +69,7 @@ impl OracleMap { .copied() .map(|(market, pubkey, source)| { ( - pubkey, + (pubkey, source as u8), Oracle { market, pubkey, @@ -78,10 +79,10 @@ impl OracleMap { ) }) .collect(); - let oracle_by_market: DashMap = all_oracles + let oracle_by_market: DashMap = all_oracles .iter() .copied() - .map(|(market, pubkey, _)| (market, pubkey)) + .map(|(market, pubkey, source)| (market, (pubkey, source))) .collect(); Self { @@ -109,11 +110,17 @@ impl OracleMap { Vec::<(WebsocketAccountSubscriber, Oracle)>::with_capacity(markets.len()); for market in markets { - let oracle_pubkey = self.oracle_by_market.get(market).expect("oracle exists"); - let oracle_info = self.oraclemap.get(oracle_pubkey).expect("oracle exists"); // caller did not supply in `OracleMap::new()` + let (oracle_pubkey, oracle_source) = + self.oracle_by_market.get(market).expect("oracle exists"); + let oracle_info = self + .oraclemap + .get(&(*oracle_pubkey, *oracle_source as u8)) + .expect("oracle exists"); // caller did not supply in `OracleMap::new()` // markets can share oracle pubkeys, only want one sub per oracle pubkey - if self.subcriptions.contains_key(oracle_pubkey) + if self + .subcriptions + .contains_key(&(*oracle_pubkey, *oracle_source as u8)) || pending_subscriptions .iter() .any(|(_, o)| &o.pubkey == oracle_pubkey) @@ -149,7 +156,8 @@ impl OracleMap { while let Some((info, unsub)) = subscription_futs.next().await { log::debug!(target: LOG_TARGET, "subscribed market oracle: {:?}", info.market); - self.subcriptions.insert(info.pubkey, unsub?); + self.subcriptions + .insert((info.pubkey, info.source as u8), unsub?); } log::debug!(target: LOG_TARGET, "subscribed"); @@ -159,8 +167,11 @@ impl OracleMap { /// Unsubscribe from oracle updates for the given `markets` pub fn unsubscribe(&self, markets: &[MarketId]) -> SdkResult<()> { for market in markets { - if let Some(oracle_pubkey) = self.oracle_by_market.get(market) { - if let Some((market, unsub)) = self.subcriptions.remove(oracle_pubkey) { + if let Some((oracle_pubkey, oracle_source)) = self.oracle_by_market.get(market) { + if let Some((market, unsub)) = self + .subcriptions + .remove(&(*oracle_pubkey, *oracle_source as u8)) + { let _ = unsub.send(()); self.oraclemap.remove(&market); } @@ -188,17 +199,17 @@ impl OracleMap { let markets = HashSet::::from_iter(markets.iter().copied()); log::debug!(target: LOG_TARGET, "sync oracles for: {markets:?}"); - let oracle_pubkeys: Vec = self + let mut oracle_sources = Vec::with_capacity(markets.len()); + let mut oracle_pubkeys = Vec::with_capacity(markets.len()); + + for (_, (pubkey, source)) in self .oracle_by_market .iter() - .filter_map(|(market, pubkey)| { - if markets.contains(market) { - Some(*pubkey) - } else { - None - } - }) - .collect(); + .filter(|(m, _)| markets.contains(m)) + { + oracle_pubkeys.push(*pubkey); + oracle_sources.push(*source); + } let (synced_oracles, latest_slot) = match get_multi_account_data_with_fallback(rpc, &oracle_pubkeys).await { @@ -214,19 +225,23 @@ impl OracleMap { return Err(SdkError::InvalidOracle); } - for (oracle_pubkey, oracle_account) in synced_oracles.iter() { - self.oraclemap.entry(*oracle_pubkey).and_modify(|o| { - let price_data = get_oracle_price( - o.source, - &mut (*oracle_pubkey, oracle_account.clone()), - latest_slot, - ) - .expect("valid oracle data"); - - o.raw.clone_from(&oracle_account.data); - o.data = price_data; - o.slot = latest_slot; - }); + for ((oracle_pubkey, oracle_account), oracle_source) in + synced_oracles.iter().zip(oracle_sources) + { + self.oraclemap + .entry((*oracle_pubkey, oracle_source as u8)) + .and_modify(|o| { + let price_data = get_oracle_price( + o.source, + &mut (*oracle_pubkey, oracle_account.clone()), + latest_slot, + ) + .expect("valid oracle data"); + + o.raw.clone_from(&oracle_account.data); + o.data = price_data; + o.slot = latest_slot; + }); } self.latest_slot.store(latest_slot, Ordering::Relaxed); @@ -243,8 +258,9 @@ impl OracleMap { /// Returns true if the oraclemap has a subscription for `market` pub fn is_subscribed(&self, market: &MarketId) -> bool { - if let Some(oracle_pubkey) = self.oracle_by_market.get(market) { - self.subcriptions.contains_key(oracle_pubkey) + if let Some((oracle_pubkey, oracle_source)) = self.oracle_by_market.get(market) { + self.subcriptions + .contains_key(&(*oracle_pubkey, *oracle_source as u8)) } else { false } @@ -264,20 +280,22 @@ impl OracleMap { /// Return Oracle data by pubkey, if known /// deprecated, see `get_by_key` instead - #[deprecated] - pub fn get(&self, key: &Pubkey) -> Option { - self.get_by_key(key) - } + // #[deprecated] + // pub fn get(&self, key: &Pubkey) -> Option { + // self.get_by_key(key) + // } - /// Return Oracle data by pubkey, if known - pub fn get_by_key(&self, key: &Pubkey) -> Option { - self.oraclemap.get(key).map(|o| o.value().clone()) - } + // /// Return Oracle data by pubkey, if known + // pub fn get_by_key(&self, key: &Pubkey) -> Option { + // self.oraclemap.get(key).map(|o| o.value().clone()) + // } /// Return Oracle data by market, if known pub fn get_by_market(&self, market: &MarketId) -> Option { - if let Some(oracle_pubkey) = self.oracle_by_market.get(market) { - self.oraclemap.get(oracle_pubkey).map(|o| o.clone()) + if let Some((oracle_pubkey, oracle_source)) = self.oracle_by_market.get(market) { + self.oraclemap + .get(&(*oracle_pubkey, *oracle_source as u8)) + .map(|o| o.clone()) } else { None } @@ -291,6 +309,10 @@ impl OracleMap { pub fn get_latest_slot(&self) -> u64 { self.latest_slot.load(Ordering::Relaxed) } + /// Return a reference to the internal map data structure + pub fn map(&self) -> Arc> { + Arc::clone(&self.oraclemap) + } } /// Handler fn for new oracle account data @@ -298,7 +320,7 @@ fn update_handler( update: &AccountUpdate, oracle_market: MarketId, oracle_source: OracleSource, - oracle_map: &DashMap, + oracle_map: &DashMap<(Pubkey, u8), Oracle, ahash::RandomState>, ) { let oracle_pubkey = update.pubkey; let lamports = update.lamports; @@ -317,7 +339,7 @@ fn update_handler( ) { Ok(price_data) => { oracle_map - .entry(oracle_pubkey) + .entry((oracle_pubkey, oracle_source as u8)) .and_modify(|o| { o.data = price_data; o.slot = update.slot; diff --git a/crates/src/types.rs b/crates/src/types.rs index 6c0cce7..850257e 100644 --- a/crates/src/types.rs +++ b/crates/src/types.rs @@ -4,6 +4,7 @@ use std::{ str::FromStr, }; +use dashmap::DashMap; pub use solana_rpc_client_api::config::RpcSendTransactionConfig; pub use solana_sdk::{ commitment_config::CommitmentConfig, message::VersionedMessage, @@ -32,6 +33,9 @@ use crate::{ Wallet, }; +/// Map from K => V +pub type MapOf = DashMap; + /// Handle for unsubscribing from network updates pub type UnsubHandle = oneshot::Sender<()>; diff --git a/tests/integration.rs b/tests/integration.rs index 4915899..03f923e 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use drift_rs::{ event_subscriber::RpcClient, math::constants::{BASE_PRECISION_I64, LAMPORTS_PER_SOL_I64, PRICE_PRECISION_U64}, @@ -199,3 +201,36 @@ async fn client_subscribe_swift_orders() { recv_count += 1; } } + +#[tokio::test] +async fn oracle_source_mixed_precision() { + let _ = env_logger::try_init(); + let client = DriftClient::new( + Context::MainNet, + RpcClient::new(mainnet_endpoint()), + Keypair::new().into(), + ) + .await + .expect("connects"); + + let price = client + .get_oracle_price_data_and_slot(MarketId::perp(4)) + .await + .unwrap() + .data + .price; + println!("Bonk: {price}"); + assert!(price % 100_000 > 0); + + tokio::time::sleep(Duration::from_secs(1)).await; + assert!(client.subscribe_oracles(&[MarketId::perp(4)]).await.is_ok()); + + let price = client + .try_get_oracle_price_data_and_slot(MarketId::perp(4)) + .unwrap() + .data + .price; + + println!("Bonk: {price}"); + assert!(price % 100_000 > 0); +}