Skip to content

Commit fde86d3

Browse files
committed
accounts+mid: support for error handlers in accounts
We also want to be able to see via the trace logs if a request has errored. So first we need to update the framework to be ready for this.
1 parent 0bac0aa commit fde86d3

File tree

6 files changed

+123
-8
lines changed

6 files changed

+123
-8
lines changed

accounts/checkers.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,30 @@ func NewAccountChecker(service Service,
398398
}
399399
}
400400

401+
// handleErrorResponse passes an error to a checker if the checker has
402+
// registered an error handler.
403+
func (a *AccountChecker) handleErrorResponse(ctx context.Context,
404+
fullUri string, parsedErr error) (error, error) {
405+
406+
// If we don't have a handler for the URI, it means we don't support
407+
// that RPC.
408+
checker, ok := a.checkers[fullUri]
409+
if !ok {
410+
return nil, ErrNotSupportedWithAccounts
411+
}
412+
413+
newErr, err := checker.HandleErrorResponse(ctx, parsedErr)
414+
if err != nil {
415+
return nil, err
416+
}
417+
418+
if newErr != nil {
419+
parsedErr = newErr
420+
}
421+
422+
return parsedErr, nil
423+
}
424+
401425
// checkIncomingRequest makes sure the type of incoming call is supported and
402426
// if it is, that it is allowed with the current account balance.
403427
func (a *AccountChecker) checkIncomingRequest(ctx context.Context,

accounts/checkers_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash,
8080
return nil
8181
}
8282

