diff --git a/crates/src/account_map.rs b/crates/src/account_map.rs index b1eb51e..3d7f550 100644 --- a/crates/src/account_map.rs +++ b/crates/src/account_map.rs @@ -13,7 +13,8 @@ use solana_sdk::{clock::Slot, commitment_config::CommitmentConfig, pubkey::Pubke use crate::{ grpc::AccountUpdate, polled_account_subscriber::PolledAccountSubscriber, types::DataAndSlot, - websocket_account_subscriber::WebsocketAccountSubscriber, SdkResult, UnsubHandle, + websocket_account_subscriber::WebsocketAccountSubscriber, SdkResult, + UnsubHandle, }; const LOG_TARGET: &str = "accountmap"; @@ -54,17 +55,29 @@ impl AccountMap { /// * `account` pubkey to subscribe /// pub async fn subscribe_account(&self, account: &Pubkey) -> SdkResult<()> { + self.subscribe_account_with_callback(account, None::).await + } + + pub async fn subscribe_account_with_callback( + &self, + account: &Pubkey, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { if self.inner.contains_key(account) { return Ok(()); } debug!(target: LOG_TARGET, "subscribing: {account:?}"); let user = AccountSub::new(Arc::clone(&self.pubsub), self.commitment, *account); - let sub = user.subscribe(Arc::clone(&self.inner)).await?; + let sub = user.subscribe(Arc::clone(&self.inner), on_account).await?; self.subscriptions.insert(*account, sub); Ok(()) } + /// Subscribe account with RPC polling /// /// * `account` pubkey to subscribe @@ -75,17 +88,31 @@ impl AccountMap { account: &Pubkey, interval: Option, ) -> SdkResult<()> { + self.subscribe_account_polled_with_callback(account, interval, None::) + .await + } + + pub async fn subscribe_account_polled_with_callback( + &self, + account: &Pubkey, + interval: Option, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { if self.inner.contains_key(account) { return Ok(()); } debug!(target: LOG_TARGET, "subscribing: {account:?} @ {interval:?}"); let user = AccountSub::polled(Arc::clone(&self.rpc), *account, interval); - let sub = user.subscribe(Arc::clone(&self.inner)).await?; + let sub = user.subscribe(Arc::clone(&self.inner), on_account).await?; self.subscriptions.insert(*account, sub); Ok(()) } + /// On account update callback for gRPC hook pub(crate) fn on_account_fn(&self) -> impl Fn(&AccountUpdate) { let accounts = Arc::clone(&self.inner); @@ -185,10 +212,15 @@ impl AccountSub { } /// Start the subscriber task - pub async fn subscribe( + pub async fn subscribe( self, accounts: Arc>, - ) -> SdkResult> { + on_account: Option, + ) -> SdkResult> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let on_account = on_account.clone(); let unsub = match self.subscription { SubscriptionImpl::Ws(ref ws) => { let unsub = ws @@ -206,6 +238,10 @@ impl AccountSub { raw: update.data.clone(), slot: update.slot, }); + + if let Some(on_account) = &on_account { + on_account(&update); + } }) .await?; Some(unsub) @@ -225,6 +261,10 @@ impl AccountSub { raw: update.data.clone(), slot: update.slot, }); + + if let Some(on_account) = &on_account { + on_account(update); + } }); Some(unsub) } diff --git a/crates/src/grpc/mod.rs b/crates/src/grpc/mod.rs index 0c5881c..2c56cd3 100644 --- a/crates/src/grpc/mod.rs +++ b/crates/src/grpc/mod.rs @@ -72,7 +72,7 @@ pub struct GrpcSubscribeOpts { /// callback for slot updates pub on_slot: Option>, /// custom callback for account updates - pub on_account: Option<(AccountFilter, Box)>, + pub on_account: Option)>>, /// Network level connection config pub connection_opts: GrpcConnectionOpts, /// Enable inter-slot update notifications @@ -130,9 +130,16 @@ impl GrpcSubscribeOpts { pub fn on_account( mut self, filter: AccountFilter, - on_account: impl Fn(&AccountUpdate) + Send + Sync + 'static, + callback: impl Fn(&AccountUpdate) + Send + Sync + 'static, ) -> Self { - self.on_account = Some((filter, Box::new(on_account))); + match &mut self.on_account { + Some(on_account) => { + on_account.push((filter, Box::new(callback))); + } + None => { + self.on_account = Some(vec![(filter, Box::new(callback))]); + } + } self } /// Set network level connection opts diff --git a/crates/src/lib.rs b/crates/src/lib.rs index 84d881f..c98a63a 100644 --- a/crates/src/lib.rs +++ b/crates/src/lib.rs @@ -47,7 +47,7 @@ use crate::{ swift_order_subscriber::{SignedOrderInfo, SwiftOrderStream}, types::{ accounts::{PerpMarket, SpotMarket, State, User, UserStats}, - DataAndSlot, MarketType, *, + DataAndSlot, MarketType, AccountUpdate, *, }, utils::{get_http_url, get_ws_url}, }; @@ -71,6 +71,9 @@ pub mod types; pub mod grpc; pub mod polled_account_subscriber; pub mod websocket_account_subscriber; + +#[cfg(test)] +mod generic_callback_test; pub mod websocket_program_account_subscriber; // subscribers @@ -157,7 +160,18 @@ impl DriftClient { /// /// This is a no-op if already subscribed pub async fn subscribe_markets(&self, markets: &[MarketId]) -> SdkResult<()> { - self.backend.subscribe_markets(markets).await + self.backend.subscribe_markets(markets, None::).await + } + + pub async fn subscribe_markets_with_callback( + &self, + markets: &[MarketId], + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + self.backend.subscribe_markets(markets, on_account).await } /// Subscribe to all spot and perp markets @@ -165,7 +179,18 @@ impl DriftClient { /// This is a no-op if already subscribed pub async fn subscribe_all_markets(&self) -> SdkResult<()> { let markets = self.get_all_market_ids(); - self.backend.subscribe_markets(&markets).await + self.backend.subscribe_markets(&markets, None::).await + } + + pub async fn subscribe_all_markets_with_callback( + &self, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let markets = self.get_all_market_ids(); + self.backend.subscribe_markets(&markets, on_account).await } /// Subscribe to all spot markets @@ -173,7 +198,18 @@ impl DriftClient { /// This is a no-op if already subscribed pub async fn subscribe_all_spot_markets(&self) -> SdkResult<()> { let markets = self.get_all_spot_market_ids(); - self.backend.subscribe_markets(&markets).await + self.backend.subscribe_markets(&markets, None::).await + } + + pub async fn subscribe_all_spot_markets_with_callback( + &self, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let markets = self.get_all_spot_market_ids(); + self.backend.subscribe_markets(&markets, on_account).await } /// Subscribe to all perp markets @@ -181,7 +217,18 @@ impl DriftClient { /// This is a no-op if already subscribed pub async fn subscribe_all_perp_markets(&self) -> SdkResult<()> { let markets = self.get_all_perp_market_ids(); - self.backend.subscribe_markets(&markets).await + self.backend.subscribe_markets(&markets, None::).await + } + + pub async fn subscribe_all_perp_markets_with_callback( + &self, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let markets = self.get_all_perp_market_ids(); + self.backend.subscribe_markets(&markets, on_account).await } /// Starts background subscriptions for live oracle account updates by market @@ -190,7 +237,18 @@ impl DriftClient { /// /// This is a no-op if already subscribed pub async fn subscribe_oracles(&self, markets: &[MarketId]) -> SdkResult<()> { - self.backend.subscribe_oracles(markets).await + self.backend.subscribe_oracles(markets, None::).await + } + + pub async fn subscribe_oracles_with_callback( + &self, + markets: &[MarketId], + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + self.backend.subscribe_oracles(markets, on_account).await } /// Subscribe to all oracles @@ -198,7 +256,18 @@ impl DriftClient { /// This is a no-op if already subscribed pub async fn subscribe_all_oracles(&self) -> SdkResult<()> { let markets = self.get_all_market_ids(); - self.backend.subscribe_oracles(&markets).await + self.backend.subscribe_oracles(&markets, None::).await + } + + pub async fn subscribe_all_oracles_with_callback( + &self, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let markets = self.get_all_market_ids(); + self.backend.subscribe_oracles(&markets, on_account).await } /// Subscribe to all spot market oracles @@ -206,7 +275,18 @@ impl DriftClient { /// This is a no-op if already subscribed pub async fn subscribe_all_spot_oracles(&self) -> SdkResult<()> { let markets = self.get_all_spot_market_ids(); - self.backend.subscribe_oracles(&markets).await + self.backend.subscribe_oracles(&markets, None::).await + } + + pub async fn subscribe_all_spot_oracles_with_callback( + &self, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let markets = self.get_all_spot_market_ids(); + self.backend.subscribe_oracles(&markets, on_account).await } /// Subscribe to all perp market oracles @@ -214,7 +294,18 @@ impl DriftClient { /// This is a no-op if already subscribed pub async fn subscribe_all_perp_oracles(&self) -> SdkResult<()> { let markets = self.get_all_perp_market_ids(); - self.backend.subscribe_oracles(&markets).await + self.backend.subscribe_oracles(&markets, None::).await + } + + pub async fn subscribe_all_perp_oracles_with_callback( + &self, + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { + let markets = self.get_all_perp_market_ids(); + self.backend.subscribe_oracles(&markets, on_account).await } /// Subscribe to swift order feed(s) for given `markets` @@ -1013,7 +1104,14 @@ impl DriftClientBackend { } /// Start subscriptions for market accounts - async fn subscribe_markets(&self, markets: &[MarketId]) -> SdkResult<()> { + async fn subscribe_markets( + &self, + markets: &[MarketId], + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { if self.is_grpc_subscribed() { log::info!("already subscribed markets via gRPC"); return Err(SdkError::AlreadySubscribed); @@ -1023,21 +1121,28 @@ impl DriftClientBackend { .iter() .partition::, _>(|x| x.is_perp()); let _ = tokio::try_join!( - self.perp_market_map.subscribe(&perps), - self.spot_market_map.subscribe(&spot), + self.perp_market_map.subscribe(&perps, on_account.clone()), + self.spot_market_map.subscribe(&spot, on_account), )?; Ok(()) } /// Start subscriptions for market oracle accounts - async fn subscribe_oracles(&self, markets: &[MarketId]) -> SdkResult<()> { + async fn subscribe_oracles( + &self, + markets: &[MarketId], + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { if self.is_grpc_subscribed() { log::info!("already subscribed oracles via gRPC"); return Err(SdkError::AlreadySubscribed); } - self.oracle_map.subscribe(markets).await + self.oracle_map.subscribe(markets, on_account).await } /// Subscribe to all: markets, oracles, and slot updates over gRPC @@ -1110,8 +1215,10 @@ impl DriftClientBackend { } // set custom callbacks - if let Some((filter, on_account)) = opts.on_account { - grpc.on_account(filter, on_account); + if let Some(callbacks) = opts.on_account { + for (filter, on_account) in callbacks { + grpc.on_account(filter, on_account) + } } if let Some(f) = opts.on_slot { grpc.on_slot(f); diff --git a/crates/src/marketmap.rs b/crates/src/marketmap.rs index 7711cdc..2dacb84 100644 --- a/crates/src/marketmap.rs +++ b/crates/src/marketmap.rs @@ -124,7 +124,14 @@ where } /// Subscribe to market account updates - pub async fn subscribe(&self, markets: &[MarketId]) -> SdkResult<()> { + pub async fn subscribe( + &self, + markets: &[MarketId], + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { log::debug!(target: LOG_TARGET, "subscribing: {:?}", T::MARKET_TYPE); let markets = HashSet::::from_iter(markets.iter().copied()); @@ -151,6 +158,7 @@ where let futs_iter = pending_subscriptions.into_iter().map(|(idx, fut)| { let marketmap = Arc::clone(&self.marketmap); let latest_slot = self.latest_slot.clone(); + let on_account = on_account.clone(); async move { let unsub = fut .subscribe(Self::SUBSCRIPTION_ID, false, { @@ -166,6 +174,9 @@ where .expect("valid market"), }, ); + if let Some(on_account) = &on_account { + on_account(&update); + } } }) .await; @@ -402,7 +413,10 @@ mod tests { ); assert!(map - .subscribe(&[MarketId::perp(0), MarketId::perp(1), MarketId::perp(1)]) + .subscribe( + &[MarketId::perp(0), MarketId::perp(1), MarketId::perp(1)], + None:: + ) .await .is_ok()); assert!(map.is_subscribed(0)); diff --git a/crates/src/oraclemap.rs b/crates/src/oraclemap.rs index 85b9674..31ab30d 100644 --- a/crates/src/oraclemap.rs +++ b/crates/src/oraclemap.rs @@ -19,9 +19,9 @@ use solana_sdk::{ use crate::{ drift_idl::types::OracleSource, ffi::{get_oracle_price, OraclePriceData}, - grpc::AccountUpdate, - types::MapOf, - websocket_account_subscriber::{AccountUpdate as WsAccountUpdate, WebsocketAccountSubscriber}, + grpc::AccountUpdate as GrpcAccountUpdate, + types::{AccountUpdate, MapOf}, + websocket_account_subscriber::WebsocketAccountSubscriber, MarketId, SdkError, SdkResult, UnsubHandle, }; @@ -128,7 +128,14 @@ impl OracleMap { /// Panics /// /// If the `market` oracle pubkey is not loaded - pub async fn subscribe(&self, markets: &[MarketId]) -> SdkResult<()> { + pub async fn subscribe( + &self, + markets: &[MarketId], + on_account: Option, + ) -> SdkResult<()> + where + F: Fn(&crate::AccountUpdate) + Send + Sync + 'static + Clone, + { let markets = HashSet::from_iter(markets); log::debug!(target: LOG_TARGET, "subscribe market oracles: {markets:?}"); @@ -140,10 +147,13 @@ impl OracleMap { self.oracle_by_market.get(market).expect("oracle exists"); // markets can share oracle pubkeys, only want one sub per oracle pubkey - if self.subscriptions.contains_key(oracle_pubkey) + // TEMP FIX: Allow multiple callbacks by allowing duplicate subscriptions + // TODO: Implement proper multi-callback support + if false && // Disable this check temporarily + (self.subscriptions.contains_key(oracle_pubkey) || pending_subscriptions .iter() - .any(|sub| &sub.pubkey == oracle_pubkey) + .any(|sub| &sub.pubkey == oracle_pubkey)) { log::debug!(target: LOG_TARGET, "subscription exists: {market:?}/{oracle_pubkey:?}"); continue; @@ -166,7 +176,7 @@ impl OracleMap { .expect("oracle exists") .clone(); let oracle_shared_mode_ref = oracle_shared_mode.clone(); - + let on_account = on_account.clone(); async move { let unsub = sub_fut .subscribe(Self::SUBSCRIPTION_ID, true, move |update| { @@ -180,6 +190,9 @@ impl OracleMap { } } } + if let Some(on_account) = &on_account { + on_account(&update); + } }) .await; ((sub_fut.pubkey, oracle_shared_mode), unsub) @@ -347,11 +360,11 @@ impl OracleMap { } /// Returns a hook for driving the map with new `Account` updates - pub(crate) fn on_account_fn(&self) -> impl Fn(&AccountUpdate) { + pub(crate) fn on_account_fn(&self) -> impl Fn(&GrpcAccountUpdate) { let oraclemap = self.map(); let oracle_lookup = self.shared_oracles.clone(); - move |update: &AccountUpdate| match oracle_lookup.get(&update.pubkey).unwrap() { + move |update: &GrpcAccountUpdate| match oracle_lookup.get(&update.pubkey).unwrap() { OracleShareMode::Normal { source } => { update_handler_grpc(update, *source, &oraclemap); } @@ -367,7 +380,7 @@ impl OracleMap { /// Handler fn for new oracle account data #[inline] fn update_handler_grpc( - update: &AccountUpdate, + update: &GrpcAccountUpdate, oracle_source: OracleSource, oracle_map: &DashMap<(Pubkey, u8), Oracle, ahash::RandomState>, ) { @@ -411,7 +424,7 @@ fn update_handler_grpc( /// Handler fn for new oracle account data fn update_handler( - update: &WsAccountUpdate, + update: &AccountUpdate, oracle_source: OracleSource, oracle_map: &DashMap<(Pubkey, u8), Oracle, ahash::RandomState>, ) { @@ -423,7 +436,7 @@ fn update_handler( oracle_pubkey, Account { owner: update.owner, - data: update.data.clone(), + data: update.data.clone().to_vec(), lamports, ..Default::default() }, @@ -436,14 +449,14 @@ fn update_handler( .and_modify(|o| { o.data = price_data; o.slot = update.slot; - o.raw.clone_from(&update.data); + o.raw = update.data.to_vec(); }) .or_insert(Oracle { pubkey: oracle_pubkey, source: oracle_source, data: price_data, slot: update.slot, - raw: update.data.clone(), + raw: update.data.to_vec(), }); } Err(err) => { @@ -608,7 +621,7 @@ mod tests { let map = OracleMap::new(pubsub, &all_oracles, CommitmentConfig::confirmed()); let markets = [MarketId::perp(0), MarketId::spot(32), MarketId::perp(4)]; - map.subscribe(&markets).await.expect("subd"); + map.subscribe(&markets, None::).await.expect("subd"); assert_eq!(map.len(), 3); assert!(map.is_subscribed(&MarketId::spot(32))); assert!(map.is_subscribed(&MarketId::perp(4))); @@ -646,10 +659,10 @@ mod tests { MarketId::perp(1), MarketId::spot(1), ]; - map.subscribe(&markets).await.expect("subd"); + map.subscribe(&markets, None::).await.expect("subd"); assert_eq!(map.len(), 2); let markets = [MarketId::perp(0), MarketId::spot(1)]; - map.subscribe(&markets).await.expect("subd"); + map.subscribe(&markets, None::).await.expect("subd"); assert_eq!(map.len(), 2); assert!(map.is_subscribed(&MarketId::perp(0))); @@ -683,7 +696,7 @@ mod tests { &all_oracles, CommitmentConfig::confirmed(), ); - map.subscribe(&[MarketId::spot(0), MarketId::perp(1)]) + map.subscribe(&[MarketId::spot(0), MarketId::perp(1)], None::) .await .expect("subd"); assert!(map.unsubscribe_all().is_ok()); diff --git a/crates/src/polled_account_subscriber.rs b/crates/src/polled_account_subscriber.rs index 0fe644a..2018ddb 100644 --- a/crates/src/polled_account_subscriber.rs +++ b/crates/src/polled_account_subscriber.rs @@ -7,20 +7,7 @@ use solana_rpc_client_api::config::RpcAccountInfoConfig; use solana_sdk::pubkey::Pubkey; use tokio::sync::oneshot; -use crate::UnsubHandle; - -#[derive(Clone, Debug)] -pub struct AccountUpdate { - /// Address of the account - pub pubkey: Pubkey, - /// Owner of the account - pub owner: Pubkey, - pub lamports: u64, - /// Serialized account data (e.g. Anchor/Borsh) - pub data: Vec, - /// Slot retrieved - pub slot: u64, -} +use crate::{AccountUpdate, UnsubHandle}; /// Subscribes to account updates at regular polled intervals pub struct PolledAccountSubscriber { diff --git a/crates/src/types.rs b/crates/src/types.rs index a614c31..72c6631 100644 --- a/crates/src/types.rs +++ b/crates/src/types.rs @@ -689,3 +689,18 @@ mod tests { ) } } + +#[derive(Clone, Debug)] +pub struct AccountUpdate { + /// Address of the account + pub pubkey: Pubkey, + /// Owner of the account + pub owner: Pubkey, + pub lamports: u64, + /// Serialized account data (e.g. Anchor/Borsh) + pub data: Vec, + /// Slot retrieved + pub slot: u64, +} + +pub type OnAccountFn = dyn Fn(&AccountUpdate) + Send + Sync + 'static; diff --git a/crates/src/websocket_account_subscriber.rs b/crates/src/websocket_account_subscriber.rs index 261f06c..d7b63a5 100644 --- a/crates/src/websocket_account_subscriber.rs +++ b/crates/src/websocket_account_subscriber.rs @@ -9,23 +9,10 @@ use solana_rpc_client_api::config::RpcAccountInfoConfig; use solana_sdk::{commitment_config::CommitmentConfig, pubkey::Pubkey}; use tokio::sync::oneshot; -use crate::{utils::get_http_url, SdkError, SdkResult, UnsubHandle}; +use crate::{utils::get_http_url, AccountUpdate, SdkError, SdkResult, UnsubHandle}; const LOG_TARGET: &str = "wsaccsub"; -#[derive(Clone, Debug)] -pub struct AccountUpdate { - /// Address of the account - pub pubkey: Pubkey, - /// Owner of the account - pub owner: Pubkey, - pub lamports: u64, - /// Serialized account data (e.g. Anchor/Borsh) - pub data: Vec, - /// Slot retrieved - pub slot: u64, -} - #[derive(Clone)] pub struct WebsocketAccountSubscriber { pubsub: Arc,