From 8b0b18583b93e155331a0a60fb2f631508c8cb90 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 1 Apr 2025 19:15:24 -0700 Subject: [PATCH 1/2] challenger: add new abstracted InvoiceStateStore This will be used to simplify some of the logic the upcoming commit. We retain the existing condition variable usage, while also adding some methods that will be of use for the upcoming async background load. --- challenger/invoice_store.go | 344 ++++++++++++++++++++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 challenger/invoice_store.go diff --git a/challenger/invoice_store.go b/challenger/invoice_store.go new file mode 100644 index 0000000..fe30658 --- /dev/null +++ b/challenger/invoice_store.go @@ -0,0 +1,344 @@ +package challenger + +import ( + "fmt" + "sync" + "sync/atomic" // Import atomic package + "time" + + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" +) + +// InvoiceStateStore manages the state of invoices in a thread-safe manner and +// allows goroutines to wait for specific states or initial load completion. +type InvoiceStateStore struct { + // states holds the last known state for invoices. + states map[lntypes.Hash]lnrpc.Invoice_InvoiceState + + // mtx guards access to states and initialLoadComplete. + mtx sync.Mutex + + // cond is used to signal waiters when states is updated or when the + // initial load completes. + cond *sync.Cond + + // initialLoadComplete is true once the initial fetching of all + // historical invoices is done. Use atomic for lock-free reads/writes. + initialLoadComplete atomic.Bool + + // quit channel signals the store that the challenger is shutting down. + // Waiters should abort if this channel is closed. + quit <-chan struct{} +} + +// NewInvoiceStateStore creates a new instance of InvoiceStateStore. The quit +// channel should be the challenger's main quit channel. +func NewInvoiceStateStore(quit <-chan struct{}) *InvoiceStateStore { + s := &InvoiceStateStore{ + states: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), + quit: quit, + } + + // Initialize cond with the store's mutex. + s.cond = sync.NewCond(&s.mtx) + + return s +} + +// SetState adds or updates the state for a given invoice hash. It notifies any +// waiting goroutines about the change. +func (s *InvoiceStateStore) SetState(hash lntypes.Hash, + state lnrpc.Invoice_InvoiceState) { + + s.mtx.Lock() + defer s.mtx.Unlock() + + // Only broadcast if the state actually changes or is new. + currentState, exists := s.states[hash] + if !exists || currentState != state { + s.states[hash] = state + + // Signal potential waiters. + s.cond.Broadcast() + } +} + +// DeleteState removes an invoice state from the store, typically used for +// irrelevant (canceled/expired) invoices. It notifies any waiting goroutines +// about the change. +func (s *InvoiceStateStore) DeleteState(hash lntypes.Hash) { + s.mtx.Lock() + defer s.mtx.Unlock() + + // Only broadcast if the state actually existed. + if _, exists := s.states[hash]; exists { + delete(s.states, hash) + + // Signal potential waiters. + s.cond.Broadcast() + } +} + +// GetState retrieves the current state for a given invoice hash. +func (s *InvoiceStateStore) GetState(hash lntypes.Hash, +) (lnrpc.Invoice_InvoiceState, bool) { + + s.mtx.Lock() + defer s.mtx.Unlock() + + state, exists := s.states[hash] + return state, exists +} + +// MarkInitialLoadComplete sets the initialLoadComplete flag to true atomically +// and broadcasts on the condition variable to wake up any waiting goroutines. +func (s *InvoiceStateStore) MarkInitialLoadComplete() { + // Check atomically first to potentially avoid locking and broadcasting. + if s.initialLoadComplete.Load() { + // Already marked so we can return early. + return + } + + // Grab the lock now to ensure we can use the condition variable safely. + s.mtx.Lock() + defer s.mtx.Unlock() + + // Double-check under lock in case another goroutine just did it. + if !s.initialLoadComplete.Load() { + s.initialLoadComplete.Store(true) + + // Wake up everyone waiting. + s.cond.Broadcast() + log.Infof("Invoice store marked initial load as complete.") + } +} + +// IsInitialLoadComplete checks atomically if the initial historical invoice +// load has finished. +func (s *InvoiceStateStore) IsInitialLoadComplete() bool { + return s.initialLoadComplete.Load() +} + +// waitForCondition blocks until the provided condition function returns true, a +// timeout occurs, or the quit signal is received. The mutex `s.mtx` MUST be +// held by the caller when calling this function. The mutex will be unlocked +// while waiting and re-locked before returning. It returns an error if the +// timeout is reached or the quit signal is received. +func (s *InvoiceStateStore) waitForCondition(condition func() bool, + timeout time.Duration, timeoutMsg string) error { + + // Check condition immediately before waiting. + if condition() { + return nil + } + + // Start the timeout timer. + timer := time.NewTimer(timeout) + defer timer.Stop() + + // Channel to signal when the condition is met or quit signal is + // received. + waitDone := make(chan struct{}) + + // Goroutine to wait on the condition variable. + go func() { + // Re-acquire lock for cond.Wait + s.mtx.Lock() + for !condition() { + // Check quit signal before waiting indefinitely. + select { + case <-s.quit: + s.mtx.Unlock() + close(waitDone) + return + default: + } + + // Wait for the condition to be signaled. + s.cond.Wait() + } + s.mtx.Unlock() + close(waitDone) + }() + + // Unlock to allow the waiting goroutine to acquire it. We expect the + // caller to already have held the lock. + s.mtx.Unlock() + + // Wait for either the condition to be met, timeout, or quit signal. + var errResult error + select { + case <-waitDone: + // Condition met or quit signal received by waiter. + if !timer.Stop() { + // Timer already fired and channel might contain value, + // drain it. Use a select to prevent blocking if the + // channel is empty. + select { + case <-timer.C: + default: + } + } + + // Re-check quit signal after waitDone is closed. + select { + case <-s.quit: + log.Warnf("waitForCondition: Shutdown signal received " + + "while condition was being met.") + + errResult = fmt.Errorf("challenger shutting down") + + default: + // Condition was met successfully. + errResult = nil + } + + case <-timer.C: + // Timeout expired. + log.Warnf("waitForCondition: %s (timeout: %v)", timeoutMsg, + timeout) + errResult = fmt.Errorf("%s", timeoutMsg) + + // We need to signal the waiting goroutine to stop, best way is via + // quit channel, but we don't control that. The waiting goroutine will + // eventually see the condition is true (if it changes later) or hit the + // quit signal. + + case <-s.quit: + // Shutdown signal received while waiting for timer/condition. + log.Warnf("waitForCondition: Shutdown signal received.") + + timer.Stop() + errResult = fmt.Errorf("challenger shutting down") + } + + // Re-acquire lock before returning, as expected by the caller. + s.mtx.Lock() + return errResult +} + +// WaitForState blocks until the specified invoice hash reaches the desiredState +// or a timeout occurs. It first waits for the initial historical invoice load +// to complete if necessary. initialLoadTimeout applies only if waiting for the +// initial load. requestTimeout applies when waiting for the specific invoice +// state change. +func (s *InvoiceStateStore) WaitForState(hash lntypes.Hash, + desiredState lnrpc.Invoice_InvoiceState, initialLoadTimeout time.Duration, + requestTimeout time.Duration) error { + + // Check to see if we need to wait for the initial load to complete. + if !s.initialLoadComplete.Load() { + log.Debugf("WaitForState: Initial load not complete, waiting "+ + "up to %v for hash %v...", + initialLoadTimeout, hash) + + initialLoadCondition := func() bool { + return s.initialLoadComplete.Load() + } + + timeoutMsg := fmt.Sprintf("timed out waiting for initial "+ + "invoice load after %v", initialLoadTimeout) + + err := s.waitForCondition( + initialLoadCondition, initialLoadTimeout, timeoutMsg, + ) + if err != nil { + log.Warnf("WaitForState: Error waiting for initial "+ + "load for hash %v: %v", hash, err) + return err + } + + log.Debugf("WaitForState: Initial load completed for hash %v", + hash) + } + + // We'll first check to see if the state is already where we need it to + // be. + currentState, hasInvoice := s.states[hash] + if hasInvoice && currentState == desiredState { + log.Debugf("WaitForState: Hash %v already in desired state %v.", + hash, desiredState) + return nil + } + + // If not, then we'll wait in the background for the condition to be + // met. + log.Debugf("WaitForState: Waiting up to %v for hash %v to reach "+ + "state %v...", requestTimeout, hash, desiredState) + + specificStateCondition := func() bool { + // Re-check state within the condition function under lock. + st, exists := s.states[hash] + return exists && st == desiredState + } + + timeoutMsg := fmt.Sprintf("timed out waiting for state %v after %v", + desiredState, requestTimeout) + + // We'll wait for the invoice to reach the desired state. + err := s.waitForCondition( + specificStateCondition, requestTimeout, timeoutMsg, + ) + if err != nil { + // If we timed out, provide a more specific error message based + // on the final state. + finalState, finalExists := s.states[hash] + if err.Error() == timeoutMsg { + log.Warnf("WaitForState: Timed out after %v waiting "+ + "for hash %v state %v. Final state: %v, "+ + "exists: %v", requestTimeout, hash, + desiredState, finalState, finalExists) + + if !finalExists { + return fmt.Errorf("no active or settled "+ + "invoice found for hash=%v after "+ + "timeout", hash) + } + + return fmt.Errorf("invoice status %v not %v before "+ + "timeout for hash=%v", finalState, + desiredState, hash) + } + + // Otherwise, it was likely a shutdown error. + log.Warnf("WaitForState: Error waiting for specific "+ + "state for hash %v: %v", hash, err) + return err + } + + // Condition was met successfully. + log.Debugf("WaitForState: Hash %v reached desired state %v.", + hash, desiredState) + return nil +} + +// WaitForInitialLoad blocks until the initial historical invoice load has +// completed, or a timeout occurs. +func (s *InvoiceStateStore) WaitForInitialLoad(timeout time.Duration) error { + // Check if already complete. + if s.initialLoadComplete.Load() { + return nil + } + + log.Debugf("WaitForInitialLoad: Initial load not complete, waiting up to %v...", + timeout) + + initialLoadCondition := func() bool { + // Atomic read, no lock needed for this condition check. + return s.initialLoadComplete.Load() + } + timeoutMsg := fmt.Sprintf("timed out waiting for initial invoice load after %v", timeout) + + s.mtx.Lock() + + // Wait for the condition using the helper. + err := s.waitForCondition(initialLoadCondition, timeout, timeoutMsg) + if err != nil { + log.Warnf("WaitForInitialLoad: Error waiting: %v", err) + return err // Return error (timeout or shutdown) + } + + log.Debugf("WaitForInitialLoad: Initial load completed.") + return nil +} From 3833f4794e0cf0612f062d09106e0a0e9f8a4abb Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 1 Apr 2025 19:16:41 -0700 Subject: [PATCH 2/2] challenger: refactor LND invoice loading and state management This commit refactors the LND challenger's invoice handling mechanism to improve performance, reliability, and resource usage, especially for nodes with a large number of historical invoices. Previously, the challenger attempted to load all historical invoices in a single `ListInvoices` call during startup. This could lead to long startup times, high memory consumption, and potential timeouts or failures for nodes with extensive invoice history. Additionally, the state management relied on a mutex and condition variable, which could be complex to manage correctly. The key changes include: - **Concurrent Loading and Subscription:** The challenger now starts two background goroutines concurrently: one to load historical invoices and another to subscribe to new invoice updates using the latest known indices. - **Paginated Historical Loading:** Historical invoices are now fetched in batches using `ListInvoices` with `IndexOffset`. The batch size is configurable via the new `InvoiceBatchSize` config option (defaulting to 1000) and passed down from the main Aperture config. - **InvoiceStateStore:** A new `InvoiceStateStore` type is introduced to manage the invoice states map. This store handles thread-safe access, tracks the completion of the initial historical load, and provides a `WaitForState` method that correctly handles waiting for the initial load before checking for a specific invoice state. - **Improved Shutdown/Cancellation:** Context cancellation and handling of the quit signal are improved throughout the loading and subscription processes to ensure cleaner shutdowns. - **Refactored VerifyInvoiceStatus:** This method now delegates waiting logic to the `InvoiceStateStore`, simplifying the challenger code and ensuring it waits for the initial load if necessary. --- .gitignore | 4 +- aperture.go | 7 +- challenger/lnc.go | 9 +- challenger/lnd.go | 457 ++++++++++++++++++++++++++++------------- challenger/lnd_test.go | 45 ++-- 5 files changed, 351 insertions(+), 171 deletions(-) diff --git a/.gitignore b/.gitignore index f2b8ad6..a58ceec 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ cmd/aperture/aperture # misc -.vscode \ No newline at end of file +.vscode +.aider* +CONVENTIONS.md diff --git a/aperture.go b/aperture.go index 5fbb4dd..f5ced9b 100644 --- a/aperture.go +++ b/aperture.go @@ -336,7 +336,8 @@ func (a *Aperture) Start(errChan chan error) error { } a.challenger, err = challenger.NewLNCChallenger( - session, lncStore, genInvoiceReq, errChan, + session, lncStore, a.cfg.InvoiceBatchSize, + genInvoiceReq, errChan, ) if err != nil { return fmt.Errorf("unable to start lnc "+ @@ -359,8 +360,8 @@ func (a *Aperture) Start(errChan chan error) error { } a.challenger, err = challenger.NewLndChallenger( - client, genInvoiceReq, context.Background, - errChan, + client, a.cfg.InvoiceBatchSize, genInvoiceReq, + context.Background, errChan, ) if err != nil { return err diff --git a/challenger/lnc.go b/challenger/lnc.go index d30891f..fa13d0a 100644 --- a/challenger/lnc.go +++ b/challenger/lnc.go @@ -19,7 +19,7 @@ type LNCChallenger struct { // NewLNCChallenger creates a new challenger that uses the given LNC session to // connect to an lnd backend to create payment challenges. func NewLNCChallenger(session *lnc.Session, lncStore lnc.Store, - genInvoiceReq InvoiceRequestGenerator, + invoiceBatchSize int, genInvoiceReq InvoiceRequestGenerator, errChan chan<- error) (*LNCChallenger, error) { nodeConn, err := lnc.NewNodeConn(session, lncStore) @@ -34,16 +34,13 @@ func NewLNCChallenger(session *lnc.Session, lncStore lnc.Store, } lndChallenger, err := NewLndChallenger( - client, genInvoiceReq, nodeConn.CtxFunc, errChan, + client, invoiceBatchSize, genInvoiceReq, nodeConn.CtxFunc, errChan, ) if err != nil { return nil, err } - err = lndChallenger.Start() - if err != nil { - return nil, err - } + lndChallenger.Start() return &LNCChallenger{ lndChallenger: lndChallenger, diff --git a/challenger/lnd.go b/challenger/lnd.go index 00bd53e..f18f7a4 100644 --- a/challenger/lnd.go +++ b/challenger/lnd.go @@ -4,26 +4,38 @@ import ( "context" "fmt" "io" - "math" "strings" "sync" "time" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" ) +const ( + // defaultListInvoicesBatchSize is the default number of invoices to fetch + // in each ListInvoices call during the historical load. + defaultListInvoicesBatchSize = 1000 +) + +var ( + // defaultInitialLoadTimeout is the maximum time we wait for the initial + // batch of invoices to be loaded from lnd before allowing state checks + // to proceed or fail. + defaultInitialLoadTimeout = 10 * time.Second +) + // LndChallenger is a challenger that uses an lnd backend to create new LSAT // payment challenges. type LndChallenger struct { client InvoiceClient clientCtx func() context.Context genInvoiceReq InvoiceRequestGenerator + batchSize int // Added batchSize - invoiceStates map[lntypes.Hash]lnrpc.Invoice_InvoiceState - invoicesMtx *sync.Mutex + invoiceStore *InvoiceStateStore invoicesCancel func() - invoicesCond *sync.Cond errChan chan<- error @@ -37,7 +49,7 @@ var _ Challenger = (*LndChallenger)(nil) // NewLndChallenger creates a new challenger that uses the given connection to // an lnd backend to create payment challenges. -func NewLndChallenger(client InvoiceClient, +func NewLndChallenger(client InvoiceClient, batchSize int, genInvoiceReq InvoiceRequestGenerator, ctxFunc func() context.Context, errChan chan<- error) (*LndChallenger, error) { @@ -52,110 +64,310 @@ func NewLndChallenger(client InvoiceClient, return nil, fmt.Errorf("genInvoiceReq cannot be nil") } - invoicesMtx := &sync.Mutex{} - challenger := &LndChallenger{ - client: client, - clientCtx: ctxFunc, - genInvoiceReq: genInvoiceReq, - invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), - invoicesMtx: invoicesMtx, - invoicesCond: sync.NewCond(invoicesMtx), - quit: make(chan struct{}), - errChan: errChan, + // Use default batch size if zero or negative is provided. + if batchSize <= 0 { + batchSize = defaultListInvoicesBatchSize } - err := challenger.Start() - if err != nil { - return nil, fmt.Errorf("unable to start challenger: %w", err) + quitChan := make(chan struct{}) + challenger := &LndChallenger{ + client: client, + clientCtx: ctxFunc, + genInvoiceReq: genInvoiceReq, + batchSize: batchSize, + invoiceStore: NewInvoiceStateStore(quitChan), + invoicesCancel: func() {}, + quit: quitChan, + errChan: errChan, } + // Start the background loading/subscription process. + challenger.Start() + return challenger, nil } -// Start starts the challenger's main work which is to keep track of all -// invoices and their states. For that the backing lnd node is queried for all -// invoices on startup and the a subscription to all subsequent invoice updates -// is created. -func (l *LndChallenger) Start() error { - // These are the default values for the subscription. In case there are - // no invoices yet, this will instruct lnd to just send us all updates. - // If there are existing invoices, these indices will be updated to - // reflect the latest known invoices. +// Start launches the background process to load historical invoices and +// subscribe to new invoice updates concurrently. This method returns +// immediately. +func (l *LndChallenger) Start() { + log.Infof("Starting LND challenger background tasks...") + + // Use a short timeout context for this initial call. + ctxIdx, cancelIdx := context.WithTimeout(l.clientCtx(), 30*time.Second) + defer cancelIdx() + addIndex := uint64(0) settleIndex := uint64(0) - // Get a list of all existing invoices on startup and add them to our - // cache. We need to keep track of all invoices, even quite old ones to - // make sure tokens are valid. But to save space we only keep track of - // an invoice's state. - ctx := l.clientCtx() - invoiceResp, err := l.client.ListInvoices( - ctx, &lnrpc.ListInvoiceRequest{ - NumMaxInvoices: math.MaxUint64, + log.Debugf("Querying latest invoice indices...") + latestInvoiceResp, err := l.client.ListInvoices( + ctxIdx, &lnrpc.ListInvoiceRequest{ + NumMaxInvoices: 1, // Only need the latest one + Reversed: true, }, ) if err != nil { - return err + // Don't fail startup entirely, just log and proceed with 0 + // indices. The historical load will catch up. + log.Errorf("Failed to get latest invoice indices, "+ + "subscribing from beginning (error: %v)", err) + } else if len(latestInvoiceResp.Invoices) > 0 { + // Indices are only meaningful if we actually got an invoice. + latestInvoice := latestInvoiceResp.Invoices[0] + addIndex = latestInvoice.AddIndex + settleIndex = latestInvoice.SettleIndex + + log.Infof("Latest indices found: add=%d, settle=%d", + addIndex, settleIndex) + } else { + log.Infof("No existing invoices found, subscribing " + + "from beginning.") } - // Advance our indices to the latest known one so we'll only receive - // updates for new invoices and/or newly settled invoices. - l.invoicesMtx.Lock() - for _, invoice := range invoiceResp.Invoices { - // Some invoices like AMP invoices may not have a payment hash - // populated. - if invoice.RHash == nil { - continue + cancelIdx() + + // We'll launch our first goroutine to load the historical invoices in + // the background. + l.wg.Add(1) + go l.loadHistoricalInvoices() + + // We'll launch our second goroutine to subscribe to new invoices in to + // populate the invoice store with new updates. + l.wg.Add(1) + go l.subscribeToInvoices(addIndex, settleIndex) + + log.Infof("LND challenger background tasks launched.") +} + +// loadHistoricalInvoices fetches all past invoices relevant using pagination +// and updates the invoice store. It marks the initial load complete upon +// finishing. This runs in a goroutine. +func (l *LndChallenger) loadHistoricalInvoices() { + defer l.wg.Done() + + log.Infof("Starting historical invoice loading "+ + "(batch size %d)...", l.batchSize) + + // Use a background context for the potentially long-running list calls. + // Allow it to be cancelled by Stop() via the main quit channel. + ctxList, cancelList := context.WithCancel(l.clientCtx()) + defer cancelList() + + // Goroutine to cancel the list context if quit signal is received. + go func() { + select { + case <-l.quit: + log.Warnf("Shutdown signal received, cancelling " + + "historical invoice list context.") + cancelList() + + case <-ctxList.Done(): } + }() + + startTime := time.Now() + numInvoicesLoaded := 0 + indexOffset := uint64(0) + + for { + // Check for shutdown signal before each batch. + select { + case <-l.quit: + log.Warnf("Shutdown signal received during " + + "historical invoice loading.") - if invoice.AddIndex > addIndex { - addIndex = invoice.AddIndex + // Mark load complete anyway so waiters don't block + // indefinitely. + l.invoiceStore.MarkInitialLoadComplete() + return + default: } - if invoice.SettleIndex > settleIndex { - settleIndex = invoice.SettleIndex + + log.Debugf("Querying invoices batch starting from "+ + "index %d", indexOffset) + + req := &lnrpc.ListInvoiceRequest{ + IndexOffset: indexOffset, + NumMaxInvoices: uint64(l.batchSize), } - hash, err := lntypes.MakeHash(invoice.RHash) + + invoiceResp, err := l.client.ListInvoices(ctxList, req) + if err != nil { - l.invoicesMtx.Unlock() - return fmt.Errorf("error parsing invoice hash: %v", err) + // If context was cancelled by shutdown, it's not a + // fatal startup error. + if strings.Contains(err.Error(), context.Canceled.Error()) { + log.Warnf("Historical invoice loading " + + "cancelled by shutdown.") + + l.invoiceStore.MarkInitialLoadComplete() + return + } + log.Errorf("Failed to list invoices batch "+ + "(offset %d): %v", indexOffset, err) + + // Signal fatal error to the main application. + select { + case l.errChan <- fmt.Errorf("failed historical "+ + "invoice load batch: %w", err): + case <-l.quit: // Don't block if shutting down + } + + // Mark load complete on error so waiters don't block + // indefinitely. + l.invoiceStore.MarkInitialLoadComplete() + return } - // Don't track the state of canceled or expired invoices. - if invoiceIrrelevant(invoice) { - continue + // Process the received batch. + invoicesInBatch := len(invoiceResp.Invoices) + + log.Debugf("Received %d invoices in batch (offset %d)", + invoicesInBatch, indexOffset) + + fmt.Println("Loading incoies: ", spew.Sdump(invoiceResp.Invoices)) + + for _, invoice := range invoiceResp.Invoices { + // Some invoices like AMP invoices may not have a + // payment hash populated. + if invoice.RHash == nil { + continue + } + + hash, err := lntypes.MakeHash(invoice.RHash) + if err != nil { + log.Errorf("Error parsing invoice hash "+ + "during initial load: %v. Skipping "+ + "invoice.", err) + continue + } + + // Don't track the state of irrelevant invoices. + if invoiceIrrelevant(invoice) { + continue + } + + l.invoiceStore.SetState(hash, invoice.State) + numInvoicesLoaded++ + } + + // If this batch was empty or less than max, we're done with + // history. LND documentation suggests LastIndexOffset is the + // index of the *last* invoice returned. If no invoices + // returned, break. If NumMaxInvoices was returned, continue + // from LastIndexOffset. If < NumMaxInvoices returned, we are + // also done. + if invoicesInBatch == 0 || invoicesInBatch < l.batchSize { + log.Debugf("Last batch processed (%d invoices), "+ + "stopping pagination.", invoicesInBatch) + break } - l.invoiceStates[hash] = invoice.State + + // Prepare for the next batch. + indexOffset = invoiceResp.LastIndexOffset + log.Debugf("Processed batch, %d invoices loaded so "+ + "far. Next index offset: %d", + numInvoicesLoaded, indexOffset) } - l.invoicesMtx.Unlock() - // We need to be able to cancel any subscription we make. - ctxc, cancel := context.WithCancel(l.clientCtx()) - l.invoicesCancel = cancel + loadDuration := time.Since(startTime) + + log.Infof("Finished historical invoice loading. Loaded %d "+ + "relevant invoices in %v.", numInvoicesLoaded, + loadDuration) + + // Mark the initial load as complete *only after* all pages are + // processed. + l.invoiceStore.MarkInitialLoadComplete() +} + +// subscribeToInvoices sets up the invoice subscription stream and starts the +// reader goroutine. This runs in a goroutine managed by Start. +func (l *LndChallenger) subscribeToInvoices(addIndex, settleIndex uint64) { + defer l.wg.Done() + + // We need a separate context for the subscription stream, managed by + // invoicesCancel. + ctxSub, cancelSub := context.WithCancel(l.clientCtx()) + defer func() { + // Only call cancelSub if l.invoicesCancel hasn't been assigned + // yet (meaning subscription failed or shutdown happened before + // success). If l.invoicesCancel was assigned, Stop() will + // handle cancellation. + if l.invoicesCancel == nil { + cancelSub() + } + }() + + // Check for immediate shutdown before attempting subscription. + select { + case <-l.quit: + log.Warnf("Shutdown signal received before starting " + + "invoice subscription.") + return + default: + } + + log.Infof("Attempting to subscribe to invoice updates starting "+ + "from add_index=%d, settle_index=%d", addIndex, settleIndex) subscriptionResp, err := l.client.SubscribeInvoices( - ctxc, &lnrpc.InvoiceSubscription{ + ctxSub, &lnrpc.InvoiceSubscription{ AddIndex: addIndex, SettleIndex: settleIndex, }, ) if err != nil { - cancel() - return err + // If context was cancelled by shutdown, it's not a fatal error. + if strings.Contains(err.Error(), context.Canceled.Error()) { + log.Warnf("Invoice subscription cancelled " + + "during setup by shutdown.") + return + } + + log.Errorf("Failed to subscribe to invoices: %v", err) + select { + case l.errChan <- fmt.Errorf("failed invoice "+ + "subscription: %w", err): + + case <-l.quit: + } + return } + // Store the cancel function *only after* SubscribeInvoices succeeds. + l.invoicesCancel = cancelSub + + log.Infof("Successfully subscribed to invoice updates.") + + // Start the goroutine to read from the subscription stream. Add to + // WaitGroup *before* launching. This WG count belongs to the + // readInvoiceStream lifecycle, managed by this parent goroutine. l.wg.Add(1) go func() { + // Ensure Done is called regardless of how readInvoiceStream + // exits. defer l.wg.Done() - defer cancel() + // Ensure the subscription context is cancelled when this reader + // goroutine exits. Calling the stored l.invoicesCancel ensures + // Stop() can also cancel it. + defer l.invoicesCancel() l.readInvoiceStream(subscriptionResp) }() - return nil + log.Infof("Invoice subscription reader started.") + + // Keep this goroutine alive until quit signal to manage + // readInvoiceStream. + <-l.quit + log.Infof("Invoice subscription manager shutting down.") } // readInvoiceStream reads the invoice update messages sent on the stream until // the stream is aborted or the challenger is shutting down. +// This runs in a goroutine managed by subscribeToInvoices. func (l *LndChallenger) readInvoiceStream( stream lnrpc.Lightning_SubscribeInvoicesClient) { @@ -225,26 +437,29 @@ func (l *LndChallenger) readInvoiceStream( return } - l.invoicesMtx.Lock() if invoiceIrrelevant(invoice) { // Don't keep the state of canceled or expired invoices. - delete(l.invoiceStates, hash) + l.invoiceStore.DeleteState(hash) } else { - l.invoiceStates[hash] = invoice.State + l.invoiceStore.SetState(hash, invoice.State) } - - // Before releasing the lock, notify our conditions that listen - // for updates on the invoice state. - l.invoicesCond.Broadcast() - l.invoicesMtx.Unlock() } } // Stop shuts down the challenger. func (l *LndChallenger) Stop() { - l.invoicesCancel() + log.Infof("Stopping LND challenger...") + // Signal all goroutines to exit. close(l.quit) + + // Cancel the subscription context if it exists and was set. + // invoicesCancel is initialized to a no-op, so safe to call always. + l.invoicesCancel() + + // Wait for all background goroutines (loadHistorical, subscribeToInvoices, + // and readInvoiceStream) to finish. l.wg.Wait() + log.Infof("LND challenger stopped.") } // NewChallenge creates a new LSAT payment challenge, returning a payment @@ -278,88 +493,40 @@ func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, return response.PaymentRequest, paymentHash, nil } -// VerifyInvoiceStatus checks that an invoice identified by a payment -// hash has the desired status. To make sure we don't fail while the -// invoice update is still on its way, we try several times until either -// the desired status is set or the given timeout is reached. +// VerifyInvoiceStatus checks that an invoice identified by a payment hash has +// the desired status. It waits until the desired status is reached or the +// given timeout occurs. It also handles waiting for the initial invoice load +// if necessary. // // NOTE: This is part of the auth.InvoiceChecker interface. func (l *LndChallenger) VerifyInvoiceStatus(hash lntypes.Hash, state lnrpc.Invoice_InvoiceState, timeout time.Duration) error { - // Prevent the challenger to be shut down while we're still waiting for - // status updates. + // Prevent the challenger from being shut down while we're still waiting + // for status updates. Add to WG *before* calling wait. l.wg.Add(1) defer l.wg.Done() - var ( - condWg sync.WaitGroup - doneChan = make(chan struct{}) - timeoutReached bool - hasInvoice bool - invoiceState lnrpc.Invoice_InvoiceState - ) - - // First of all, spawn a goroutine that will signal us on timeout. - // Otherwise if a client subscribes to an update on an invoice that - // never arrives, and there is no other activity, it would block - // forever in the condition. - condWg.Add(1) - go func() { - defer condWg.Done() - - select { - case <-doneChan: - case <-time.After(timeout): - case <-l.quit: - } - - l.invoicesCond.L.Lock() - timeoutReached = true - l.invoicesCond.Broadcast() - l.invoicesCond.L.Unlock() - }() - - // Now create the main goroutine that blocks until an update is received - // on the condition. - condWg.Add(1) - go func() { - defer condWg.Done() - l.invoicesCond.L.Lock() - - // Block here until our condition is met or the allowed time is - // up. The Wait() will return whenever a signal is broadcast. - invoiceState, hasInvoice = l.invoiceStates[hash] - for !(hasInvoice && invoiceState == state) && !timeoutReached { - l.invoicesCond.Wait() - - // The Wait() above has re-acquired the lock so we can - // safely access the states map. - invoiceState, hasInvoice = l.invoiceStates[hash] - } - - // We're now done. - l.invoicesCond.L.Unlock() - close(doneChan) - }() - - // Wait until we're either done or timed out. - condWg.Wait() - - // Interpret the result so we can return a more descriptive error than - // just "failed". - switch { - case !hasInvoice: - return fmt.Errorf("no active or settled invoice found for "+ - "hash=%v", hash) - - case invoiceState != state: - return fmt.Errorf("invoice status not correct before timeout, "+ - "hash=%v, status=%v", hash, invoiceState) - + // Check for immediate shutdown signal before potentially blocking. + select { + case <-l.quit: + return fmt.Errorf("challenger shutting down") default: - return nil } + + // Delegate the waiting logic to the invoice store. + // We use a default timeout for the initial load wait, and the provided + // timeout for the specific state wait. + err := l.invoiceStore.WaitForState( + hash, state, defaultInitialLoadTimeout, timeout, + ) + if err != nil { + // Add context to the error message. + return fmt.Errorf("error verifying invoice status for hash %v "+ + "(target state %v): %w", hash, state, err) + } + + return nil } // invoiceIrrelevant returns true if an invoice is nil, canceled or non-settled diff --git a/challenger/lnd_test.go b/challenger/lnd_test.go index 82f7620..ebe3db3 100644 --- a/challenger/lnd_test.go +++ b/challenger/lnd_test.go @@ -98,18 +98,19 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) { return newInvoice(lntypes.ZeroHash, 99, lnrpc.Invoice_OPEN), nil } - invoicesMtx := &sync.Mutex{} + mainErrChan := make(chan error) - return &LndChallenger{ + quitChan := make(chan struct{}) + challenger := &LndChallenger{ client: mockClient, clientCtx: context.Background, genInvoiceReq: genInvoiceReq, - invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), - quit: make(chan struct{}), - invoicesMtx: invoicesMtx, - invoicesCond: sync.NewCond(invoicesMtx), + invoiceStore: NewInvoiceStateStore(quitChan), + quit: quitChan, errChan: mainErrChan, - }, mockClient, mainErrChan + } + + return challenger, mockClient, mainErrChan } func newInvoice(hash lntypes.Hash, addIndex uint64, @@ -131,7 +132,7 @@ func TestLndChallenger(t *testing.T) { // First of all, test that the NewLndChallenger doesn't allow a nil // invoice generator function. errChan := make(chan error) - _, err := NewLndChallenger(nil, nil, nil, errChan) + _, err := NewLndChallenger(nil, 0, nil, nil, errChan) require.Error(t, err) // Now mock the lnd backend and create a challenger instance that we can @@ -149,10 +150,16 @@ func TestLndChallenger(t *testing.T) { // Now we already have an invoice in our lnd mock. When starting the // challenger, we should have that invoice in the cache and a // subscription that only starts at our faked addIndex. - err = c.Start() + // In the test setup, Start() is called after the challenger is created + // by newChallenger, which already pre-populates the store. + // We'll call Start() again here to ensure the subscription logic runs. + c.Start() require.NoError(t, err) - require.Equal(t, 1, len(c.invoiceStates)) - require.Equal(t, lnrpc.Invoice_OPEN, c.invoiceStates[lntypes.ZeroHash]) + + // Wait for the invoices to be loaded. + c.invoiceStore.WaitForInitialLoad(time.Second * 3) + + // Verify the initial state using the public method, not direct access. require.Equal(t, uint64(99), invoiceMock.lastAddIndex) require.NoError(t, c.VerifyInvoiceStatus( lntypes.ZeroHash, lnrpc.Invoice_OPEN, defaultTimeout, @@ -223,16 +230,20 @@ func TestLndChallenger(t *testing.T) { } // Finally test that if an error occurs in the invoice subscription the - // challenger reports it on the main error channel to cause a shutdown - // of aperture. The mock's error channel is buffered so we can send - // directly. - invoiceMock.errChan <- fmt.Errorf("an expected error") + // challenger reports it on the main error channel to cause a shutdown. + // The mock's error channel is buffered so we can send directly. + expectedErr := fmt.Errorf("an expected error") + invoiceMock.errChan <- expectedErr select { case err := <-mainErrChan: - require.Error(t, err) + require.ErrorIs(t, err, expectedErr) // Check if it's the expected error // Make sure that the goroutine exited. done := make(chan struct{}) + require.Error(t, err) + + // Make sure that the goroutine exited. + done = make(chan struct{}) go func() { c.wg.Wait() done <- struct{}{} @@ -249,6 +260,8 @@ func TestLndChallenger(t *testing.T) { t.Fatalf("error not received on main chan before the timeout") } + // Stop the mock client first to close its quit channel used by the store invoiceMock.stop() + // Then stop the challenger c.Stop() }