diff --git a/.gitignore b/.gitignore index f2b8ad6..a58ceec 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,6 @@ cmd/aperture/aperture # misc -.vscode \ No newline at end of file +.vscode +.aider* +CONVENTIONS.md diff --git a/aperture.go b/aperture.go index 5fbb4dd..f5ced9b 100644 --- a/aperture.go +++ b/aperture.go @@ -336,7 +336,8 @@ func (a *Aperture) Start(errChan chan error) error { } a.challenger, err = challenger.NewLNCChallenger( - session, lncStore, genInvoiceReq, errChan, + session, lncStore, a.cfg.InvoiceBatchSize, + genInvoiceReq, errChan, ) if err != nil { return fmt.Errorf("unable to start lnc "+ @@ -359,8 +360,8 @@ func (a *Aperture) Start(errChan chan error) error { } a.challenger, err = challenger.NewLndChallenger( - client, genInvoiceReq, context.Background, - errChan, + client, a.cfg.InvoiceBatchSize, genInvoiceReq, + context.Background, errChan, ) if err != nil { return err diff --git a/challenger/invoice_store.go b/challenger/invoice_store.go new file mode 100644 index 0000000..fe30658 --- /dev/null +++ b/challenger/invoice_store.go @@ -0,0 +1,344 @@ +package challenger + +import ( + "fmt" + "sync" + "sync/atomic" // Import atomic package + "time" + + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" +) + +// InvoiceStateStore manages the state of invoices in a thread-safe manner and +// allows goroutines to wait for specific states or initial load completion. +type InvoiceStateStore struct { + // states holds the last known state for invoices. + states map[lntypes.Hash]lnrpc.Invoice_InvoiceState + + // mtx guards access to states and initialLoadComplete. + mtx sync.Mutex + + // cond is used to signal waiters when states is updated or when the + // initial load completes. + cond *sync.Cond + + // initialLoadComplete is true once the initial fetching of all + // historical invoices is done. Use atomic for lock-free reads/writes. + initialLoadComplete atomic.Bool + + // quit channel signals the store that the challenger is shutting down. + // Waiters should abort if this channel is closed. + quit <-chan struct{} +} + +// NewInvoiceStateStore creates a new instance of InvoiceStateStore. The quit +// channel should be the challenger's main quit channel. +func NewInvoiceStateStore(quit <-chan struct{}) *InvoiceStateStore { + s := &InvoiceStateStore{ + states: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), + quit: quit, + } + + // Initialize cond with the store's mutex. + s.cond = sync.NewCond(&s.mtx) + + return s +} + +// SetState adds or updates the state for a given invoice hash. It notifies any +// waiting goroutines about the change. +func (s *InvoiceStateStore) SetState(hash lntypes.Hash, + state lnrpc.Invoice_InvoiceState) { + + s.mtx.Lock() + defer s.mtx.Unlock() + + // Only broadcast if the state actually changes or is new. + currentState, exists := s.states[hash] + if !exists || currentState != state { + s.states[hash] = state + + // Signal potential waiters. + s.cond.Broadcast() + } +} + +// DeleteState removes an invoice state from the store, typically used for +// irrelevant (canceled/expired) invoices. It notifies any waiting goroutines +// about the change. +func (s *InvoiceStateStore) DeleteState(hash lntypes.Hash) { + s.mtx.Lock() + defer s.mtx.Unlock() + + // Only broadcast if the state actually existed. + if _, exists := s.states[hash]; exists { + delete(s.states, hash) + + // Signal potential waiters. + s.cond.Broadcast() + } +} + +// GetState retrieves the current state for a given invoice hash. +func (s *InvoiceStateStore) GetState(hash lntypes.Hash, +) (lnrpc.Invoice_InvoiceState, bool) { + + s.mtx.Lock() + defer s.mtx.Unlock() + + state, exists := s.states[hash] + return state, exists +} + +// MarkInitialLoadComplete sets the initialLoadComplete flag to true atomically +// and broadcasts on the condition variable to wake up any waiting goroutines. +func (s *InvoiceStateStore) MarkInitialLoadComplete() { + // Check atomically first to potentially avoid locking and broadcasting. + if s.initialLoadComplete.Load() { + // Already marked so we can return early. + return + } + + // Grab the lock now to ensure we can use the condition variable safely. + s.mtx.Lock() + defer s.mtx.Unlock() + + // Double-check under lock in case another goroutine just did it. + if !s.initialLoadComplete.Load() { + s.initialLoadComplete.Store(true) + + // Wake up everyone waiting. + s.cond.Broadcast() + log.Infof("Invoice store marked initial load as complete.") + } +} + +// IsInitialLoadComplete checks atomically if the initial historical invoice +// load has finished. +func (s *InvoiceStateStore) IsInitialLoadComplete() bool { + return s.initialLoadComplete.Load() +} + +// waitForCondition blocks until the provided condition function returns true, a +// timeout occurs, or the quit signal is received. The mutex `s.mtx` MUST be +// held by the caller when calling this function. The mutex will be unlocked +// while waiting and re-locked before returning. It returns an error if the +// timeout is reached or the quit signal is received. +func (s *InvoiceStateStore) waitForCondition(condition func() bool, + timeout time.Duration, timeoutMsg string) error { + + // Check condition immediately before waiting. + if condition() { + return nil + } + + // Start the timeout timer. + timer := time.NewTimer(timeout) + defer timer.Stop() + + // Channel to signal when the condition is met or quit signal is + // received. + waitDone := make(chan struct{}) + + // Goroutine to wait on the condition variable. + go func() { + // Re-acquire lock for cond.Wait + s.mtx.Lock() + for !condition() { + // Check quit signal before waiting indefinitely. + select { + case <-s.quit: + s.mtx.Unlock() + close(waitDone) + return + default: + } + + // Wait for the condition to be signaled. + s.cond.Wait() + } + s.mtx.Unlock() + close(waitDone) + }() + + // Unlock to allow the waiting goroutine to acquire it. We expect the + // caller to already have held the lock. + s.mtx.Unlock() + + // Wait for either the condition to be met, timeout, or quit signal. + var errResult error + select { + case <-waitDone: + // Condition met or quit signal received by waiter. + if !timer.Stop() { + // Timer already fired and channel might contain value, + // drain it. Use a select to prevent blocking if the + // channel is empty. + select { + case <-timer.C: + default: + } + } + + // Re-check quit signal after waitDone is closed. + select { + case <-s.quit: + log.Warnf("waitForCondition: Shutdown signal received " + + "while condition was being met.") + + errResult = fmt.Errorf("challenger shutting down") + + default: + // Condition was met successfully. + errResult = nil + } + + case <-timer.C: + // Timeout expired. + log.Warnf("waitForCondition: %s (timeout: %v)", timeoutMsg, + timeout) + errResult = fmt.Errorf("%s", timeoutMsg) + + // We need to signal the waiting goroutine to stop, best way is via + // quit channel, but we don't control that. The waiting goroutine will + // eventually see the condition is true (if it changes later) or hit the + // quit signal. + + case <-s.quit: + // Shutdown signal received while waiting for timer/condition. + log.Warnf("waitForCondition: Shutdown signal received.") + + timer.Stop() + errResult = fmt.Errorf("challenger shutting down") + } + + // Re-acquire lock before returning, as expected by the caller. + s.mtx.Lock() + return errResult +} + +// WaitForState blocks until the specified invoice hash reaches the desiredState +// or a timeout occurs. It first waits for the initial historical invoice load +// to complete if necessary. initialLoadTimeout applies only if waiting for the +// initial load. requestTimeout applies when waiting for the specific invoice +// state change. +func (s *InvoiceStateStore) WaitForState(hash lntypes.Hash, + desiredState lnrpc.Invoice_InvoiceState, initialLoadTimeout time.Duration, + requestTimeout time.Duration) error { + + // Check to see if we need to wait for the initial load to complete. + if !s.initialLoadComplete.Load() { + log.Debugf("WaitForState: Initial load not complete, waiting "+ + "up to %v for hash %v...", + initialLoadTimeout, hash) + + initialLoadCondition := func() bool { + return s.initialLoadComplete.Load() + } + + timeoutMsg := fmt.Sprintf("timed out waiting for initial "+ + "invoice load after %v", initialLoadTimeout) + + err := s.waitForCondition( + initialLoadCondition, initialLoadTimeout, timeoutMsg, + ) + if err != nil { + log.Warnf("WaitForState: Error waiting for initial "+ + "load for hash %v: %v", hash, err) + return err + } + + log.Debugf("WaitForState: Initial load completed for hash %v", + hash) + } + + // We'll first check to see if the state is already where we need it to + // be. + currentState, hasInvoice := s.states[hash] + if hasInvoice && currentState == desiredState { + log.Debugf("WaitForState: Hash %v already in desired state %v.", + hash, desiredState) + return nil + } + + // If not, then we'll wait in the background for the condition to be + // met. + log.Debugf("WaitForState: Waiting up to %v for hash %v to reach "+ + "state %v...", requestTimeout, hash, desiredState) + + specificStateCondition := func() bool { + // Re-check state within the condition function under lock. + st, exists := s.states[hash] + return exists && st == desiredState + } + + timeoutMsg := fmt.Sprintf("timed out waiting for state %v after %v", + desiredState, requestTimeout) + + // We'll wait for the invoice to reach the desired state. + err := s.waitForCondition( + specificStateCondition, requestTimeout, timeoutMsg, + ) + if err != nil { + // If we timed out, provide a more specific error message based + // on the final state. + finalState, finalExists := s.states[hash] + if err.Error() == timeoutMsg { + log.Warnf("WaitForState: Timed out after %v waiting "+ + "for hash %v state %v. Final state: %v, "+ + "exists: %v", requestTimeout, hash, + desiredState, finalState, finalExists) + + if !finalExists { + return fmt.Errorf("no active or settled "+ + "invoice found for hash=%v after "+ + "timeout", hash) + } + + return fmt.Errorf("invoice status %v not %v before "+ + "timeout for hash=%v", finalState, + desiredState, hash) + } + + // Otherwise, it was likely a shutdown error. + log.Warnf("WaitForState: Error waiting for specific "+ + "state for hash %v: %v", hash, err) + return err + } + + // Condition was met successfully. + log.Debugf("WaitForState: Hash %v reached desired state %v.", + hash, desiredState) + return nil +} + +// WaitForInitialLoad blocks until the initial historical invoice load has +// completed, or a timeout occurs. +func (s *InvoiceStateStore) WaitForInitialLoad(timeout time.Duration) error { + // Check if already complete. + if s.initialLoadComplete.Load() { + return nil + } + + log.Debugf("WaitForInitialLoad: Initial load not complete, waiting up to %v...", + timeout) + + initialLoadCondition := func() bool { + // Atomic read, no lock needed for this condition check. + return s.initialLoadComplete.Load() + } + timeoutMsg := fmt.Sprintf("timed out waiting for initial invoice load after %v", timeout) + + s.mtx.Lock() + + // Wait for the condition using the helper. + err := s.waitForCondition(initialLoadCondition, timeout, timeoutMsg) + if err != nil { + log.Warnf("WaitForInitialLoad: Error waiting: %v", err) + return err // Return error (timeout or shutdown) + } + + log.Debugf("WaitForInitialLoad: Initial load completed.") + return nil +} diff --git a/challenger/lnc.go b/challenger/lnc.go index d30891f..fa13d0a 100644 --- a/challenger/lnc.go +++ b/challenger/lnc.go @@ -19,7 +19,7 @@ type LNCChallenger struct { // NewLNCChallenger creates a new challenger that uses the given LNC session to // connect to an lnd backend to create payment challenges. func NewLNCChallenger(session *lnc.Session, lncStore lnc.Store, - genInvoiceReq InvoiceRequestGenerator, + invoiceBatchSize int, genInvoiceReq InvoiceRequestGenerator, errChan chan<- error) (*LNCChallenger, error) { nodeConn, err := lnc.NewNodeConn(session, lncStore) @@ -34,16 +34,13 @@ func NewLNCChallenger(session *lnc.Session, lncStore lnc.Store, } lndChallenger, err := NewLndChallenger( - client, genInvoiceReq, nodeConn.CtxFunc, errChan, + client, invoiceBatchSize, genInvoiceReq, nodeConn.CtxFunc, errChan, ) if err != nil { return nil, err } - err = lndChallenger.Start() - if err != nil { - return nil, err - } + lndChallenger.Start() return &LNCChallenger{ lndChallenger: lndChallenger, diff --git a/challenger/lnd.go b/challenger/lnd.go index 00bd53e..f18f7a4 100644 --- a/challenger/lnd.go +++ b/challenger/lnd.go @@ -4,26 +4,38 @@ import ( "context" "fmt" "io" - "math" "strings" "sync" "time" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" ) +const ( + // defaultListInvoicesBatchSize is the default number of invoices to fetch + // in each ListInvoices call during the historical load. + defaultListInvoicesBatchSize = 1000 +) + +var ( + // defaultInitialLoadTimeout is the maximum time we wait for the initial + // batch of invoices to be loaded from lnd before allowing state checks + // to proceed or fail. + defaultInitialLoadTimeout = 10 * time.Second +) + // LndChallenger is a challenger that uses an lnd backend to create new LSAT // payment challenges. type LndChallenger struct { client InvoiceClient clientCtx func() context.Context genInvoiceReq InvoiceRequestGenerator + batchSize int // Added batchSize - invoiceStates map[lntypes.Hash]lnrpc.Invoice_InvoiceState - invoicesMtx *sync.Mutex + invoiceStore *InvoiceStateStore invoicesCancel func() - invoicesCond *sync.Cond errChan chan<- error @@ -37,7 +49,7 @@ var _ Challenger = (*LndChallenger)(nil) // NewLndChallenger creates a new challenger that uses the given connection to // an lnd backend to create payment challenges. -func NewLndChallenger(client InvoiceClient, +func NewLndChallenger(client InvoiceClient, batchSize int, genInvoiceReq InvoiceRequestGenerator, ctxFunc func() context.Context, errChan chan<- error) (*LndChallenger, error) { @@ -52,110 +64,310 @@ func NewLndChallenger(client InvoiceClient, return nil, fmt.Errorf("genInvoiceReq cannot be nil") } - invoicesMtx := &sync.Mutex{} - challenger := &LndChallenger{ - client: client, - clientCtx: ctxFunc, - genInvoiceReq: genInvoiceReq, - invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), - invoicesMtx: invoicesMtx, - invoicesCond: sync.NewCond(invoicesMtx), - quit: make(chan struct{}), - errChan: errChan, + // Use default batch size if zero or negative is provided. + if batchSize <= 0 { + batchSize = defaultListInvoicesBatchSize } - err := challenger.Start() - if err != nil { - return nil, fmt.Errorf("unable to start challenger: %w", err) + quitChan := make(chan struct{}) + challenger := &LndChallenger{ + client: client, + clientCtx: ctxFunc, + genInvoiceReq: genInvoiceReq, + batchSize: batchSize, + invoiceStore: NewInvoiceStateStore(quitChan), + invoicesCancel: func() {}, + quit: quitChan, + errChan: errChan, } + // Start the background loading/subscription process. + challenger.Start() + return challenger, nil } -// Start starts the challenger's main work which is to keep track of all -// invoices and their states. For that the backing lnd node is queried for all -// invoices on startup and the a subscription to all subsequent invoice updates -// is created. -func (l *LndChallenger) Start() error { - // These are the default values for the subscription. In case there are - // no invoices yet, this will instruct lnd to just send us all updates. - // If there are existing invoices, these indices will be updated to - // reflect the latest known invoices. +// Start launches the background process to load historical invoices and +// subscribe to new invoice updates concurrently. This method returns +// immediately. +func (l *LndChallenger) Start() { + log.Infof("Starting LND challenger background tasks...") + + // Use a short timeout context for this initial call. + ctxIdx, cancelIdx := context.WithTimeout(l.clientCtx(), 30*time.Second) + defer cancelIdx() + addIndex := uint64(0) settleIndex := uint64(0) - // Get a list of all existing invoices on startup and add them to our - // cache. We need to keep track of all invoices, even quite old ones to - // make sure tokens are valid. But to save space we only keep track of - // an invoice's state. - ctx := l.clientCtx() - invoiceResp, err := l.client.ListInvoices( - ctx, &lnrpc.ListInvoiceRequest{ - NumMaxInvoices: math.MaxUint64, + log.Debugf("Querying latest invoice indices...") + latestInvoiceResp, err := l.client.ListInvoices( + ctxIdx, &lnrpc.ListInvoiceRequest{ + NumMaxInvoices: 1, // Only need the latest one + Reversed: true, }, ) if err != nil { - return err + // Don't fail startup entirely, just log and proceed with 0 + // indices. The historical load will catch up. + log.Errorf("Failed to get latest invoice indices, "+ + "subscribing from beginning (error: %v)", err) + } else if len(latestInvoiceResp.Invoices) > 0 { + // Indices are only meaningful if we actually got an invoice. + latestInvoice := latestInvoiceResp.Invoices[0] + addIndex = latestInvoice.AddIndex + settleIndex = latestInvoice.SettleIndex + + log.Infof("Latest indices found: add=%d, settle=%d", + addIndex, settleIndex) + } else { + log.Infof("No existing invoices found, subscribing " + + "from beginning.") } - // Advance our indices to the latest known one so we'll only receive - // updates for new invoices and/or newly settled invoices. - l.invoicesMtx.Lock() - for _, invoice := range invoiceResp.Invoices { - // Some invoices like AMP invoices may not have a payment hash - // populated. - if invoice.RHash == nil { - continue + cancelIdx() + + // We'll launch our first goroutine to load the historical invoices in + // the background. + l.wg.Add(1) + go l.loadHistoricalInvoices() + + // We'll launch our second goroutine to subscribe to new invoices in to + // populate the invoice store with new updates. + l.wg.Add(1) + go l.subscribeToInvoices(addIndex, settleIndex) + + log.Infof("LND challenger background tasks launched.") +} + +// loadHistoricalInvoices fetches all past invoices relevant using pagination +// and updates the invoice store. It marks the initial load complete upon +// finishing. This runs in a goroutine. +func (l *LndChallenger) loadHistoricalInvoices() { + defer l.wg.Done() + + log.Infof("Starting historical invoice loading "+ + "(batch size %d)...", l.batchSize) + + // Use a background context for the potentially long-running list calls. + // Allow it to be cancelled by Stop() via the main quit channel. + ctxList, cancelList := context.WithCancel(l.clientCtx()) + defer cancelList() + + // Goroutine to cancel the list context if quit signal is received. + go func() { + select { + case <-l.quit: + log.Warnf("Shutdown signal received, cancelling " + + "historical invoice list context.") + cancelList() + + case <-ctxList.Done(): } + }() + + startTime := time.Now() + numInvoicesLoaded := 0 + indexOffset := uint64(0) + + for { + // Check for shutdown signal before each batch. + select { + case <-l.quit: + log.Warnf("Shutdown signal received during " + + "historical invoice loading.") - if invoice.AddIndex > addIndex { - addIndex = invoice.AddIndex + // Mark load complete anyway so waiters don't block + // indefinitely. + l.invoiceStore.MarkInitialLoadComplete() + return + default: } - if invoice.SettleIndex > settleIndex { - settleIndex = invoice.SettleIndex + + log.Debugf("Querying invoices batch starting from "+ + "index %d", indexOffset) + + req := &lnrpc.ListInvoiceRequest{ + IndexOffset: indexOffset, + NumMaxInvoices: uint64(l.batchSize), } - hash, err := lntypes.MakeHash(invoice.RHash) + + invoiceResp, err := l.client.ListInvoices(ctxList, req) + if err != nil { - l.invoicesMtx.Unlock() - return fmt.Errorf("error parsing invoice hash: %v", err) + // If context was cancelled by shutdown, it's not a + // fatal startup error. + if strings.Contains(err.Error(), context.Canceled.Error()) { + log.Warnf("Historical invoice loading " + + "cancelled by shutdown.") + + l.invoiceStore.MarkInitialLoadComplete() + return + } + log.Errorf("Failed to list invoices batch "+ + "(offset %d): %v", indexOffset, err) + + // Signal fatal error to the main application. + select { + case l.errChan <- fmt.Errorf("failed historical "+ + "invoice load batch: %w", err): + case <-l.quit: // Don't block if shutting down + } + + // Mark load complete on error so waiters don't block + // indefinitely. + l.invoiceStore.MarkInitialLoadComplete() + return } - // Don't track the state of canceled or expired invoices. - if invoiceIrrelevant(invoice) { - continue + // Process the received batch. + invoicesInBatch := len(invoiceResp.Invoices) + + log.Debugf("Received %d invoices in batch (offset %d)", + invoicesInBatch, indexOffset) + + fmt.Println("Loading incoies: ", spew.Sdump(invoiceResp.Invoices)) + + for _, invoice := range invoiceResp.Invoices { + // Some invoices like AMP invoices may not have a + // payment hash populated. + if invoice.RHash == nil { + continue + } + + hash, err := lntypes.MakeHash(invoice.RHash) + if err != nil { + log.Errorf("Error parsing invoice hash "+ + "during initial load: %v. Skipping "+ + "invoice.", err) + continue + } + + // Don't track the state of irrelevant invoices. + if invoiceIrrelevant(invoice) { + continue + } + + l.invoiceStore.SetState(hash, invoice.State) + numInvoicesLoaded++ + } + + // If this batch was empty or less than max, we're done with + // history. LND documentation suggests LastIndexOffset is the + // index of the *last* invoice returned. If no invoices + // returned, break. If NumMaxInvoices was returned, continue + // from LastIndexOffset. If < NumMaxInvoices returned, we are + // also done. + if invoicesInBatch == 0 || invoicesInBatch < l.batchSize { + log.Debugf("Last batch processed (%d invoices), "+ + "stopping pagination.", invoicesInBatch) + break } - l.invoiceStates[hash] = invoice.State + + // Prepare for the next batch. + indexOffset = invoiceResp.LastIndexOffset + log.Debugf("Processed batch, %d invoices loaded so "+ + "far. Next index offset: %d", + numInvoicesLoaded, indexOffset) } - l.invoicesMtx.Unlock() - // We need to be able to cancel any subscription we make. - ctxc, cancel := context.WithCancel(l.clientCtx()) - l.invoicesCancel = cancel + loadDuration := time.Since(startTime) + + log.Infof("Finished historical invoice loading. Loaded %d "+ + "relevant invoices in %v.", numInvoicesLoaded, + loadDuration) + + // Mark the initial load as complete *only after* all pages are + // processed. + l.invoiceStore.MarkInitialLoadComplete() +} + +// subscribeToInvoices sets up the invoice subscription stream and starts the +// reader goroutine. This runs in a goroutine managed by Start. +func (l *LndChallenger) subscribeToInvoices(addIndex, settleIndex uint64) { + defer l.wg.Done() + + // We need a separate context for the subscription stream, managed by + // invoicesCancel. + ctxSub, cancelSub := context.WithCancel(l.clientCtx()) + defer func() { + // Only call cancelSub if l.invoicesCancel hasn't been assigned + // yet (meaning subscription failed or shutdown happened before + // success). If l.invoicesCancel was assigned, Stop() will + // handle cancellation. + if l.invoicesCancel == nil { + cancelSub() + } + }() + + // Check for immediate shutdown before attempting subscription. + select { + case <-l.quit: + log.Warnf("Shutdown signal received before starting " + + "invoice subscription.") + return + default: + } + + log.Infof("Attempting to subscribe to invoice updates starting "+ + "from add_index=%d, settle_index=%d", addIndex, settleIndex) subscriptionResp, err := l.client.SubscribeInvoices( - ctxc, &lnrpc.InvoiceSubscription{ + ctxSub, &lnrpc.InvoiceSubscription{ AddIndex: addIndex, SettleIndex: settleIndex, }, ) if err != nil { - cancel() - return err + // If context was cancelled by shutdown, it's not a fatal error. + if strings.Contains(err.Error(), context.Canceled.Error()) { + log.Warnf("Invoice subscription cancelled " + + "during setup by shutdown.") + return + } + + log.Errorf("Failed to subscribe to invoices: %v", err) + select { + case l.errChan <- fmt.Errorf("failed invoice "+ + "subscription: %w", err): + + case <-l.quit: + } + return } + // Store the cancel function *only after* SubscribeInvoices succeeds. + l.invoicesCancel = cancelSub + + log.Infof("Successfully subscribed to invoice updates.") + + // Start the goroutine to read from the subscription stream. Add to + // WaitGroup *before* launching. This WG count belongs to the + // readInvoiceStream lifecycle, managed by this parent goroutine. l.wg.Add(1) go func() { + // Ensure Done is called regardless of how readInvoiceStream + // exits. defer l.wg.Done() - defer cancel() + // Ensure the subscription context is cancelled when this reader + // goroutine exits. Calling the stored l.invoicesCancel ensures + // Stop() can also cancel it. + defer l.invoicesCancel() l.readInvoiceStream(subscriptionResp) }() - return nil + log.Infof("Invoice subscription reader started.") + + // Keep this goroutine alive until quit signal to manage + // readInvoiceStream. + <-l.quit + log.Infof("Invoice subscription manager shutting down.") } // readInvoiceStream reads the invoice update messages sent on the stream until // the stream is aborted or the challenger is shutting down. +// This runs in a goroutine managed by subscribeToInvoices. func (l *LndChallenger) readInvoiceStream( stream lnrpc.Lightning_SubscribeInvoicesClient) { @@ -225,26 +437,29 @@ func (l *LndChallenger) readInvoiceStream( return } - l.invoicesMtx.Lock() if invoiceIrrelevant(invoice) { // Don't keep the state of canceled or expired invoices. - delete(l.invoiceStates, hash) + l.invoiceStore.DeleteState(hash) } else { - l.invoiceStates[hash] = invoice.State + l.invoiceStore.SetState(hash, invoice.State) } - - // Before releasing the lock, notify our conditions that listen - // for updates on the invoice state. - l.invoicesCond.Broadcast() - l.invoicesMtx.Unlock() } } // Stop shuts down the challenger. func (l *LndChallenger) Stop() { - l.invoicesCancel() + log.Infof("Stopping LND challenger...") + // Signal all goroutines to exit. close(l.quit) + + // Cancel the subscription context if it exists and was set. + // invoicesCancel is initialized to a no-op, so safe to call always. + l.invoicesCancel() + + // Wait for all background goroutines (loadHistorical, subscribeToInvoices, + // and readInvoiceStream) to finish. l.wg.Wait() + log.Infof("LND challenger stopped.") } // NewChallenge creates a new LSAT payment challenge, returning a payment @@ -278,88 +493,40 @@ func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, return response.PaymentRequest, paymentHash, nil } -// VerifyInvoiceStatus checks that an invoice identified by a payment -// hash has the desired status. To make sure we don't fail while the -// invoice update is still on its way, we try several times until either -// the desired status is set or the given timeout is reached. +// VerifyInvoiceStatus checks that an invoice identified by a payment hash has +// the desired status. It waits until the desired status is reached or the +// given timeout occurs. It also handles waiting for the initial invoice load +// if necessary. // // NOTE: This is part of the auth.InvoiceChecker interface. func (l *LndChallenger) VerifyInvoiceStatus(hash lntypes.Hash, state lnrpc.Invoice_InvoiceState, timeout time.Duration) error { - // Prevent the challenger to be shut down while we're still waiting for - // status updates. + // Prevent the challenger from being shut down while we're still waiting + // for status updates. Add to WG *before* calling wait. l.wg.Add(1) defer l.wg.Done() - var ( - condWg sync.WaitGroup - doneChan = make(chan struct{}) - timeoutReached bool - hasInvoice bool - invoiceState lnrpc.Invoice_InvoiceState - ) - - // First of all, spawn a goroutine that will signal us on timeout. - // Otherwise if a client subscribes to an update on an invoice that - // never arrives, and there is no other activity, it would block - // forever in the condition. - condWg.Add(1) - go func() { - defer condWg.Done() - - select { - case <-doneChan: - case <-time.After(timeout): - case <-l.quit: - } - - l.invoicesCond.L.Lock() - timeoutReached = true - l.invoicesCond.Broadcast() - l.invoicesCond.L.Unlock() - }() - - // Now create the main goroutine that blocks until an update is received - // on the condition. - condWg.Add(1) - go func() { - defer condWg.Done() - l.invoicesCond.L.Lock() - - // Block here until our condition is met or the allowed time is - // up. The Wait() will return whenever a signal is broadcast. - invoiceState, hasInvoice = l.invoiceStates[hash] - for !(hasInvoice && invoiceState == state) && !timeoutReached { - l.invoicesCond.Wait() - - // The Wait() above has re-acquired the lock so we can - // safely access the states map. - invoiceState, hasInvoice = l.invoiceStates[hash] - } - - // We're now done. - l.invoicesCond.L.Unlock() - close(doneChan) - }() - - // Wait until we're either done or timed out. - condWg.Wait() - - // Interpret the result so we can return a more descriptive error than - // just "failed". - switch { - case !hasInvoice: - return fmt.Errorf("no active or settled invoice found for "+ - "hash=%v", hash) - - case invoiceState != state: - return fmt.Errorf("invoice status not correct before timeout, "+ - "hash=%v, status=%v", hash, invoiceState) - + // Check for immediate shutdown signal before potentially blocking. + select { + case <-l.quit: + return fmt.Errorf("challenger shutting down") default: - return nil } + + // Delegate the waiting logic to the invoice store. + // We use a default timeout for the initial load wait, and the provided + // timeout for the specific state wait. + err := l.invoiceStore.WaitForState( + hash, state, defaultInitialLoadTimeout, timeout, + ) + if err != nil { + // Add context to the error message. + return fmt.Errorf("error verifying invoice status for hash %v "+ + "(target state %v): %w", hash, state, err) + } + + return nil } // invoiceIrrelevant returns true if an invoice is nil, canceled or non-settled diff --git a/challenger/lnd_test.go b/challenger/lnd_test.go index 82f7620..ebe3db3 100644 --- a/challenger/lnd_test.go +++ b/challenger/lnd_test.go @@ -98,18 +98,19 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) { return newInvoice(lntypes.ZeroHash, 99, lnrpc.Invoice_OPEN), nil } - invoicesMtx := &sync.Mutex{} + mainErrChan := make(chan error) - return &LndChallenger{ + quitChan := make(chan struct{}) + challenger := &LndChallenger{ client: mockClient, clientCtx: context.Background, genInvoiceReq: genInvoiceReq, - invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), - quit: make(chan struct{}), - invoicesMtx: invoicesMtx, - invoicesCond: sync.NewCond(invoicesMtx), + invoiceStore: NewInvoiceStateStore(quitChan), + quit: quitChan, errChan: mainErrChan, - }, mockClient, mainErrChan + } + + return challenger, mockClient, mainErrChan } func newInvoice(hash lntypes.Hash, addIndex uint64, @@ -131,7 +132,7 @@ func TestLndChallenger(t *testing.T) { // First of all, test that the NewLndChallenger doesn't allow a nil // invoice generator function. errChan := make(chan error) - _, err := NewLndChallenger(nil, nil, nil, errChan) + _, err := NewLndChallenger(nil, 0, nil, nil, errChan) require.Error(t, err) // Now mock the lnd backend and create a challenger instance that we can @@ -149,10 +150,16 @@ func TestLndChallenger(t *testing.T) { // Now we already have an invoice in our lnd mock. When starting the // challenger, we should have that invoice in the cache and a // subscription that only starts at our faked addIndex. - err = c.Start() + // In the test setup, Start() is called after the challenger is created + // by newChallenger, which already pre-populates the store. + // We'll call Start() again here to ensure the subscription logic runs. + c.Start() require.NoError(t, err) - require.Equal(t, 1, len(c.invoiceStates)) - require.Equal(t, lnrpc.Invoice_OPEN, c.invoiceStates[lntypes.ZeroHash]) + + // Wait for the invoices to be loaded. + c.invoiceStore.WaitForInitialLoad(time.Second * 3) + + // Verify the initial state using the public method, not direct access. require.Equal(t, uint64(99), invoiceMock.lastAddIndex) require.NoError(t, c.VerifyInvoiceStatus( lntypes.ZeroHash, lnrpc.Invoice_OPEN, defaultTimeout, @@ -223,16 +230,20 @@ func TestLndChallenger(t *testing.T) { } // Finally test that if an error occurs in the invoice subscription the - // challenger reports it on the main error channel to cause a shutdown - // of aperture. The mock's error channel is buffered so we can send - // directly. - invoiceMock.errChan <- fmt.Errorf("an expected error") + // challenger reports it on the main error channel to cause a shutdown. + // The mock's error channel is buffered so we can send directly. + expectedErr := fmt.Errorf("an expected error") + invoiceMock.errChan <- expectedErr select { case err := <-mainErrChan: - require.Error(t, err) + require.ErrorIs(t, err, expectedErr) // Check if it's the expected error // Make sure that the goroutine exited. done := make(chan struct{}) + require.Error(t, err) + + // Make sure that the goroutine exited. + done = make(chan struct{}) go func() { c.wg.Wait() done <- struct{}{} @@ -249,6 +260,8 @@ func TestLndChallenger(t *testing.T) { t.Fatalf("error not received on main chan before the timeout") } + // Stop the mock client first to close its quit channel used by the store invoiceMock.stop() + // Then stop the challenger c.Stop() }