Skip to content

Feat/max trade size #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 168 additions & 4 deletions crates/src/math/leverage.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use super::{account_map_builder::AccountsListBuilder, constants::PRICE_PRECISION};
use solana_sdk::pubkey::Pubkey;

use super::{
account_map_builder::AccountsListBuilder,
constants::{AMM_RESERVE_PRECISION, BASE_PRECISION, MARGIN_PRECISION, PRICE_PRECISION},
};
use crate::{
accounts::PerpMarket,
ffi::{
calculate_margin_requirement_and_total_collateral_and_liability_info, MarginContextMode,
calculate_margin_requirement_and_total_collateral_and_liability_info, MarginCalculation,
MarginContextMode,
},
types::accounts::User,
DriftClient, SdkError, SdkResult,
ContractType, DriftClient, MarginMode, MarginRequirementType, MarketId, PositionDirection,
SdkError, SdkResult,
};

pub fn get_leverage(client: &DriftClient, user: &User) -> SdkResult<u128> {
Expand Down Expand Up @@ -72,8 +80,164 @@ fn calculate_leverage(total_liability_value: u128, net_asset_value: i128) -> u12
sign as u128 * (leverage * PRICE_PRECISION as f64) as u128
}

#[cfg(feature = "rpc_tests")]
/// Provides margin calculation helpers for User accounts
///
/// sync, requires client is subscribed to necessary markets beforehand
pub trait UserMargin {
/// Calculate user's max. trade size in USDC for a given market and direction
///
/// * `user` - the user account
/// * `market` - the market to trade
/// * `trade_side` - the direction of the trade
///
/// Returns max USDC trade size (PRICE_PRECISION)
fn max_trade_size(
&self,
user: &Pubkey,
market: MarketId,
trade_side: PositionDirection,
) -> SdkResult<u64>;
fn calculate_perp_buying_power(
&self,
user: &User,
market: &PerpMarket,
oracle_price: i64,
collateral_buffer: u64,
) -> SdkResult<u128>;
/// Calculate the user's live margin information
fn calculate_margin_info(&self, user: &User) -> SdkResult<MarginCalculation>;
}

impl UserMargin for DriftClient {
fn calculate_margin_info(&self, user: &User) -> SdkResult<MarginCalculation> {
let mut builder = AccountsListBuilder::default();
let mut accounts = builder.try_build(self, user)?;
calculate_margin_requirement_and_total_collateral_and_liability_info(
user,
&mut accounts,
MarginContextMode::StandardMaintenance,
)
}
fn max_trade_size(
&self,
user: &Pubkey,
market: MarketId,
trade_side: PositionDirection,
) -> SdkResult<u64> {
let oracle = self
.try_get_oracle_price_data_and_slot(market)
.ok_or(SdkError::NoMarketData(market))?;
let oracle_price = oracle.data.price;
let user_account = self.try_get_account::<User>(user)?;

if market.is_perp() {
let market_account = self.try_get_perp_market_account(market.index())?;

let position = user_account
.get_perp_position(market_account.market_index)
.map_err(|_| SdkError::NoMarketData(MarketId::perp(market_account.market_index)))?;
// add any position we have on the opposite side of the current trade
// because we can "flip" the size of this position without taking any extra leverage.
let is_reduce_only = position.base_asset_amount.is_negative() as u8 != trade_side as u8;
let opposite_side_liability_value = calculate_perp_liability_value(
position.base_asset_amount,
oracle_price,
market_account.contract_type == ContractType::Prediction,
);

let lp_buffer = ((oracle_price as u64 * market_account.amm.order_step_size)
/ AMM_RESERVE_PRECISION as u64)
* position.lp_shares.max(1);

let max_position_size = self.calculate_perp_buying_power(
&user_account,
&market_account,
oracle_price,
lp_buffer,
)?;

Ok(max_position_size as u64 + opposite_side_liability_value * is_reduce_only as u64)
} else {
// TODO: implement for spot
Err(SdkError::Generic("spot market unimplemented".to_string()))
}
}
/// Calculate buying power = free collateral / initial margin ratio
///
/// Returns buying power in `QUOTE_PRECISION` units
fn calculate_perp_buying_power(
&self,
user: &User,
market: &PerpMarket,
oracle_price: i64,
collateral_buffer: u64,
) -> SdkResult<u128> {
let position = user
.get_perp_position(market.market_index)
.map_err(|_| SdkError::NoMarketData(MarketId::perp(market.market_index)))?;
let position_with_lp_settle =
position.simulate_settled_lp_position(market, oracle_price)?;

let worst_case_base_amount = position_with_lp_settle
.worst_case_base_asset_amount(oracle_price, market.contract_type)?;

let margin_info = self.calculate_margin_info(user)?;
let free_collateral = margin_info.get_free_collateral() - collateral_buffer as u128;

let margin_ratio = market
.get_margin_ratio(
worst_case_base_amount.unsigned_abs(),
MarginRequirementType::Initial,
user.margin_mode == MarginMode::HighLeverage,
)
.expect("got margin ratio");
let margin_ratio = margin_ratio.max(user.max_margin_ratio);

Ok((free_collateral * MARGIN_PRECISION as u128) / margin_ratio as u128)
}
}

#[inline]
pub fn calculate_perp_liability_value(
base_asset_amount: i64,
price: i64,
is_prediction_market: bool,
) -> u64 {
let max_prediction_price = PRICE_PRECISION as i64;
let max_price =
max_prediction_price * base_asset_amount.is_negative() as i64 * is_prediction_market as i64;
(base_asset_amount * (max_price - price) / BASE_PRECISION as i64).unsigned_abs()
}

#[cfg(test)]
mod tests {
use super::calculate_perp_liability_value;

#[test]
fn calculate_perp_liability_value_works() {
use crate::math::constants::{BASE_PRECISION_I64, PRICE_PRECISION_I64};
// test values taken from TS sdk
assert_eq!(
calculate_perp_liability_value(1 * BASE_PRECISION_I64, 5 * PRICE_PRECISION_I64, false),
5_000_000
);
assert_eq!(
calculate_perp_liability_value(-1 * BASE_PRECISION_I64, 5 * PRICE_PRECISION_I64, false),
5_000_000
);
assert_eq!(
calculate_perp_liability_value(-1 * BASE_PRECISION_I64, 10_000, true),
990_000
);
assert_eq!(
calculate_perp_liability_value(1 * BASE_PRECISION_I64, 90_000, true),
90_000
);
}
}

#[cfg(feature = "rpc_tests")]
mod rpc_tests {
use solana_sdk::signature::Keypair;

use super::*;
Expand Down
Loading