Skip to content

Commit 54cf58b

Browse files
authored
Merge pull request #932 from ellemouton/context
multi: thread contexts through properly
2 parents 0fa5112 + 76250ca commit 54cf58b

31 files changed

+390
-357
lines changed

accounts/checkers_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ func TestSendPaymentCalls(t *testing.T) {
499499

500500
func testSendPayment(t *testing.T, uri string) {
501501
var (
502-
parentCtx = context.Background()
503-
zeroFee = &lnrpc.FeeLimit{Limit: &lnrpc.FeeLimit_Fixed{
502+
ctx = context.Background()
503+
zeroFee = &lnrpc.FeeLimit{Limit: &lnrpc.FeeLimit_Fixed{
504504
Fixed: 0,
505505
}}
506506
requestID uint64
@@ -520,7 +520,7 @@ func testSendPayment(t *testing.T, uri string) {
520520
service, err := NewService(t.TempDir(), errFunc)
521521
require.NoError(t, err)
522522

523-
err = service.Start(lndMock, routerMock, chainParams)
523+
err = service.Start(ctx, lndMock, routerMock, chainParams)
524524
require.NoError(t, err)
525525

526526
assertBalance := func(id AccountID, expectedBalance int64) {
@@ -533,7 +533,7 @@ func testSendPayment(t *testing.T, uri string) {
533533

534534
// This should error because there is no account in the context.
535535
err = service.checkers.checkIncomingRequest(
536-
parentCtx, uri, &lnrpc.SendRequest{},
536+
ctx, uri, &lnrpc.SendRequest{},
537537
)
538538
require.ErrorContains(t, err, "no account found in context")
539539

@@ -543,7 +543,7 @@ func testSendPayment(t *testing.T, uri string) {
543543
)
544544
require.NoError(t, err)
545545

546-
ctxWithAcct := AddAccountToContext(parentCtx, acct)
546+
ctxWithAcct := AddAccountToContext(ctx, acct)
547547

548548
// This should error because there is no request ID in the context.
549549
err = service.checkers.checkIncomingRequest(
@@ -552,7 +552,7 @@ func testSendPayment(t *testing.T, uri string) {
552552
require.ErrorContains(t, err, "no request ID found in context")
553553

554554
reqID1 := nextRequestID()
555-
ctx := AddRequestIDToContext(ctxWithAcct, reqID1)
555+
ctx = AddRequestIDToContext(ctxWithAcct, reqID1)
556556

557557
// This should error because no payment hash is provided.
558558
err = service.checkers.checkIncomingRequest(
@@ -698,7 +698,7 @@ func testSendPayment(t *testing.T, uri string) {
698698
func TestSendPaymentV2(t *testing.T) {
699699
var (
700700
uri = "/routerrpc.Router/SendPaymentV2"
701-
parentCtx = context.Background()
701+
ctx = context.Background()
702702
requestID uint64
703703
)
704704

@@ -716,7 +716,7 @@ func TestSendPaymentV2(t *testing.T) {
716716
service, err := NewService(t.TempDir(), errFunc)
717717
require.NoError(t, err)
718718

719-
err = service.Start(lndMock, routerMock, chainParams)
719+
err = service.Start(ctx, lndMock, routerMock, chainParams)
720720
require.NoError(t, err)
721721

722722
assertBalance := func(id AccountID, expectedBalance int64) {
@@ -729,7 +729,7 @@ func TestSendPaymentV2(t *testing.T) {
729729

730730
// This should error because there is no account in the context.
731731
err = service.checkers.checkIncomingRequest(
732-
parentCtx, uri, &routerrpc.SendPaymentRequest{},
732+
ctx, uri, &routerrpc.SendPaymentRequest{},
733733
)
734734
require.ErrorContains(t, err, "no account found in context")
735735

@@ -739,7 +739,7 @@ func TestSendPaymentV2(t *testing.T) {
739739
)
740740
require.NoError(t, err)
741741

742-
ctxWithAcct := AddAccountToContext(parentCtx, acct)
742+
ctxWithAcct := AddAccountToContext(ctx, acct)
743743

744744
// This should error because there is no request ID in the context.
745745
err = service.checkers.checkIncomingRequest(
@@ -748,7 +748,7 @@ func TestSendPaymentV2(t *testing.T) {
748748
require.ErrorContains(t, err, "no request ID found in context")
749749

750750
reqID1 := nextRequestID()
751-
ctx := AddRequestIDToContext(ctxWithAcct, reqID1)
751+
ctx = AddRequestIDToContext(ctxWithAcct, reqID1)
752752

753753
// This should error because no payment hash is provided.
754754
err = service.checkers.checkIncomingRequest(
@@ -885,7 +885,7 @@ func TestSendPaymentV2(t *testing.T) {
885885
func TestSendToRouteV2(t *testing.T) {
886886
var (
887887
uri = "/routerrpc.Router/SendToRouteV2"
888-
parentCtx = context.Background()
888+
ctx = context.Background()
889889
requestID uint64
890890
)
891891

@@ -903,7 +903,7 @@ func TestSendToRouteV2(t *testing.T) {
903903
service, err := NewService(t.TempDir(), errFunc)
904904
require.NoError(t, err)
905905

906-
err = service.Start(lndMock, routerMock, chainParams)
906+
err = service.Start(ctx, lndMock, routerMock, chainParams)
907907
require.NoError(t, err)
908908

909909
assertBalance := func(id AccountID, expectedBalance int64) {
@@ -916,7 +916,7 @@ func TestSendToRouteV2(t *testing.T) {
916916

917917
// This should error because there is no account in the context.
918918
err = service.checkers.checkIncomingRequest(
919-
parentCtx, uri, &routerrpc.SendToRouteRequest{},
919+
ctx, uri, &routerrpc.SendToRouteRequest{},
920920
)
921921
require.ErrorContains(t, err, "no account found in context")
922922

@@ -926,7 +926,7 @@ func TestSendToRouteV2(t *testing.T) {
926926
)
927927
require.NoError(t, err)
928928

929-
ctxWithAcct := AddAccountToContext(parentCtx, acct)
929+
ctxWithAcct := AddAccountToContext(ctx, acct)
930930

931931
// This should error because there is no request ID in the context.
932932
err = service.checkers.checkIncomingRequest(
@@ -935,7 +935,7 @@ func TestSendToRouteV2(t *testing.T) {
935935
require.ErrorContains(t, err, "no request ID found in context")
936936

937937
reqID1 := nextRequestID()
938-
ctx := AddRequestIDToContext(ctxWithAcct, reqID1)
938+
ctx = AddRequestIDToContext(ctxWithAcct, reqID1)
939939

940940
// This should error because no payment hash is provided.
941941
err = service.checkers.checkIncomingRequest(

accounts/service.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/btcsuite/btcd/chaincfg"
1111
"github.com/lightninglabs/lndclient"
12+
"github.com/lightninglabs/taproot-assets/fn"
1213
"github.com/lightningnetwork/lnd/channeldb"
1314
invpkg "github.com/lightningnetwork/lnd/invoices"
1415
"github.com/lightningnetwork/lnd/lnrpc"
@@ -55,7 +56,7 @@ type InterceptorService struct {
5556
routerClient lndclient.RouterClient
5657

5758
mainCtx context.Context
58-
contextCancel context.CancelFunc
59+
contextCancel fn.Option[context.CancelFunc]
5960

6061
requestMtx sync.Mutex
6162
checkers *AccountChecker
@@ -85,12 +86,8 @@ func NewService(dir string,
8586
return nil, err
8687
}
8788

88-
mainCtx, contextCancel := context.WithCancel(context.Background())
89-
9089
return &InterceptorService{
9190
store: accountStore,
92-
mainCtx: mainCtx,
93-
contextCancel: contextCancel,
9491
invoiceToAccount: make(map[lntypes.Hash]AccountID),
9592
pendingPayments: make(map[lntypes.Hash]*trackedPayment),
9693
requestValuesStore: newRequestValuesStore(),
@@ -101,9 +98,14 @@ func NewService(dir string,
10198
}
10299

103100
// Start starts the account service and its interceptor capability.
104-
func (s *InterceptorService) Start(lightningClient lndclient.LightningClient,
101+
func (s *InterceptorService) Start(ctx context.Context,
102+
lightningClient lndclient.LightningClient,
105103
routerClient lndclient.RouterClient, params *chaincfg.Params) error {
106104

105+
mainCtx, contextCancel := context.WithCancel(ctx)
106+
s.mainCtx = mainCtx
107+
s.contextCancel = fn.Some(contextCancel)
108+
107109
s.routerClient = routerClient
108110
s.checkers = NewAccountChecker(s, params)
109111

@@ -180,7 +182,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient,
180182
s.wg.Add(1)
181183
go func() {
182184
defer s.wg.Done()
183-
defer s.contextCancel()
185+
defer contextCancel()
184186

185187
for {
186188
select {
@@ -235,9 +237,8 @@ func (s *InterceptorService) Stop() error {
235237
s.requestMtx.Lock()
236238
defer s.requestMtx.Unlock()
237239

238-
s.contextCancel()
240+
s.contextCancel.WhenSome(func(fn context.CancelFunc) { fn() })
239241
close(s.quit)
240-
241242
s.wg.Wait()
242243

243244
return s.store.Close()

accounts/service_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,10 @@ func TestAccountService(t *testing.T) {
838838
}
839839

840840
// Any errors during startup expected?
841-
err = service.Start(lndMock, routerMock, chainParams)
841+
err = service.Start(
842+
context.Background(), lndMock, routerMock,
843+
chainParams,
844+
)
842845
if tc.startupErr != "" {
843846
require.ErrorContains(tt, err, tc.startupErr)
844847

autopilotserver/client.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
"github.com/btcsuite/btcd/btcec/v2"
1515
"github.com/lightninglabs/lightning-terminal/autopilotserverrpc"
16+
"github.com/lightninglabs/taproot-assets/fn"
1617
"github.com/lightningnetwork/lnd/tor"
1718
"google.golang.org/grpc"
1819
"google.golang.org/grpc/credentials"
@@ -88,8 +89,9 @@ type Client struct {
8889

8990
featurePerms *featurePerms
9091

91-
quit chan struct{}
92-
wg sync.WaitGroup
92+
quit chan struct{}
93+
wg sync.WaitGroup
94+
cancel fn.Option[context.CancelFunc]
9395
}
9496

9597
type session struct {
@@ -124,16 +126,19 @@ func NewClient(cfg *Config) (Autopilot, error) {
124126
}
125127

126128
// Start kicks off all the goroutines required by the Client.
127-
func (c *Client) Start(opts ...func(cfg *Config)) error {
129+
func (c *Client) Start(ctx context.Context, opts ...func(cfg *Config)) error {
128130
var startErr error
129131
c.start.Do(func() {
130132
log.Infof("Starting Autopilot Client")
131133

134+
ctx, cancel := context.WithCancel(ctx)
135+
c.cancel = fn.Some(cancel)
136+
132137
for _, o := range opts {
133138
o(c.cfg)
134139
}
135140

136-
version, err := c.getMinVersion(context.Background())
141+
version, err := c.getMinVersion(ctx)
137142
if err != nil {
138143
startErr = err
139144
return
@@ -154,8 +159,8 @@ func (c *Client) Start(opts ...func(cfg *Config)) error {
154159
}
155160

156161
c.wg.Add(2)
157-
go c.activateSessionsForever()
158-
go c.updateFeaturePermsForever()
162+
go c.activateSessionsForever(ctx)
163+
go c.updateFeaturePermsForever(ctx)
159164
})
160165

161166
return startErr
@@ -164,6 +169,7 @@ func (c *Client) Start(opts ...func(cfg *Config)) error {
164169
// Stop cleans up any resources or goroutines managed by the Client.
165170
func (c *Client) Stop() {
166171
c.stop.Do(func() {
172+
c.cancel.WhenSome(func(fn context.CancelFunc) { fn() })
167173
close(c.quit)
168174
c.wg.Wait()
169175
})
@@ -222,10 +228,9 @@ func (c *Client) SessionRevoked(ctx context.Context, pubKey *btcec.PublicKey) {
222228

223229
// activateSessionsForever periodically ensures that each of our active
224230
// autopilot sessions are known by the autopilot to be active.
225-
func (c *Client) activateSessionsForever() {
231+
func (c *Client) activateSessionsForever(ctx context.Context) {
226232
defer c.wg.Done()
227233

228-
ctx := context.Background()
229234
ticker := time.NewTicker(c.cfg.PingCadence)
230235
defer ticker.Stop()
231236

@@ -273,10 +278,9 @@ func (c *Client) activateSessionsForever() {
273278
// feature permissions list.
274279
//
275280
// NOTE: this MUST be called in a goroutine.
276-
func (c *Client) updateFeaturePermsForever() {
281+
func (c *Client) updateFeaturePermsForever(ctx context.Context) {
277282
defer c.wg.Done()
278283

279-
ctx := context.Background()
280284
ticker := time.NewTicker(time.Second)
281285
defer ticker.Stop()
282286

autopilotserver/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestAutopilotClient(t *testing.T) {
3232
PingCadence: time.Second,
3333
})
3434
require.NoError(t, err)
35-
require.NoError(t, client.Start())
35+
require.NoError(t, client.Start(ctx))
3636
t.Cleanup(client.Stop)
3737

3838
privKey, err := btcec.NewPrivateKey()

autopilotserver/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ type Autopilot interface {
4545
SessionRevoked(ctx context.Context, key *btcec.PublicKey)
4646

4747
// Start kicks off the goroutines of the client.
48-
Start(opts ...func(cfg *Config)) error
48+
Start(ctx context.Context, opts ...func(cfg *Config)) error
4949

5050
// Stop cleans up any resources held by the client.
5151
Stop()

0 commit comments

Comments
 (0)