Skip to content

refactor update_aum, add unit tests #1727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions programs/drift/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ pub enum ErrorCode {
ConstituentOracleStale,
#[msg("LP Invariant failed")]
LpInvariantFailed,
#[msg("Invalid constituent derivative weights")]
InvalidConstituentDerivativeWeights,
}

#[macro_export]
Expand Down
191 changes: 21 additions & 170 deletions programs/drift/src/instructions/lp_pool.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::BTreeMap;

use anchor_lang::{prelude::*, Accounts, Key, Result};
use anchor_spl::token_interface::{Mint, TokenAccount, TokenInterface};

Expand All @@ -15,10 +13,7 @@ use crate::{
math::{
self,
casting::Cast,
constants::{
PERCENTAGE_PRECISION, PERCENTAGE_PRECISION_I128, PERCENTAGE_PRECISION_I64,
PERCENTAGE_PRECISION_U64, PRICE_PRECISION, PRICE_PRECISION_I128, QUOTE_PRECISION_I128,
},
constants::{PERCENTAGE_PRECISION_I64, PRICE_PRECISION},
oracle::{is_oracle_valid_for_action, oracle_validity, DriftAction},
safe_math::SafeMath,
},
Expand All @@ -27,9 +22,9 @@ use crate::{
constituent_map::{ConstituentMap, ConstituentSet},
events::{emit_stack, LPMintRedeemRecord, LPSwapRecord},
lp_pool::{
calculate_target_weight, AmmConstituentDatum, AmmConstituentMappingFixed, Constituent,
ConstituentCorrelationsFixed, ConstituentTargetBaseFixed, LPPool, TargetsDatum,
WeightValidationFlags, LP_POOL_SWAP_AUM_UPDATE_DELAY,
update_constituent_target_base_for_derivatives, AmmConstituentDatum,
AmmConstituentMappingFixed, Constituent, ConstituentCorrelationsFixed,
ConstituentTargetBaseFixed, LPPool, TargetsDatum, LP_POOL_SWAP_AUM_UPDATE_DELAY,
MAX_AMM_CACHE_STALENESS_FOR_TARGET_CALC,
},
oracle::OraclePriceData,
Expand Down Expand Up @@ -199,6 +194,7 @@ pub fn handle_update_lp_pool_aum<'c: 'info, 'info>(
let state = &ctx.accounts.state;

let slot = Clock::get()?.slot;
let now = Clock::get()?.unix_timestamp;

let remaining_accounts = &mut ctx.remaining_accounts.iter().peekable();

Expand Down Expand Up @@ -251,88 +247,14 @@ pub fn handle_update_lp_pool_aum<'c: 'info, 'info>(
"Amm cache PDA does not match expected PDA"
)?;

let mut aum: u128 = 0;
let mut crypto_delta = 0_i128;
let mut oldest_slot = u64::MAX;
let mut derivative_groups: BTreeMap<u16, Vec<u16>> = BTreeMap::new();
for i in 0..lp_pool.constituents as usize {
let constituent = constituent_map.get_ref(&(i as u16))?;
if slot.saturating_sub(constituent.last_oracle_slot)
> constituent.oracle_staleness_threshold
{
msg!(
"Constituent {} oracle slot is too stale: {}, current slot: {}",
constituent.constituent_index,
constituent.last_oracle_slot,
slot
);
return Err(ErrorCode::ConstituentOracleStale.into());
}

if constituent.constituent_derivative_index >= 0 && constituent.derivative_weight != 0 {
if !derivative_groups.contains_key(&(constituent.constituent_derivative_index as u16)) {
derivative_groups.insert(
constituent.constituent_derivative_index as u16,
vec![constituent.constituent_index],
);
} else {
derivative_groups
.get_mut(&(constituent.constituent_derivative_index as u16))
.unwrap()
.push(constituent.constituent_index);
}
}

let spot_market = spot_market_map.get_ref(&constituent.spot_market_index)?;

let oracle_slot = constituent.last_oracle_slot;

if oracle_slot < oldest_slot {
oldest_slot = oracle_slot;
}

let (numerator_scale, denominator_scale) = if spot_market.decimals > 6 {
(10_i128.pow(spot_market.decimals - 6), 1)
} else {
(1, 10_i128.pow(6 - spot_market.decimals))
};

let constituent_aum = constituent
.get_full_balance(&spot_market)?
.safe_mul(numerator_scale)?
.safe_div(denominator_scale)?
.safe_mul(constituent.last_oracle_price as i128)?
.safe_div(PRICE_PRECISION_I128)?
.max(0);
msg!(
"constituent: {}, aum: {}, deriv index: {}",
constituent.constituent_index,
constituent_aum,
constituent.constituent_derivative_index
);
if constituent.constituent_index != lp_pool.usdc_consituent_index
&& constituent.constituent_derivative_index != lp_pool.usdc_consituent_index as i16
{
let constituent_target_notional = constituent_target_base
.get(constituent.constituent_index as u32)
.target_base
.safe_mul(constituent.last_oracle_price)?
.safe_div(10_i64.pow(constituent.decimals as u32))?;
crypto_delta = crypto_delta.safe_add(constituent_target_notional.cast()?)?;
}
aum = aum.safe_add(constituent_aum.cast()?)?;
}

let mut aum_i128 = aum.cast::<i128>()?;
for cache_datum in amm_cache.iter() {
aum_i128 -= cache_datum.quote_owed_from_lp_pool as i128;
}
aum = aum_i128.max(0i128).cast::<u128>()?;

lp_pool.oldest_oracle_slot = oldest_slot;
lp_pool.last_aum = aum;
lp_pool.last_aum_slot = slot;
lp_pool.last_aum_ts = Clock::get()?.unix_timestamp;
let (aum, crypto_delta, derivative_groups) = lp_pool.update_aum(
now,
slot,
&constituent_map,
&spot_market_map,
&constituent_target_base,
&amm_cache,
)?;

// Set USDC stable weight
let total_stable_target_base = aum
Expand All @@ -341,16 +263,7 @@ pub fn handle_update_lp_pool_aum<'c: 'info, 'info>(
.max(0_i128);
constituent_target_base
.get_mut(lp_pool.usdc_consituent_index as u32)
.target_base = total_stable_target_base
.safe_mul(
10_i128.pow(
constituent_map
.get_ref(&lp_pool.usdc_consituent_index)?
.decimals as u32,
),
)?
.safe_div(QUOTE_PRECISION_I128)?
.cast::<i64>()?;
.target_base = total_stable_target_base.cast::<i64>()?;

msg!(
"stable target base: {}",
Expand All @@ -361,75 +274,13 @@ pub fn handle_update_lp_pool_aum<'c: 'info, 'info>(
msg!("aum: {}, crypto_delta: {}", aum, crypto_delta);
msg!("derivative groups: {:?}", derivative_groups);

// Handle all other derivatives
for (parent_index, constituent_indexes) in derivative_groups.iter() {
let parent_constituent = constituent_map.get_ref(&(parent_index))?;
let parent_target_base = constituent_target_base
.get(*parent_index as u32)
.target_base;
let target_parent_weight = calculate_target_weight(
parent_target_base,
&*spot_market_map.get_ref(&parent_constituent.spot_market_index)?,
parent_constituent.last_oracle_price,
aum,
WeightValidationFlags::NONE,
)?;
let mut derivative_weights_sum = 0;
for constituent_index in constituent_indexes {
let constituent = constituent_map.get_ref(constituent_index)?;
if constituent.last_oracle_price
< parent_constituent
.last_oracle_price
.safe_mul(constituent.constituent_derivative_depeg_threshold as i64)?
.safe_div(PERCENTAGE_PRECISION_I64)?
{
msg!(
"Constituent {} last oracle price {} is too low compared to parent constituent {} last oracle price {}. Assuming depegging and setting target base to 0.",
constituent.constituent_index,
constituent.last_oracle_price,
parent_constituent.constituent_index,
parent_constituent.last_oracle_price
);
constituent_target_base
.get_mut(*constituent_index as u32)
.target_base = 0_i64;
continue;
}

derivative_weights_sum += constituent.derivative_weight;

let target_weight = target_parent_weight
.safe_mul(constituent.derivative_weight as i64)?
.safe_div(PERCENTAGE_PRECISION_I64)?;

msg!(
"constituent: {}, target weight: {}",
constituent_index,
target_weight,
);
let target_base = lp_pool
.last_aum
.cast::<i128>()?
.safe_mul(target_weight as i128)?
.safe_div(PERCENTAGE_PRECISION_I128)?
.safe_mul(10_i128.pow(constituent.decimals as u32))?
.safe_div(constituent.last_oracle_price as i128)?;

msg!(
"constituent: {}, target base: {}",
constituent_index,
target_base
);
constituent_target_base
.get_mut(*constituent_index as u32)
.target_base = target_base.cast::<i64>()?;
}
constituent_target_base
.get_mut(*parent_index as u32)
.target_base = parent_target_base
.safe_mul(PERCENTAGE_PRECISION_U64.safe_sub(derivative_weights_sum)? as i64)?
.safe_div(PERCENTAGE_PRECISION_I64)?;
}
update_constituent_target_base_for_derivatives(
aum,
&derivative_groups,
&constituent_map,
&spot_market_map,
&mut constituent_target_base,
)?;

Ok(())
}
Expand Down
1 change: 0 additions & 1 deletion programs/drift/src/instructions/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ use crate::state::fulfillment_params::openbook_v2::OpenbookV2FulfillmentParams;
use crate::state::fulfillment_params::phoenix::PhoenixFulfillmentParams;
use crate::state::fulfillment_params::serum::SerumFulfillmentParams;
use crate::state::high_leverage_mode_config::HighLeverageModeConfig;
use crate::state::lp_pool::{Constituent, LPPool};
use crate::state::margin_calculation::MarginContext;
use crate::state::oracle::StrictOraclePrice;
use crate::state::order_params::{
Expand Down
40 changes: 40 additions & 0 deletions programs/drift/src/state/constituent_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,46 @@ impl<'a> ConstituentMap<'a> {
Ok(constituent_map)
}

pub fn load_multiple<'c: 'a>(
account_info: Vec<&'c AccountInfo<'a>>,
must_be_writable: bool,
) -> DriftResult<ConstituentMap<'a>> {
let mut constituent_map: ConstituentMap = ConstituentMap(BTreeMap::new());

let account_info_iter = account_info.into_iter();
for account_info in account_info_iter {
let constituent_discriminator: [u8; 8] = Constituent::discriminator();
let data = account_info
.try_borrow_data()
.or(Err(ErrorCode::ConstituentCouldNotLoad))?;

let expected_data_len = Constituent::SIZE;
if data.len() < expected_data_len {
return Err(ErrorCode::ConstituentCouldNotLoad);
}

let account_discriminator = array_ref![data, 0, 8];
if account_discriminator != &constituent_discriminator {
return Err(ErrorCode::ConstituentCouldNotLoad);
}

// constituent index 284 bytes from front of account
let constituent_index = u16::from_le_bytes(*array_ref![data, 284, 2]);

let is_writable = account_info.is_writable;
let account_loader: AccountLoader<Constituent> = AccountLoader::try_from(account_info)
.or(Err(ErrorCode::ConstituentCouldNotLoad))?;

if must_be_writable && !is_writable {
return Err(ErrorCode::ConstituentWrongMutability);
}

constituent_map.0.insert(constituent_index, account_loader);
}

Ok(constituent_map)
}

pub fn empty() -> Self {
ConstituentMap(BTreeMap::new())
}
Expand Down
Loading