@@ -41,6 +41,7 @@ type InvoiceClient interface {
41
41
// payment challenges.
42
42
type LndChallenger struct {
43
43
client InvoiceClient
44
+ clientCtx func () context.Context
44
45
genInvoiceReq InvoiceRequestGenerator
45
46
46
47
invoiceStates map [lntypes.Hash ]lnrpc.Invoice_InvoiceState
@@ -69,15 +70,23 @@ const (
69
70
// an lnd backend to create payment challenges.
70
71
func NewLndChallenger (client InvoiceClient ,
71
72
genInvoiceReq InvoiceRequestGenerator ,
73
+ ctxFunc func () context.Context ,
72
74
errChan chan <- error ) (* LndChallenger , error ) {
73
75
76
+ // Make sure we have a valid context function. This will be called to
77
+ // create a new context for each call to the lnd client.
78
+ if ctxFunc == nil {
79
+ ctxFunc = context .Background
80
+ }
81
+
74
82
if genInvoiceReq == nil {
75
83
return nil , fmt .Errorf ("genInvoiceReq cannot be nil" )
76
84
}
77
85
78
86
invoicesMtx := & sync.Mutex {}
79
87
return & LndChallenger {
80
88
client : client ,
89
+ clientCtx : ctxFunc ,
81
90
genInvoiceReq : genInvoiceReq ,
82
91
invoiceStates : make (map [lntypes.Hash ]lnrpc.Invoice_InvoiceState ),
83
92
invoicesMtx : invoicesMtx ,
@@ -103,8 +112,9 @@ func (l *LndChallenger) Start() error {
103
112
// cache. We need to keep track of all invoices, even quite old ones to
104
113
// make sure tokens are valid. But to save space we only keep track of
105
114
// an invoice's state.
115
+ ctx := l .clientCtx ()
106
116
invoiceResp , err := l .client .ListInvoices (
107
- context . Background () , & lnrpc.ListInvoiceRequest {
117
+ ctx , & lnrpc.ListInvoiceRequest {
108
118
NumMaxInvoices : math .MaxUint64 ,
109
119
},
110
120
)
@@ -137,7 +147,7 @@ func (l *LndChallenger) Start() error {
137
147
l .invoicesMtx .Unlock ()
138
148
139
149
// We need to be able to cancel any subscription we make.
140
- ctxc , cancel := context .WithCancel (context . Background ())
150
+ ctxc , cancel := context .WithCancel (l . clientCtx ())
141
151
l .invoicesCancel = cancel
142
152
143
153
subscriptionResp , err := l .client .SubscribeInvoices (
@@ -261,12 +271,14 @@ func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, error)
261
271
log .Errorf ("Error generating invoice request: %v" , err )
262
272
return "" , lntypes .ZeroHash , err
263
273
}
264
- ctx := context .Background ()
274
+
275
+ ctx := l .clientCtx ()
265
276
response , err := l .client .AddInvoice (ctx , invoice )
266
277
if err != nil {
267
278
log .Errorf ("Error adding invoice: %v" , err )
268
279
return "" , lntypes .ZeroHash , err
269
280
}
281
+
270
282
paymentHash , err := lntypes .MakeHash (response .RHash )
271
283
if err != nil {
272
284
log .Errorf ("Error parsing payment hash: %v" , err )
0 commit comments