83+
func (m *mockService) PaymentErrored(id AccountID, hash lntypes.Hash) error {
84+
return nil
85+
}
86+
8387
func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
8488
amt lnwire.MilliSatoshi) error {
8589

accounts/interceptor.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,30 @@ func (s *InterceptorService) Intercept(ctx context.Context,
126126

127127
// Parse and possibly manipulate outgoing responses.
128128
case *lnrpc.RPCMiddlewareRequest_Response:
129+
if r.Response.IsError {
130+
parsedErr := mid.ParseResponseErr(r.Response.Serialized)
131+
132+
replacementErr, err := s.checkers.handleErrorResponse(
133+
ctx, r.Response.MethodFullUri, parsedErr,
134+
)
135+
if err != nil {
136+
return mid.RPCErr(req, err)
137+
}
138+
139+
// No error occurred but the response error should be
140+
// replaced with the given custom error. Wrap it in the
141+
// correct RPC response of the interceptor now.
142+
if replacementErr != nil {
143+
return mid.RPCErrReplacement(
144+
req, replacementErr,
145+
)
146+
}
147+
148+
// No error and no replacement, just return an empty
149+
// response of the correct type.
150+
return mid.RPCOk(req)
151+
}
152+
129153
msg, err := parseRPCMessage(r.Response)
130154
if err != nil {
131155
return mid.RPCErr(req, err)

accounts/interface.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ type Service interface {
258258
AssociatePayment(id AccountID, paymentHash lntypes.Hash,
259259
fullAmt lnwire.MilliSatoshi) error
260260

261+
// PaymentErrored removes a pending payment from the accounts
262+
// registered payment list. This should only ever be called if we are
263+
// sure that the payment request errored out.
264+
PaymentErrored(id AccountID, hash lntypes.Hash) error
265+
261266
RequestValuesStore
262267
}
263268

accounts/service.go

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,45 @@ func (s *InterceptorService) AssociateInvoice(id AccountID,
437437
return s.store.UpdateAccount(account)
438438
}
439439

440+
// PaymentErrored removes a pending payment from the account's registered
441+
// payment list. This should only ever be called if we are sure that the payment
442+
// request errored out.
443+
func (s *InterceptorService) PaymentErrored(id AccountID,
444+
hash lntypes.Hash) error {
445+
446+
s.Lock()
447+
defer s.Unlock()
448+
449+
// If we have already started tracking this payment, then RemovePayment
450+
// should have been called instead.
451+
_, ok := s.pendingPayments[hash]
452+
if ok {
453+
return fmt.Errorf("cannot disassociate payment if tracking " +
454+
"has already started")
455+
}
456+
457+
account, err := s.store.Account(id)
458+
if err != nil {
459+
return err
460+
}
461+
462+
// Check that this payment is actually associated with this account.
463+
_, ok = account.Payments[hash]
464+
if !ok {
465+
return fmt.Errorf("payment with hash %s is not associated "+
466+
"with this account", hash)
467+
}
468+
469+
// Delete the payment and update the persisted account.
470+
delete(account.Payments, hash)
471+
472+
if err := s.store.UpdateAccount(account); err != nil {
473+
return fmt.Errorf("error updating account: %w", err)
474+
}
475+
476+
return nil
477+
}
478+
440479
// AssociatePayment associates a payment (hash) with the given account,
441480
// ensuring that the payment will be tracked for a user when LiT is
442481
// restarted.
@@ -451,11 +490,26 @@ func (s *InterceptorService) AssociatePayment(id AccountID,
451490
return err
452491
}
453492

454-
// If the payment is already associated with the account, we don't need
455-
// to associate it again.
493+
// Check if this payment is associated with the account already.
456494
_, ok := account.Payments[paymentHash]
457495
if ok {
458-
return nil
496+
// We do not allow another payment to the same hash if the
497+
// payment is already in-flight or succeeded. This mitigates a
498+
// user being able to launch a second RPC-erring payment with
499+
// the same hash that would remove the payment from being
500+
// tracked. Note that this prevents launching multipart
501+
// payments, but allows retrying a payment if it has failed.
502+
if account.Payments[paymentHash].Status !=
503+
lnrpc.Payment_FAILED {
504+
505+
return fmt.Errorf("payment with hash %s is already in "+
506+
"flight or succeeded (status %v)", paymentHash,
507+
account.Payments[paymentHash].Status)
508+
}
509+
510+
// Otherwise, we fall through to correctly update the payment
511+
// amount, in case we have a zero-amount invoice that is
512+
// retried.
459513
}
460514

461515
// Associate the payment with the account and store it.

rpcmiddleware/interface.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ var (
3535

3636
// PassThroughErrorHandler is an ErrorHandler that does not modify an
3737
// error and instead just passes it through.
38-
PassThroughErrorHandler ErrorHandler = func(error) (error, error) {
38+
PassThroughErrorHandler ErrorHandler = func(context.Context, error) (
39+
error, error) {
40+
3941
return nil, nil
4042
}
4143

@@ -81,7 +83,7 @@ type messageHandler func(context.Context, proto.Message) (proto.Message, error)
8183
// pass through the error unchanged (=return nil, nil), replace the error with
8284
// a different one (=return non-nil error, nil error) or abort by returning a
8385
// non-nil error.
84-
type ErrorHandler func(respErr error) (error, error)
86+
type ErrorHandler func(ctx context.Context, respErr error) (error, error)
8587

8688
// RoundTripChecker is a type that represents a basic request/response round
8789
// trip checker.
@@ -115,7 +117,7 @@ type RoundTripChecker interface {
115117
// The handler can pass through the error (=return nil, nil), replace
116118
// the response error with a new one (=return non-nil error, nil) or
117119
// abort by returning a non nil error (=return nil, non-nil error).
118-
HandleErrorResponse(error) (error, error)
120+
HandleErrorResponse(context.Context, error) (error, error)
119121
}
120122

121123
// DefaultChecker is the default implementation of a round trip checker.
@@ -171,8 +173,10 @@ func (r *DefaultChecker) HandleResponse(ctx context.Context,
171173
// The handler can pass through the error (=return nil, nil), replace
172174
// the response error with a new one (=return non-nil error, nil) or
173175
// abort by returning a non nil error (=return nil, non-nil error).
174-
func (r *DefaultChecker) HandleErrorResponse(respErr error) (error, error) {
175-
return r.errorHandler(respErr)
176+
func (r *DefaultChecker) HandleErrorResponse(ctx context.Context,
177+
respErr error) (error, error) {
178+
179+
return r.errorHandler(ctx, respErr)
176180
}
177181

178182
// NewPassThrough returns a round trip checker that allows the incoming request

0 commit comments

Comments
 (0)