Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Commit 99b2960

Browse files
authored
Merge pull request #105 from wvanlint/fix_proportional_fees
Fixes proportional fees for MPP
2 parents c530515 + 7a73914 commit 99b2960

File tree

1 file changed

+66
-44
lines changed

1 file changed

+66
-44
lines changed

src/lsps2/service.rs

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ enum OutboundJITChannelState {
7777
},
7878
ChannelReady {
7979
htlcs: Vec<InterceptedHTLC>,
80-
amt_to_forward_msat: u64,
80+
opening_fee_msat: u64,
8181
},
8282
}
8383

@@ -183,10 +183,10 @@ impl OutboundJITChannelState {
183183

184184
fn channel_ready(&self) -> Result<Self, ChannelStateError> {
185185
match self {
186-
OutboundJITChannelState::PendingChannelOpen { htlcs, amt_to_forward_msat, .. } => {
186+
OutboundJITChannelState::PendingChannelOpen { htlcs, opening_fee_msat, .. } => {
187187
Ok(OutboundJITChannelState::ChannelReady {
188188
htlcs: htlcs.clone(),
189-
amt_to_forward_msat: *amt_to_forward_msat,
189+
opening_fee_msat: *opening_fee_msat,
190190
})
191191
}
192192
state => Err(ChannelStateError(format!(
@@ -241,8 +241,8 @@ impl OutboundJITChannel {
241241
self.state = self.state.channel_ready()?;
242242

243243
match &self.state {
244-
OutboundJITChannelState::ChannelReady { htlcs, amt_to_forward_msat } => {
245-
Ok((htlcs.clone(), *amt_to_forward_msat))
244+
OutboundJITChannelState::ChannelReady { htlcs, opening_fee_msat } => {
245+
Ok((htlcs.clone(), *opening_fee_msat))
246246
}
247247
impossible_state => Err(LightningError {
248248
err: format!(
@@ -529,11 +529,9 @@ where
529529
peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
530530
{
531531
match jit_channel.channel_ready() {
532-
Ok((htlcs, total_amt_to_forward_msat)) => {
533-
let amounts_to_forward_msat = calculate_amount_to_forward_per_htlc(
534-
&htlcs,
535-
total_amt_to_forward_msat,
536-
);
532+
Ok((htlcs, opening_fee_msat)) => {
533+
let amounts_to_forward_msat =
534+
calculate_amount_to_forward_per_htlc(&htlcs, opening_fee_msat);
537535

538536
for (intercept_id, amount_to_forward_msat) in
539537
amounts_to_forward_msat
@@ -759,40 +757,38 @@ where
759757
}
760758

761759
fn calculate_amount_to_forward_per_htlc(
762-
htlcs: &[InterceptedHTLC], total_amt_to_forward_msat: u64,
760+
htlcs: &[InterceptedHTLC], total_fee_msat: u64,
763761
) -> Vec<(InterceptId, u64)> {
764762
// TODO: we should eventually make sure the HTLCs are all above ChannelDetails::next_outbound_minimum_msat
765-
let total_received_msat: u64 =
763+
let total_expected_outbound_msat: u64 =
766764
htlcs.iter().map(|htlc| htlc.expected_outbound_amount_msat).sum();
765+
if total_fee_msat > total_expected_outbound_msat {
766+
debug_assert!(false, "Fee is larger than the total expected outbound amount.");
767+
return Vec::new();
768+
}
767769

768-
match total_received_msat.checked_sub(total_amt_to_forward_msat) {
769-
Some(total_fee_msat) => {
770-
let mut fee_remaining_msat = total_fee_msat;
771-
772-
let mut per_htlc_forwards = vec![];
773-
774-
for (index, htlc) in htlcs.iter().enumerate() {
775-
let proportional_fee_amt_msat =
776-
total_fee_msat * (htlc.expected_outbound_amount_msat / total_received_msat);
777-
778-
let mut actual_fee_amt_msat =
779-
core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat);
780-
fee_remaining_msat -= actual_fee_amt_msat;
770+
let mut fee_remaining_msat = total_fee_msat;
771+
let mut per_htlc_forwards = vec![];
772+
for (index, htlc) in htlcs.iter().enumerate() {
773+
let proportional_fee_amt_msat = (total_fee_msat as u128
774+
* htlc.expected_outbound_amount_msat as u128
775+
/ total_expected_outbound_msat as u128) as u64;
781776

782-
if index == htlcs.len() - 1 {
783-
actual_fee_amt_msat += fee_remaining_msat;
784-
}
777+
let mut actual_fee_amt_msat = core::cmp::min(fee_remaining_msat, proportional_fee_amt_msat);
778+
actual_fee_amt_msat =
779+
core::cmp::min(actual_fee_amt_msat, htlc.expected_outbound_amount_msat);
780+
fee_remaining_msat -= actual_fee_amt_msat;
785781

786-
let amount_to_forward_msat =
787-
htlc.expected_outbound_amount_msat.saturating_sub(actual_fee_amt_msat);
782+
if index == htlcs.len() - 1 {
783+
actual_fee_amt_msat += fee_remaining_msat;
784+
}
788785

789-
per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat))
790-
}
786+
let amount_to_forward_msat =
787+
htlc.expected_outbound_amount_msat.saturating_sub(actual_fee_amt_msat);
791788

792-
per_htlc_forwards
793-
}
794-
None => Vec::new(),
789+
per_htlc_forwards.push((htlc.intercept_id, amount_to_forward_msat))
795790
}
791+
per_htlc_forwards
796792
}
797793

798794
#[cfg(test)]
@@ -812,7 +808,7 @@ mod tests {
812808

813809
proptest! {
814810
#[test]
815-
fn test_calculate_amount_to_forward((o_0, o_1, o_2, total_amt_to_forward_msat) in arb_forward_amounts()) {
811+
fn proptest_calculate_amount_to_forward((o_0, o_1, o_2, total_fee_msat) in arb_forward_amounts()) {
816812
let htlcs = vec![
817813
InterceptedHTLC {
818814
intercept_id: InterceptId([0; 32]),
@@ -828,10 +824,10 @@ mod tests {
828824
},
829825
];
830826

831-
let result = calculate_amount_to_forward_per_htlc(&htlcs, total_amt_to_forward_msat);
827+
let result = calculate_amount_to_forward_per_htlc(&htlcs, total_fee_msat);
832828
let total_received_msat = o_0 + o_1 + o_2;
833829

834-
if total_received_msat < total_amt_to_forward_msat {
830+
if total_received_msat < total_fee_msat {
835831
assert_eq!(result.len(), 0);
836832
} else {
837833
assert_ne!(result.len(), 0);
@@ -843,16 +839,42 @@ mod tests {
843839
assert!(result[2].1 <= o_2);
844840

845841
let result_sum = result.iter().map(|(_, f)| f).sum::<u64>();
846-
assert!(result_sum >= total_amt_to_forward_msat);
847-
let five_pct = result_sum as f32 * 0.1;
848-
let fair_share_0 = ((o_0 as f32 / total_received_msat as f32) * result_sum as f32).max(o_0 as f32);
842+
assert_eq!(total_received_msat - result_sum, total_fee_msat);
843+
let five_pct = result_sum as f32 * 0.05;
844+
let fair_share_0 = (o_0 as f32 / total_received_msat as f32) * result_sum as f32;
849845
assert!(result[0].1 as f32 <= fair_share_0 + five_pct);
850-
let fair_share_1 = ((o_1 as f32 / total_received_msat as f32) * result_sum as f32).max(o_1 as f32);
846+
let fair_share_1 = (o_1 as f32 / total_received_msat as f32) * result_sum as f32;
851847
assert!(result[1].1 as f32 <= fair_share_1 + five_pct);
852-
let fair_share_2 = ((o_2 as f32 / total_received_msat as f32) * result_sum as f32).max(o_2 as f32);
848+
let fair_share_2 = (o_2 as f32 / total_received_msat as f32) * result_sum as f32;
853849
assert!(result[2].1 as f32 <= fair_share_2 + five_pct);
854850
}
855-
856851
}
857852
}
853+
854+
#[test]
855+
fn test_calculate_amount_to_forward() {
856+
let htlcs = vec![
857+
InterceptedHTLC {
858+
intercept_id: InterceptId([0; 32]),
859+
expected_outbound_amount_msat: 2,
860+
},
861+
InterceptedHTLC {
862+
intercept_id: InterceptId([1; 32]),
863+
expected_outbound_amount_msat: 6,
864+
},
865+
InterceptedHTLC {
866+
intercept_id: InterceptId([2; 32]),
867+
expected_outbound_amount_msat: 2,
868+
},
869+
];
870+
let result = calculate_amount_to_forward_per_htlc(&htlcs, 5);
871+
assert_eq!(
872+
result,
873+
vec![
874+
(htlcs[0].intercept_id, 1),
875+
(htlcs[1].intercept_id, 3),
876+
(htlcs[2].intercept_id, 1),
877+
]
878+
);
879+
}
858880
}

0 commit comments

Comments
 (0)