1
1
package loopin
2
2
3
3
import (
4
+ "bytes"
4
5
"context"
5
6
"errors"
6
7
"fmt"
@@ -93,6 +94,11 @@ type StaticAddressLoopIn struct {
93
94
// swap.
94
95
DepositOutpoints []string
95
96
97
+ // SelectedAmount is the amount that the user selected for the swap. If
98
+ // the user did not select an amount, the amount of all deposits is
99
+ // used.
100
+ SelectedAmount btcutil.Amount
101
+
96
102
// state is the current state of the swap.
97
103
state fsm.StateType
98
104
@@ -283,14 +289,25 @@ func (l *StaticAddressLoopIn) createHtlcTx(chainParams *chaincfg.Params,
283
289
})
284
290
}
285
291
292
+ // Determine the swap amount. If the user selected a specific amount, we
293
+ // use that and use the difference to the total deposit amount as the
294
+ // change.
295
+ var (
296
+ swapAmt = l .TotalDepositAmount ()
297
+ changeAmount btcutil.Amount
298
+ )
299
+ if l .SelectedAmount > 0 {
300
+ swapAmt = l .SelectedAmount
301
+ changeAmount = l .TotalDepositAmount () - l .SelectedAmount
302
+ }
303
+
286
304
// Calculate htlc tx fee for server provided fee rate.
287
- weight := l .htlcWeight ()
305
+ hasChange := changeAmount > 0
306
+ weight := l .htlcWeight (hasChange )
288
307
fee := feeRate .FeeForWeight (weight )
289
308
290
309
// Check if the server breaches our fee limits.
291
- amt := float64 (l .TotalDepositAmount ())
292
- feeLimit := btcutil .Amount (amt * maxFeePercentage )
293
-
310
+ feeLimit := btcutil .Amount (float64 (swapAmt ) * maxFeePercentage )
294
311
if fee > feeLimit {
295
312
return nil , fmt .Errorf ("htlc tx fee %v exceeds max fee %v" ,
296
313
fee , feeLimit )
@@ -308,12 +325,20 @@ func (l *StaticAddressLoopIn) createHtlcTx(chainParams *chaincfg.Params,
308
325
309
326
// Create the sweep output
310
327
sweepOutput := & wire.TxOut {
311
- Value : int64 (l . TotalDepositAmount ()) - int64 ( fee ),
328
+ Value : int64 (swapAmt - fee ),
312
329
PkScript : pkscript ,
313
330
}
314
331
315
332
msgTx .AddTxOut (sweepOutput )
316
333
334
+ // We expect change to be sent back to our static address output script.
335
+ if changeAmount > 0 {
336
+ msgTx .AddTxOut (& wire.TxOut {
337
+ Value : int64 (changeAmount ),
338
+ PkScript : l .AddressParams .PkScript ,
339
+ })
340
+ }
341
+
317
342
return msgTx , nil
318
343
}
319
344
@@ -325,7 +350,7 @@ func (l *StaticAddressLoopIn) isHtlcTimedOut(height int32) bool {
325
350
}
326
351
327
352
// htlcWeight returns the weight for the htlc transaction.
328
- func (l * StaticAddressLoopIn ) htlcWeight () lntypes.WeightUnit {
353
+ func (l * StaticAddressLoopIn ) htlcWeight (hasChange bool ) lntypes.WeightUnit {
329
354
var weightEstimator input.TxWeightEstimator
330
355
for i := 0 ; i < len (l .Deposits ); i ++ {
331
356
weightEstimator .AddTaprootKeySpendInput (
@@ -335,6 +360,10 @@ func (l *StaticAddressLoopIn) htlcWeight() lntypes.WeightUnit {
335
360
336
361
weightEstimator .AddP2WSHOutput ()
337
362
363
+ if hasChange {
364
+ weightEstimator .AddP2TROutput ()
365
+ }
366
+
338
367
return weightEstimator .Weight ()
339
368
}
340
369
@@ -373,11 +402,25 @@ func (l *StaticAddressLoopIn) createHtlcSweepTx(ctx context.Context,
373
402
return nil , err
374
403
}
375
404
405
+ // Check if the htlc tx has a change output. If so we need to select the
406
+ // non-change output index to construct the sweep with.
407
+ htlcInputIndex := uint32 (0 )
408
+ if len (htlcTx .TxOut ) == 2 {
409
+ // If the first htlc tx output matches our static address
410
+ // script we need to select the second output to sweep from.
411
+ if bytes .Equal (
412
+ htlcTx .TxOut [0 ].PkScript , l .AddressParams .PkScript ,
413
+ ) {
414
+
415
+ htlcInputIndex = 1
416
+ }
417
+ }
418
+
376
419
// Add the htlc input.
377
420
sweepTx .AddTxIn (& wire.TxIn {
378
421
PreviousOutPoint : wire.OutPoint {
379
422
Hash : htlcTx .TxHash (),
380
- Index : 0 ,
423
+ Index : htlcInputIndex ,
381
424
},
382
425
SignatureScript : htlc .SigScript ,
383
426
Sequence : htlc .SuccessSequence (),
0 commit comments