@@ -77,7 +77,7 @@ enum OutboundJITChannelState {
77
77
} ,
78
78
ChannelReady {
79
79
htlcs : Vec < InterceptedHTLC > ,
80
- amt_to_forward_msat : u64 ,
80
+ opening_fee_msat : u64 ,
81
81
} ,
82
82
}
83
83
@@ -183,10 +183,10 @@ impl OutboundJITChannelState {
183
183
184
184
fn channel_ready ( & self ) -> Result < Self , ChannelStateError > {
185
185
match self {
186
- OutboundJITChannelState :: PendingChannelOpen { htlcs, amt_to_forward_msat , .. } => {
186
+ OutboundJITChannelState :: PendingChannelOpen { htlcs, opening_fee_msat , .. } => {
187
187
Ok ( OutboundJITChannelState :: ChannelReady {
188
188
htlcs : htlcs. clone ( ) ,
189
- amt_to_forward_msat : * amt_to_forward_msat ,
189
+ opening_fee_msat : * opening_fee_msat ,
190
190
} )
191
191
}
192
192
state => Err ( ChannelStateError ( format ! (
@@ -241,8 +241,8 @@ impl OutboundJITChannel {
241
241
self . state = self . state . channel_ready ( ) ?;
242
242
243
243
match & self . state {
244
- OutboundJITChannelState :: ChannelReady { htlcs, amt_to_forward_msat } => {
245
- Ok ( ( htlcs. clone ( ) , * amt_to_forward_msat ) )
244
+ OutboundJITChannelState :: ChannelReady { htlcs, opening_fee_msat } => {
245
+ Ok ( ( htlcs. clone ( ) , * opening_fee_msat ) )
246
246
}
247
247
impossible_state => Err ( LightningError {
248
248
err : format ! (
@@ -529,11 +529,9 @@ where
529
529
peer_state. outbound_channels_by_intercept_scid . get_mut ( & intercept_scid)
530
530
{
531
531
match jit_channel. channel_ready ( ) {
532
- Ok ( ( htlcs, total_amt_to_forward_msat) ) => {
533
- let amounts_to_forward_msat = calculate_amount_to_forward_per_htlc (
534
- & htlcs,
535
- total_amt_to_forward_msat,
536
- ) ;
532
+ Ok ( ( htlcs, opening_fee_msat) ) => {
533
+ let amounts_to_forward_msat =
534
+ calculate_amount_to_forward_per_htlc ( & htlcs, opening_fee_msat) ;
537
535
538
536
for ( intercept_id, amount_to_forward_msat) in
539
537
amounts_to_forward_msat
@@ -759,40 +757,38 @@ where
759
757
}
760
758
761
759
fn calculate_amount_to_forward_per_htlc (
762
- htlcs : & [ InterceptedHTLC ] , total_amt_to_forward_msat : u64 ,
760
+ htlcs : & [ InterceptedHTLC ] , total_fee_msat : u64 ,
763
761
) -> Vec < ( InterceptId , u64 ) > {
764
762
// TODO: we should eventually make sure the HTLCs are all above ChannelDetails::next_outbound_minimum_msat
765
- let total_received_msat : u64 =
763
+ let total_expected_outbound_msat : u64 =
766
764
htlcs. iter ( ) . map ( |htlc| htlc. expected_outbound_amount_msat ) . sum ( ) ;
765
+ if total_fee_msat > total_expected_outbound_msat {
766
+ debug_assert ! ( false , "Fee is larger than the total expected outbound amount." ) ;
767
+ return Vec :: new ( ) ;
768
+ }
767
769
768
- match total_received_msat. checked_sub ( total_amt_to_forward_msat) {
769
- Some ( total_fee_msat) => {
770
- let mut fee_remaining_msat = total_fee_msat;
771
-
772
- let mut per_htlc_forwards = vec ! [ ] ;
773
-
774
- for ( index, htlc) in htlcs. iter ( ) . enumerate ( ) {
775
- let proportional_fee_amt_msat =
776
- total_fee_msat * ( htlc. expected_outbound_amount_msat / total_received_msat) ;
777
-
778
- let mut actual_fee_amt_msat =
779
- core:: cmp:: min ( fee_remaining_msat, proportional_fee_amt_msat) ;
780
- fee_remaining_msat -= actual_fee_amt_msat;
770
+ let mut fee_remaining_msat = total_fee_msat;
771
+ let mut per_htlc_forwards = vec ! [ ] ;
772
+ for ( index, htlc) in htlcs. iter ( ) . enumerate ( ) {
773
+ let proportional_fee_amt_msat = ( total_fee_msat as u128
774
+ * htlc. expected_outbound_amount_msat as u128
775
+ / total_expected_outbound_msat as u128 ) as u64 ;
781
776
782
- if index == htlcs. len ( ) - 1 {
783
- actual_fee_amt_msat += fee_remaining_msat;
784
- }
777
+ let mut actual_fee_amt_msat = core:: cmp:: min ( fee_remaining_msat, proportional_fee_amt_msat) ;
778
+ actual_fee_amt_msat =
779
+ core:: cmp:: min ( actual_fee_amt_msat, htlc. expected_outbound_amount_msat ) ;
780
+ fee_remaining_msat -= actual_fee_amt_msat;
785
781
786
- let amount_to_forward_msat =
787
- htlc. expected_outbound_amount_msat . saturating_sub ( actual_fee_amt_msat) ;
782
+ if index == htlcs. len ( ) - 1 {
783
+ actual_fee_amt_msat += fee_remaining_msat;
784
+ }
788
785
789
- per_htlc_forwards . push ( ( htlc . intercept_id , amount_to_forward_msat) )
790
- }
786
+ let amount_to_forward_msat =
787
+ htlc . expected_outbound_amount_msat . saturating_sub ( actual_fee_amt_msat ) ;
791
788
792
- per_htlc_forwards
793
- }
794
- None => Vec :: new ( ) ,
789
+ per_htlc_forwards. push ( ( htlc. intercept_id , amount_to_forward_msat) )
795
790
}
791
+ per_htlc_forwards
796
792
}
797
793
798
794
#[ cfg( test) ]
@@ -812,7 +808,7 @@ mod tests {
812
808
813
809
proptest ! {
814
810
#[ test]
815
- fn test_calculate_amount_to_forward ( ( o_0, o_1, o_2, total_amt_to_forward_msat ) in arb_forward_amounts( ) ) {
811
+ fn proptest_calculate_amount_to_forward ( ( o_0, o_1, o_2, total_fee_msat ) in arb_forward_amounts( ) ) {
816
812
let htlcs = vec![
817
813
InterceptedHTLC {
818
814
intercept_id: InterceptId ( [ 0 ; 32 ] ) ,
@@ -828,10 +824,10 @@ mod tests {
828
824
} ,
829
825
] ;
830
826
831
- let result = calculate_amount_to_forward_per_htlc( & htlcs, total_amt_to_forward_msat ) ;
827
+ let result = calculate_amount_to_forward_per_htlc( & htlcs, total_fee_msat ) ;
832
828
let total_received_msat = o_0 + o_1 + o_2;
833
829
834
- if total_received_msat < total_amt_to_forward_msat {
830
+ if total_received_msat < total_fee_msat {
835
831
assert_eq!( result. len( ) , 0 ) ;
836
832
} else {
837
833
assert_ne!( result. len( ) , 0 ) ;
@@ -843,16 +839,42 @@ mod tests {
843
839
assert!( result[ 2 ] . 1 <= o_2) ;
844
840
845
841
let result_sum = result. iter( ) . map( |( _, f) | f) . sum:: <u64 >( ) ;
846
- assert! ( result_sum >= total_amt_to_forward_msat ) ;
847
- let five_pct = result_sum as f32 * 0.1 ;
848
- let fair_share_0 = ( ( o_0 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max ( o_0 as f32 ) ;
842
+ assert_eq! ( total_received_msat - result_sum , total_fee_msat ) ;
843
+ let five_pct = result_sum as f32 * 0.05 ;
844
+ let fair_share_0 = ( o_0 as f32 / total_received_msat as f32 ) * result_sum as f32 ;
849
845
assert!( result[ 0 ] . 1 as f32 <= fair_share_0 + five_pct) ;
850
- let fair_share_1 = ( ( o_1 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max ( o_1 as f32 ) ;
846
+ let fair_share_1 = ( o_1 as f32 / total_received_msat as f32 ) * result_sum as f32 ;
851
847
assert!( result[ 1 ] . 1 as f32 <= fair_share_1 + five_pct) ;
852
- let fair_share_2 = ( ( o_2 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max ( o_2 as f32 ) ;
848
+ let fair_share_2 = ( o_2 as f32 / total_received_msat as f32 ) * result_sum as f32 ;
853
849
assert!( result[ 2 ] . 1 as f32 <= fair_share_2 + five_pct) ;
854
850
}
855
-
856
851
}
857
852
}
853
+
854
+ #[ test]
855
+ fn test_calculate_amount_to_forward ( ) {
856
+ let htlcs = vec ! [
857
+ InterceptedHTLC {
858
+ intercept_id: InterceptId ( [ 0 ; 32 ] ) ,
859
+ expected_outbound_amount_msat: 2 ,
860
+ } ,
861
+ InterceptedHTLC {
862
+ intercept_id: InterceptId ( [ 1 ; 32 ] ) ,
863
+ expected_outbound_amount_msat: 6 ,
864
+ } ,
865
+ InterceptedHTLC {
866
+ intercept_id: InterceptId ( [ 2 ; 32 ] ) ,
867
+ expected_outbound_amount_msat: 2 ,
868
+ } ,
869
+ ] ;
870
+ let result = calculate_amount_to_forward_per_htlc ( & htlcs, 5 ) ;
871
+ assert_eq ! (
872
+ result,
873
+ vec![
874
+ ( htlcs[ 0 ] . intercept_id, 1 ) ,
875
+ ( htlcs[ 1 ] . intercept_id, 3 ) ,
876
+ ( htlcs[ 2 ] . intercept_id, 1 ) ,
877
+ ]
878
+ ) ;
879
+ }
858
880
}
0 commit comments