Skip to content

Fix bonk oracle precision with subscription #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/src/blockhash_subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod tests {

// after unsub, returns none
blockhash_subscriber.unsubscribe();
tokio::time::sleep(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(4)).await;
assert!(blockhash_subscriber.get_latest_blockhash().is_none());
}
}
15 changes: 15 additions & 0 deletions crates/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,21 @@ impl DriftClient {

Ok(())
}

/// Return a reference to the internal spot market map
pub fn spot_market_map(&self) -> Arc<MapOf<u16, DataAndSlot<SpotMarket>>> {
self.backend.spot_market_map.map()
}

/// Return a reference to the internal perp market map
pub fn perp_market_map(&self) -> Arc<MapOf<u16, DataAndSlot<PerpMarket>>> {
self.backend.perp_market_map.map()
}

/// Return a reference to the internal oracle map
pub fn oracle_map(&self) -> Arc<MapOf<(Pubkey, u8), Oracle>> {
self.backend.oracle_map.map()
}
}

/// Provides the heavy-lifting and network facing features of the SDK
Expand Down
6 changes: 6 additions & 0 deletions crates/src/marketmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{
constants::{self, derive_perp_market_account, derive_spot_market_account, state_account},
drift_idl::types::OracleSource,
memcmp::get_market_filter,
types::MapOf,
websocket_account_subscriber::WebsocketAccountSubscriber,
DataAndSlot, MarketId, MarketType, PerpMarket, SdkResult, SpotMarket, UnsubHandle,
};
Expand Down Expand Up @@ -97,6 +98,11 @@ where
}
}

/// Return a reference to the internal map data structure
pub fn map(&self) -> Arc<MapOf<u16, DataAndSlot<T>>> {
Arc::clone(&self.marketmap)
}

