Skip to content

Commit a27bd69

Browse files
authored
Merge pull request #9956 from ellemouton/chanGraphContext
multi: add `context.Context` param to some `graphdb.V1Store` methods
2 parents c1740c1 + e724e1c commit a27bd69

27 files changed

+306
-222
lines changed

autopilot/prefattach_test.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) {
395395
func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey,
396396
capacity btcutil.Amount) (*ChannelEdge, *ChannelEdge, error) {
397397

398+
ctx := context.Background()
399+
398400
fetchNode := func(pub *btcec.PublicKey) (*models.LightningNode, error) {
399401
if pub != nil {
400402
vertex, err := route.NewVertexFromBytes(
@@ -404,7 +406,7 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey,
404406
return nil, err
405407
}
406408

407-
dbNode, err := d.db.FetchLightningNode(vertex)
409+
dbNode, err := d.db.FetchLightningNode(ctx, vertex)
408410
switch {
409411
case errors.Is(err, graphdb.ErrGraphNodeNotFound):
410412
fallthrough
@@ -422,7 +424,9 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey,
422424
AuthSigBytes: testSig.Serialize(),
423425
}
424426
graphNode.AddPubKey(pub)
425-
err := d.db.AddLightningNode(graphNode)
427+
err := d.db.AddLightningNode(
428+
context.Background(), graphNode,
429+
)
426430
if err != nil {
427431
return nil, err
428432
}
@@ -450,7 +454,9 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey,
450454
AuthSigBytes: testSig.Serialize(),
451455
}
452456
dbNode.AddPubKey(nodeKey)
453-
if err := d.db.AddLightningNode(dbNode); err != nil {
457+
if err := d.db.AddLightningNode(
458+
context.Background(), dbNode,
459+
); err != nil {
454460
return nil, err
455461
}
456462

@@ -554,7 +560,8 @@ func (d *testDBGraph) addRandNode() (*btcec.PublicKey, error) {
554560
AuthSigBytes: testSig.Serialize(),
555561
}
556562
dbNode.AddPubKey(nodeKey)
557-
if err := d.db.AddLightningNode(dbNode); err != nil {
563+
err = d.db.AddLightningNode(context.Background(), dbNode)
564+
if err != nil {
558565
return nil, err
559566
}
560567

@@ -732,7 +739,7 @@ func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo,
732739
}
733740

734741
func (t *testNodeTx) FetchNode(pub route.Vertex) (graphdb.NodeRTx, error) {
735-
node, err := t.db.db.FetchLightningNode(pub)
742+
node, err := t.db.db.FetchLightningNode(context.Background(), pub)
736743
if err != nil {
737744
return nil, err
738745
}

chanbackup/backup.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package chanbackup
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/btcsuite/btcd/wire"
@@ -24,15 +25,17 @@ type LiveChannelSource interface {
2425
// passed open channel. The backup includes all information required to restore
2526
// the channel, as well as addressing information so we can find the peer and
2627
// reconnect to them to initiate the protocol.
27-
func assembleChanBackup(addrSource channeldb.AddrSource,
28+
func assembleChanBackup(ctx context.Context, addrSource channeldb.AddrSource,
2829
openChan *channeldb.OpenChannel) (*Single, error) {
2930

3031
log.Debugf("Crafting backup for ChannelPoint(%v)",
3132
openChan.FundingOutpoint)
3233

3334
// First, we'll query the channel source to obtain all the addresses
3435
// that are associated with the peer for this channel.
35-
known, nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub)
36+
known, nodeAddrs, err := addrSource.AddrsForNode(
37+
ctx, openChan.IdentityPub,
38+
)
3639
if err != nil {
3740
return nil, err
3841
}
@@ -90,7 +93,8 @@ func buildCloseTxInputs(
9093
// FetchBackupForChan attempts to create a plaintext static channel backup for
9194
// the target channel identified by its channel point. If we're unable to find
9295
// the target channel, then an error will be returned.
93-
func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
96+
func FetchBackupForChan(ctx context.Context, chanPoint wire.OutPoint,
97+
chanSource LiveChannelSource,
9498
addrSource channeldb.AddrSource) (*Single, error) {
9599

96100
// First, we'll query the channel source to see if the channel is known
@@ -104,7 +108,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
104108

105109
// Once we have the target channel, we can assemble the backup using
106110
// the source to obtain any extra information that we may need.
107-
staticChanBackup, err := assembleChanBackup(addrSource, targetChan)
111+
staticChanBackup, err := assembleChanBackup(ctx, addrSource, targetChan)
108112
if err != nil {
109113
return nil, fmt.Errorf("unable to create chan backup: %w", err)
110114
}
@@ -114,7 +118,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
114118

115119
// FetchStaticChanBackups will return a plaintext static channel back up for
116120
// all known active/open channels within the passed channel source.
117-
func FetchStaticChanBackups(chanSource LiveChannelSource,
121+
func FetchStaticChanBackups(ctx context.Context, chanSource LiveChannelSource,
118122
addrSource channeldb.AddrSource) ([]Single, error) {
119123

120124
// First, we'll query the backup source for information concerning all
@@ -129,7 +133,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource,
129133
// channel.
130134
staticChanBackups := make([]Single, 0, len(openChans))
131135
for _, openChan := range openChans {
132-
chanBackup, err := assembleChanBackup(addrSource, openChan)
136+
chanBackup, err := assembleChanBackup(ctx, addrSource, openChan)
133137
if err != nil {
134138
return nil, err
135139
}

chanbackup/backup_test.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package chanbackup
22

33
import (
4+
"context"
45
"fmt"
56
"net"
67
"testing"
@@ -61,8 +62,8 @@ func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []ne
6162
m.addrs[nodeKey] = addrs
6263
}
6364

64-
func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
65-
[]net.Addr, error) {
65+
func (m *mockChannelSource) AddrsForNode(_ context.Context,
66+
nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
6667

6768
if m.failQuery {
6869
return false, nil, fmt.Errorf("fail")
@@ -120,7 +121,8 @@ func TestFetchBackupForChan(t *testing.T) {
120121
}
121122
for i, testCase := range testCases {
122123
_, err := FetchBackupForChan(
123-
testCase.chanPoint, chanSource, chanSource,
124+
context.Background(), testCase.chanPoint, chanSource,
125+
chanSource,
124126
)
125127
switch {
126128
// If this is a valid test case, and we failed, then we'll
@@ -141,6 +143,7 @@ func TestFetchBackupForChan(t *testing.T) {
141143
// channel source for all channels and construct a Single for each channel.
142144
func TestFetchStaticChanBackups(t *testing.T) {
143145
t.Parallel()
146+
ctx := context.Background()
144147

145148
// First, we'll make the set of channels that we want to seed the
146149
// channel source with. Both channels will be fully populated in the
@@ -162,7 +165,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
162165
// With the channel source populated, we'll now attempt to create a set
163166
// of backups for all the channels. This should succeed, as all items
164167
// are populated within the channel source.
165-
backups, err := FetchStaticChanBackups(chanSource, chanSource)
168+
backups, err := FetchStaticChanBackups(ctx, chanSource, chanSource)
166169
require.NoError(t, err, "unable to create chan back ups")
167170

168171
if len(backups) != numChans {
@@ -177,7 +180,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
177180
copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
178181
delete(chanSource.addrs, n)
179182

180-
_, err = FetchStaticChanBackups(chanSource, chanSource)
183+
_, err = FetchStaticChanBackups(ctx, chanSource, chanSource)
181184
if err == nil {
182185
t.Fatalf("query with incomplete information should fail")
183186
}
@@ -186,7 +189,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
186189
// source at all, then we'll fail as well.
187190
chanSource = newMockChannelSource()
188191
chanSource.failQuery = true
189-
_, err = FetchStaticChanBackups(chanSource, chanSource)
192+
_, err = FetchStaticChanBackups(ctx, chanSource, chanSource)
190193
if err == nil {
191194
t.Fatalf("query should fail")
192195
}

chanbackup/pubsub.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package chanbackup
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
67
"net"
78
"os"
@@ -81,7 +82,8 @@ type ChannelNotifier interface {
8182
// synchronization point to ensure that the chanbackup.SubSwapper does
8283
// not miss any channel open or close events in the period between when
8384
// it's created, and when it requests the channel subscription.
84-
SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error)
85+
SubscribeChans(context.Context,
86+
map[wire.OutPoint]struct{}) (*ChannelSubscription, error)
8587
}
8688

8789
// SubSwapper subscribes to new updates to the open channel state, and then
@@ -119,16 +121,17 @@ type SubSwapper struct {
119121
// set of channels, and the required interfaces to be notified of new channel
120122
// updates, pack a multi backup, and swap the current best backup from its
121123
// storage location.
122-
func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier,
123-
keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) {
124+
func NewSubSwapper(ctx context.Context, startingChans []Single,
125+
chanNotifier ChannelNotifier, keyRing keychain.KeyRing,
126+
backupSwapper Swapper) (*SubSwapper, error) {
124127

125128
// First, we'll subscribe to the latest set of channel updates given
126129
// the set of channels we already know of.
127130
knownChans := make(map[wire.OutPoint]struct{})
128131
for _, chanBackup := range startingChans {
129132
knownChans[chanBackup.FundingOutpoint] = struct{}{}
130133
}
131-
chanEvents, err := chanNotifier.SubscribeChans(knownChans)
134+
chanEvents, err := chanNotifier.SubscribeChans(ctx, knownChans)
132135
if err != nil {
133136
return nil, err
134137
}

chanbackup/pubsub_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package chanbackup
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67
"time"
@@ -62,8 +63,8 @@ func newMockChannelNotifier() *mockChannelNotifier {
6263
}
6364
}
6465

65-
func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
66-
*ChannelSubscription, error) {
66+
func (m *mockChannelNotifier) SubscribeChans(_ context.Context,
67+
_ map[wire.OutPoint]struct{}) (*ChannelSubscription, error) {
6768

6869
if m.fail {
6970
return nil, fmt.Errorf("fail")
@@ -80,6 +81,7 @@ func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
8081
// channel subscription, then the entire sub-swapper will fail to start.
8182
func TestNewSubSwapperSubscribeFail(t *testing.T) {
8283
t.Parallel()
84+
ctx := context.Background()
8385

8486
keyRing := &lnencrypt.MockKeyRing{}
8587

@@ -88,10 +90,8 @@ func TestNewSubSwapperSubscribeFail(t *testing.T) {
8890
fail: true,
8991
}
9092

91-
_, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper)
92-
if err == nil {
93-
t.Fatalf("expected fail due to lack of subscription")
94-
}
93+
_, err := NewSubSwapper(ctx, nil, &chanNotifier, keyRing, &swapper)
94+
require.Error(t, err)
9595
}
9696

9797
func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper,
@@ -158,7 +158,9 @@ func TestSubSwapperIdempotentStartStop(t *testing.T) {
158158
var chanNotifier mockChannelNotifier
159159

160160
swapper := newMockSwapper(keyRing)
161-
subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, swapper)
161+
subSwapper, err := NewSubSwapper(
162+
context.Background(), nil, &chanNotifier, keyRing, swapper,
163+
)
162164
require.NoError(t, err, "unable to init subSwapper")
163165

164166
if err := subSwapper.Start(); err != nil {
@@ -224,7 +226,8 @@ func TestSubSwapperUpdater(t *testing.T) {
224226
// With our channel set created, we'll make a fresh sub swapper
225227
// instance to begin our test.
226228
subSwapper, err := NewSubSwapper(
227-
initialChanSet, chanNotifier, keyRing, swapper,
229+
context.Background(), initialChanSet, chanNotifier, keyRing,
230+
swapper,
228231
)
229232
require.NoError(t, err, "unable to make swapper")
230233
if err := subSwapper.Start(); err != nil {

channel_notifier.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package lnd
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/btcsuite/btcd/wire"
@@ -31,7 +32,8 @@ type channelNotifier struct {
3132
// the channel subscription.
3233
//
3334
// NOTE: This is part of the chanbackup.ChannelNotifier interface.
34-
func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{}) (
35+
func (c *channelNotifier) SubscribeChans(ctx context.Context,
36+
startingChans map[wire.OutPoint]struct{}) (
3537
*chanbackup.ChannelSubscription, error) {
3638

3739
ltndLog.Infof("Channel backup proxy channel notifier starting")
@@ -46,7 +48,7 @@ func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{
4648
// confirmed channels.
4749
sendChanOpenUpdate := func(newOrPendingChan *channeldb.OpenChannel) {
4850
_, nodeAddrs, err := c.addrs.AddrsForNode(
49-
newOrPendingChan.IdentityPub,
51+
ctx, newOrPendingChan.IdentityPub,
5052
)
5153
if err != nil {
5254
pub := newOrPendingChan.IdentityPub

channeldb/addr_source.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package channeldb
22

33
import (
4+
"context"
45
"errors"
56
"net"
67

@@ -13,7 +14,8 @@ type AddrSource interface {
1314
// AddrsForNode returns all known addresses for the target node public
1415
// key. The returned boolean must indicate if the given node is unknown
1516
// to the backing source.
16-
AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error)
17+
AddrsForNode(ctx context.Context,
18+
nodePub *btcec.PublicKey) (bool, []net.Addr, error)
1719
}
1820

1921
// multiAddrSource is an implementation of AddrSource which gathers all the
@@ -38,8 +40,8 @@ func NewMultiAddrSource(sources ...AddrSource) AddrSource {
3840
// node.
3941
//
4042
// NOTE: this implements the AddrSource interface.
41-
func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
42-
[]net.Addr, error) {
43+
func (c *multiAddrSource) AddrsForNode(ctx context.Context,
44+
nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
4345

4446
if len(c.sources) == 0 {
4547
return false, nil, errors.New("no address sources")
@@ -55,7 +57,7 @@ func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
5557
// Iterate over all the address sources and query each one for the
5658
// addresses it has for the node in question.
5759
for _, src := range c.sources {
58-
isKnown, addrs, err := src.AddrsForNode(nodePub)
60+
isKnown, addrs, err := src.AddrsForNode(ctx, nodePub)
5961
if err != nil {
6062
return false, nil, err
6163
}

0 commit comments

Comments
 (0)