/// Subscribe to market account updates
pub async fn subscribe(&self, markets: &[MarketId]) -> SdkResult<()> {
log::debug!(target: LOG_TARGET, "subscribing: {:?}", T::MARKET_TYPE);
Expand Down
120 changes: 71 additions & 49 deletions crates/src/oraclemap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use solana_sdk::{
use crate::{
drift_idl::types::OracleSource,
ffi::{get_oracle_price, OraclePriceData},
types::MapOf,
websocket_account_subscriber::{AccountUpdate, WebsocketAccountSubscriber},
MarketId, SdkError, SdkResult, UnsubHandle,
};
Expand All @@ -38,11 +39,11 @@ pub struct Oracle {
/// Alternatively, the caller may drive the map by calling `sync` periodically
pub struct OracleMap {
/// Oracle data keyed by pubkey
oraclemap: Arc<DashMap<Pubkey, Oracle, ahash::RandomState>>,
oraclemap: Arc<DashMap<(Pubkey, u8), Oracle, ahash::RandomState>>,
/// Oracle subscription handles by pubkey
subcriptions: DashMap<Pubkey, UnsubHandle, ahash::RandomState>,
/// Oracle pubkey by MarketId (immutable)
oracle_by_market: ReadOnlyView<MarketId, Pubkey>,
subcriptions: DashMap<(Pubkey, u8), UnsubHandle, ahash::RandomState>,
/// Oracle (pubkey, source) by MarketId (immutable)
oracle_by_market: ReadOnlyView<MarketId, (Pubkey, OracleSource)>,
latest_slot: Arc<AtomicU64>,
commitment: CommitmentConfig,
pubsub: Arc<PubsubClient>,
Expand All @@ -68,7 +69,7 @@ impl OracleMap {
.copied()
.map(|(market, pubkey, source)| {
(
pubkey,
(pubkey, source as u8),
Oracle {
market,
pubkey,
Expand All @@ -78,10 +79,10 @@ impl OracleMap {
)
})
.collect();
let oracle_by_market: DashMap<MarketId, Pubkey> = all_oracles
let oracle_by_market: DashMap<MarketId, (Pubkey, OracleSource)> = all_oracles
.iter()
.copied()
.map(|(market, pubkey, _)| (market, pubkey))
.map(|(market, pubkey, source)| (market, (pubkey, source)))
.collect();

Self {
Expand Down Expand Up @@ -109,11 +110,17 @@ impl OracleMap {
Vec::<(WebsocketAccountSubscriber, Oracle)>::with_capacity(markets.len());

for market in markets {
let oracle_pubkey = self.oracle_by_market.get(market).expect("oracle exists");
let oracle_info = self.oraclemap.get(oracle_pubkey).expect("oracle exists"); // caller did not supply in `OracleMap::new()`
let (oracle_pubkey, oracle_source) =
self.oracle_by_market.get(market).expect("oracle exists");
let oracle_info = self
.oraclemap
.get(&(*oracle_pubkey, *oracle_source as u8))
.expect("oracle exists"); // caller did not supply in `OracleMap::new()`

// markets can share oracle pubkeys, only want one sub per oracle pubkey
if self.subcriptions.contains_key(oracle_pubkey)
if self
.subcriptions
.contains_key(&(*oracle_pubkey, *oracle_source as u8))
|| pending_subscriptions
.iter()
.any(|(_, o)| &o.pubkey == oracle_pubkey)
Expand Down Expand Up @@ -149,7 +156,8 @@ impl OracleMap {

while let Some((info, unsub)) = subscription_futs.next().await {
log::debug!(target: LOG_TARGET, "subscribed market oracle: {:?}", info.market);
self.subcriptions.insert(info.pubkey, unsub?);
self.subcriptions
.insert((info.pubkey, info.source as u8), unsub?);
}

log::debug!(target: LOG_TARGET, "subscribed");
Expand All @@ -159,8 +167,11 @@ impl OracleMap {
/// Unsubscribe from oracle updates for the given `markets`
pub fn unsubscribe(&self, markets: &[MarketId]) -> SdkResult<()> {
for market in markets {
if let Some(oracle_pubkey) = self.oracle_by_market.get(market) {
if let Some((market, unsub)) = self.subcriptions.remove(oracle_pubkey) {
if let Some((oracle_pubkey, oracle_source)) = self.oracle_by_market.get(market) {
if let Some((market, unsub)) = self
.subcriptions
.remove(&(*oracle_pubkey, *oracle_source as u8))
{
let _ = unsub.send(());
self.oraclemap.remove(&market);
}
Expand Down Expand Up @@ -188,17 +199,17 @@ impl OracleMap {
let markets = HashSet::<MarketId>::from_iter(markets.iter().copied());
log::debug!(target: LOG_TARGET, "sync oracles for: {markets:?}");

let oracle_pubkeys: Vec<Pubkey> = self
let mut oracle_sources = Vec::with_capacity(markets.len());
let mut oracle_pubkeys = Vec::with_capacity(markets.len());

for (_, (pubkey, source)) in self
.oracle_by_market
.iter()
.filter_map(|(market, pubkey)| {
if markets.contains(market) {
Some(*pubkey)
} else {
None
}
})
.collect();
.filter(|(m, _)| markets.contains(m))
{
oracle_pubkeys.push(*pubkey);
oracle_sources.push(*source);
}

let (synced_oracles, latest_slot) =
match get_multi_account_data_with_fallback(rpc, &oracle_pubkeys).await {
Expand All @@ -214,19 +225,23 @@ impl OracleMap {
return Err(SdkError::InvalidOracle);
}

for (oracle_pubkey, oracle_account) in synced_oracles.iter() {
self.oraclemap.entry(*oracle_pubkey).and_modify(|o| {
let price_data = get_oracle_price(
o.source,
&mut (*oracle_pubkey, oracle_account.clone()),
latest_slot,
)
.expect("valid oracle data");

o.raw.clone_from(&oracle_account.data);
o.data = price_data;
o.slot = latest_slot;
});
for ((oracle_pubkey, oracle_account), oracle_source) in
synced_oracles.iter().zip(oracle_sources)
{
self.oraclemap
.entry((*oracle_pubkey, oracle_source as u8))
.and_modify(|o| {
let price_data = get_oracle_price(
o.source,
&mut (*oracle_pubkey, oracle_account.clone()),
latest_slot,
)
.expect("valid oracle data");

o.raw.clone_from(&oracle_account.data);
o.data = price_data;
o.slot = latest_slot;
});
}

self.latest_slot.store(latest_slot, Ordering::Relaxed);
Expand All @@ -243,8 +258,9 @@ impl OracleMap {

/// Returns true if the oraclemap has a subscription for `market`
pub fn is_subscribed(&self, market: &MarketId) -> bool {
if let Some(oracle_pubkey) = self.oracle_by_market.get(market) {
self.subcriptions.contains_key(oracle_pubkey)
if let Some((oracle_pubkey, oracle_source)) = self.oracle_by_market.get(market) {
self.subcriptions
.contains_key(&(*oracle_pubkey, *oracle_source as u8))
} else {
false
}
Expand All @@ -264,20 +280,22 @@ impl OracleMap {

/// Return Oracle data by pubkey, if known
/// deprecated, see `get_by_key` instead
#[deprecated]
pub fn get(&self, key: &Pubkey) -> Option<Oracle> {
self.get_by_key(key)
}
// #[deprecated]
// pub fn get(&self, key: &Pubkey) -> Option<Oracle> {
// self.get_by_key(key)
// }

/// Return Oracle data by pubkey, if known
pub fn get_by_key(&self, key: &Pubkey) -> Option<Oracle> {
self.oraclemap.get(key).map(|o| o.value().clone())
}
// /// Return Oracle data by pubkey, if known
// pub fn get_by_key(&self, key: &Pubkey) -> Option<Oracle> {
// self.oraclemap.get(key).map(|o| o.value().clone())
// }

/// Return Oracle data by market, if known
pub fn get_by_market(&self, market: &MarketId) -> Option<Oracle> {
if let Some(oracle_pubkey) = self.oracle_by_market.get(market) {
self.oraclemap.get(oracle_pubkey).map(|o| o.clone())
if let Some((oracle_pubkey, oracle_source)) = self.oracle_by_market.get(market) {
self.oraclemap
.get(&(*oracle_pubkey, *oracle_source as u8))
.map(|o| o.clone())
} else {
None
}
Expand All @@ -291,14 +309,18 @@ impl OracleMap {
pub fn get_latest_slot(&self) -> u64 {
self.latest_slot.load(Ordering::Relaxed)
}
/// Return a reference to the internal map data structure
pub fn map(&self) -> Arc<MapOf<(Pubkey, u8), Oracle>> {
Arc::clone(&self.oraclemap)
}
}

/// Handler fn for new oracle account data
fn update_handler(
update: &AccountUpdate,
oracle_market: MarketId,
oracle_source: OracleSource,
oracle_map: &DashMap<Pubkey, Oracle, ahash::RandomState>,
oracle_map: &DashMap<(Pubkey, u8), Oracle, ahash::RandomState>,
) {
let oracle_pubkey = update.pubkey;
let lamports = update.lamports;
Expand All @@ -317,7 +339,7 @@ fn update_handler(
) {
Ok(price_data) => {
oracle_map
.entry(oracle_pubkey)
.entry((oracle_pubkey, oracle_source as u8))
.and_modify(|o| {
o.data = price_data;
o.slot = update.slot;
Expand Down
4 changes: 4 additions & 0 deletions crates/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{
str::FromStr,
};

use dashmap::DashMap;
pub use solana_rpc_client_api::config::RpcSendTransactionConfig;
pub use solana_sdk::{
commitment_config::CommitmentConfig, message::VersionedMessage,
Expand Down Expand Up @@ -32,6 +33,9 @@ use crate::{
Wallet,
};

/// Map from K => V
pub type MapOf<K, V> = DashMap<K, V, ahash::RandomState>;

/// Handle for unsubscribing from network updates
pub type UnsubHandle = oneshot::Sender<()>;

Expand Down
35 changes: 35 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Duration;

use drift_rs::{
event_subscriber::RpcClient,
math::constants::{BASE_PRECISION_I64, LAMPORTS_PER_SOL_I64, PRICE_PRECISION_U64},
Expand Down Expand Up @@ -199,3 +201,36 @@ async fn client_subscribe_swift_orders() {
recv_count += 1;
}
}

#[tokio::test]
async fn oracle_source_mixed_precision() {
let _ = env_logger::try_init();
let client = DriftClient::new(
Context::MainNet,
RpcClient::new(mainnet_endpoint()),
Keypair::new().into(),
)
.await
.expect("connects");

let price = client
.get_oracle_price_data_and_slot(MarketId::perp(4))
.await
.unwrap()
.data
.price;
println!("Bonk: {price}");
assert!(price % 100_000 > 0);

tokio::time::sleep(Duration::from_secs(1)).await;
assert!(client.subscribe_oracles(&[MarketId::perp(4)]).await.is_ok());

let price = client
.try_get_oracle_price_data_and_slot(MarketId::perp(4))
.unwrap()
.data
.price;

println!("Bonk: {price}");
assert!(price % 100_000 > 0);
}
Loading