Skip to content

Commit e724e1c

Browse files
committed
multi: thread context through to AddrsForNode
1 parent d1fa570 commit e724e1c

15 files changed

+88
-64
lines changed

chanbackup/backup.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package chanbackup
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/btcsuite/btcd/wire"
@@ -24,15 +25,17 @@ type LiveChannelSource interface {
2425
// passed open channel. The backup includes all information required to restore
2526
// the channel, as well as addressing information so we can find the peer and
2627
// reconnect to them to initiate the protocol.
27-
func assembleChanBackup(addrSource channeldb.AddrSource,
28+
func assembleChanBackup(ctx context.Context, addrSource channeldb.AddrSource,
2829
openChan *channeldb.OpenChannel) (*Single, error) {
2930

3031
log.Debugf("Crafting backup for ChannelPoint(%v)",
3132
openChan.FundingOutpoint)
3233

3334
// First, we'll query the channel source to obtain all the addresses
3435
// that are associated with the peer for this channel.
35-
known, nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub)
36+
known, nodeAddrs, err := addrSource.AddrsForNode(
37+
ctx, openChan.IdentityPub,
38+
)
3639
if err != nil {
3740
return nil, err
3841
}
@@ -90,7 +93,8 @@ func buildCloseTxInputs(
9093
// FetchBackupForChan attempts to create a plaintext static channel backup for
9194
// the target channel identified by its channel point. If we're unable to find
9295
// the target channel, then an error will be returned.
93-
func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
96+
func FetchBackupForChan(ctx context.Context, chanPoint wire.OutPoint,
97+
chanSource LiveChannelSource,
9498
addrSource channeldb.AddrSource) (*Single, error) {
9599

96100
// First, we'll query the channel source to see if the channel is known
@@ -104,7 +108,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
104108

105109
// Once we have the target channel, we can assemble the backup using
106110
// the source to obtain any extra information that we may need.
107-
staticChanBackup, err := assembleChanBackup(addrSource, targetChan)
111+
staticChanBackup, err := assembleChanBackup(ctx, addrSource, targetChan)
108112
if err != nil {
109113
return nil, fmt.Errorf("unable to create chan backup: %w", err)
110114
}
@@ -114,7 +118,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
114118

115119
// FetchStaticChanBackups will return a plaintext static channel back up for
116120
// all known active/open channels within the passed channel source.
117-
func FetchStaticChanBackups(chanSource LiveChannelSource,
121+
func FetchStaticChanBackups(ctx context.Context, chanSource LiveChannelSource,
118122
addrSource channeldb.AddrSource) ([]Single, error) {
119123

120124
// First, we'll query the backup source for information concerning all
@@ -129,7 +133,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource,
129133
// channel.
130134
staticChanBackups := make([]Single, 0, len(openChans))
131135
for _, openChan := range openChans {
132-
chanBackup, err := assembleChanBackup(addrSource, openChan)
136+
chanBackup, err := assembleChanBackup(ctx, addrSource, openChan)
133137
if err != nil {
134138
return nil, err
135139
}

chanbackup/backup_test.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package chanbackup
22

33
import (
4+
"context"
45
"fmt"
56
"net"
67
"testing"
@@ -61,8 +62,8 @@ func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []ne
6162
m.addrs[nodeKey] = addrs
6263
}
6364

64-
func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
65-
[]net.Addr, error) {
65+
func (m *mockChannelSource) AddrsForNode(_ context.Context,
66+
nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
6667

6768
if m.failQuery {
6869
return false, nil, fmt.Errorf("fail")
@@ -120,7 +121,8 @@ func TestFetchBackupForChan(t *testing.T) {
120121
}
121122
for i, testCase := range testCases {
122123
_, err := FetchBackupForChan(
123-
testCase.chanPoint, chanSource, chanSource,
124+
context.Background(), testCase.chanPoint, chanSource,
125+
chanSource,
124126
)
125127
switch {
126128
// If this is a valid test case, and we failed, then we'll
@@ -141,6 +143,7 @@ func TestFetchBackupForChan(t *testing.T) {
141143
// channel source for all channels and construct a Single for each channel.
142144
func TestFetchStaticChanBackups(t *testing.T) {
143145
t.Parallel()
146+
ctx := context.Background()
144147

145148
// First, we'll make the set of channels that we want to seed the
146149
// channel source with. Both channels will be fully populated in the
@@ -162,7 +165,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
162165
// With the channel source populated, we'll now attempt to create a set
163166
// of backups for all the channels. This should succeed, as all items
164167
// are populated within the channel source.
165-
backups, err := FetchStaticChanBackups(chanSource, chanSource)
168+
backups, err := FetchStaticChanBackups(ctx, chanSource, chanSource)
166169
require.NoError(t, err, "unable to create chan back ups")
167170

168171
if len(backups) != numChans {
@@ -177,7 +180,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
177180
copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
178181
delete(chanSource.addrs, n)
179182

180-
_, err = FetchStaticChanBackups(chanSource, chanSource)
183+
_, err = FetchStaticChanBackups(ctx, chanSource, chanSource)
181184
if err == nil {
182185
t.Fatalf("query with incomplete information should fail")
183186
}
@@ -186,7 +189,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
186189
// source at all, then we'll fail as well.
187190
chanSource = newMockChannelSource()
188191
chanSource.failQuery = true
189-
_, err = FetchStaticChanBackups(chanSource, chanSource)
192+
_, err = FetchStaticChanBackups(ctx, chanSource, chanSource)
190193
if err == nil {
191194
t.Fatalf("query should fail")
192195
}

chanbackup/pubsub.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package chanbackup
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
67
"net"
78
"os"
@@ -81,7 +82,8 @@ type ChannelNotifier interface {
8182
// synchronization point to ensure that the chanbackup.SubSwapper does
8283
// not miss any channel open or close events in the period between when
8384
// it's created, and when it requests the channel subscription.
84-
SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error)
85+
SubscribeChans(context.Context,
86+
map[wire.OutPoint]struct{}) (*ChannelSubscription, error)
8587
}
8688

8789
// SubSwapper subscribes to new updates to the open channel state, and then
@@ -119,16 +121,17 @@ type SubSwapper struct {
119121
// set of channels, and the required interfaces to be notified of new channel
120122
// updates, pack a multi backup, and swap the current best backup from its
121123
// storage location.
122-
func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier,
123-
keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) {
124+
func NewSubSwapper(ctx context.Context, startingChans []Single,
125+
chanNotifier ChannelNotifier, keyRing keychain.KeyRing,
126+
backupSwapper Swapper) (*SubSwapper, error) {
124127

125128
// First, we'll subscribe to the latest set of channel updates given
126129
// the set of channels we already know of.
127130
knownChans := make(map[wire.OutPoint]struct{})
128131
for _, chanBackup := range startingChans {
129132
knownChans[chanBackup.FundingOutpoint] = struct{}{}
130133
}
131-
chanEvents, err := chanNotifier.SubscribeChans(knownChans)
134+
chanEvents, err := chanNotifier.SubscribeChans(ctx, knownChans)
132135
if err != nil {
133136
return nil, err
134137
}

chanbackup/pubsub_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package chanbackup
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67
"time"
@@ -62,8 +63,8 @@ func newMockChannelNotifier() *mockChannelNotifier {
6263
}
6364
}
6465

65-
func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
66-
*ChannelSubscription, error) {
66+
func (m *mockChannelNotifier) SubscribeChans(_ context.Context,
67+
_ map[wire.OutPoint]struct{}) (*ChannelSubscription, error) {
6768

6869
if m.fail {
6970
return nil, fmt.Errorf("fail")
@@ -80,6 +81,7 @@ func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
8081
// channel subscription, then the entire sub-swapper will fail to start.
8182
func TestNewSubSwapperSubscribeFail(t *testing.T) {
8283
t.Parallel()
84+
ctx := context.Background()
8385

8486
keyRing := &lnencrypt.MockKeyRing{}
8587

@@ -88,10 +90,8 @@ func TestNewSubSwapperSubscribeFail(t *testing.T) {
8890
fail: true,
8991
}
9092

91-
_, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper)
92-
if err == nil {
93-
t.Fatalf("expected fail due to lack of subscription")
94-
}
93+
_, err := NewSubSwapper(ctx, nil, &chanNotifier, keyRing, &swapper)
94+
require.Error(t, err)
9595
}
9696

9797
func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper,
@@ -158,7 +158,9 @@ func TestSubSwapperIdempotentStartStop(t *testing.T) {
158158
var chanNotifier mockChannelNotifier
159159

160160
swapper := newMockSwapper(keyRing)
161-
subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, swapper)
161+
subSwapper, err := NewSubSwapper(
162+
context.Background(), nil, &chanNotifier, keyRing, swapper,
163+
)
162164
require.NoError(t, err, "unable to init subSwapper")
163165

164166
if err := subSwapper.Start(); err != nil {
@@ -224,7 +226,8 @@ func TestSubSwapperUpdater(t *testing.T) {
224226
// With our channel set created, we'll make a fresh sub swapper
225227
// instance to begin our test.
226228
subSwapper, err := NewSubSwapper(
227-
initialChanSet, chanNotifier, keyRing, swapper,
229+
context.Background(), initialChanSet, chanNotifier, keyRing,
230+
swapper,
228231
)
229232
require.NoError(t, err, "unable to make swapper")
230233
if err := subSwapper.Start(); err != nil {

channel_notifier.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package lnd
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/btcsuite/btcd/wire"
@@ -31,7 +32,8 @@ type channelNotifier struct {
3132
// the channel subscription.
3233
//
3334
// NOTE: This is part of the chanbackup.ChannelNotifier interface.
34-
func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{}) (
35+
func (c *channelNotifier) SubscribeChans(ctx context.Context,
36+
startingChans map[wire.OutPoint]struct{}) (
3537
*chanbackup.ChannelSubscription, error) {
3638

3739
ltndLog.Infof("Channel backup proxy channel notifier starting")
@@ -46,7 +48,7 @@ func (c *channelNotifier) SubscribeChans(startingChans map[wire.OutPoint]struct{
4648
// confirmed channels.
4749
sendChanOpenUpdate := func(newOrPendingChan *channeldb.OpenChannel) {
4850
_, nodeAddrs, err := c.addrs.AddrsForNode(
49-
newOrPendingChan.IdentityPub,
51+
ctx, newOrPendingChan.IdentityPub,
5052
)
5153
if err != nil {
5254
pub := newOrPendingChan.IdentityPub

channeldb/addr_source.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package channeldb
22

33
import (
4+
"context"
45
"errors"
56
"net"
67

@@ -13,7 +14,8 @@ type AddrSource interface {
1314
// AddrsForNode returns all known addresses for the target node public
1415
// key. The returned boolean must indicate if the given node is unknown
1516
// to the backing source.
16-
AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error)
17+
AddrsForNode(ctx context.Context,
18+
nodePub *btcec.PublicKey) (bool, []net.Addr, error)
1719
}
1820

1921
// multiAddrSource is an implementation of AddrSource which gathers all the
@@ -38,8 +40,8 @@ func NewMultiAddrSource(sources ...AddrSource) AddrSource {
3840
// node.
3941
//
4042
// NOTE: this implements the AddrSource interface.
41-
func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
42-
[]net.Addr, error) {
43+
func (c *multiAddrSource) AddrsForNode(ctx context.Context,
44+
nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
4345

4446
if len(c.sources) == 0 {
4547
return false, nil, errors.New("no address sources")
@@ -55,7 +57,7 @@ func (c *multiAddrSource) AddrsForNode(nodePub *btcec.PublicKey) (bool,
5557
// Iterate over all the address sources and query each one for the
5658
// addresses it has for the node in question.
5759
for _, src := range c.sources {
58-
isKnown, addrs, err := src.AddrsForNode(nodePub)
60+
isKnown, addrs, err := src.AddrsForNode(ctx, nodePub)
5961
if err != nil {
6062
return false, nil, err
6163
}

channeldb/addr_source_test.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package channeldb
22

33
import (
4+
"context"
45
"net"
56
"testing"
67

@@ -19,6 +20,7 @@ var (
1920
// deduplicates the results of a set of AddrSource implementations.
2021
func TestMultiAddrSource(t *testing.T) {
2122
t.Parallel()
23+
ctx := context.Background()
2224

2325
var pk1 = newTestPubKey(t)
2426

@@ -35,12 +37,12 @@ func TestMultiAddrSource(t *testing.T) {
3537
})
3638

3739
// Let source 1 know of 2 addresses (addr 1 and 2) for node 1.
38-
src1.On("AddrsForNode", pk1).Return(
40+
src1.On("AddrsForNode", ctx, pk1).Return(
3941
true, []net.Addr{addr1, addr2}, nil,
4042
).Once()
4143

4244
// Let source 2 know of 2 addresses (addr 2 and 3) for node 1.
43-
src2.On("AddrsForNode", pk1).Return(
45+
src2.On("AddrsForNode", ctx, pk1).Return(
4446
true, []net.Addr{addr2, addr3}, nil,
4547
[]net.Addr{addr2, addr3}, nil,
4648
).Once()
@@ -51,7 +53,7 @@ func TestMultiAddrSource(t *testing.T) {
5153

5254
// Query it for the addresses known for node 1. The results
5355
// should contain addr 1, 2 and 3.
54-
known, addrs, err := multiSrc.AddrsForNode(pk1)
56+
known, addrs, err := multiSrc.AddrsForNode(ctx, pk1)
5557
require.NoError(t, err)
5658
require.True(t, known)
5759
require.ElementsMatch(t, addrs, []net.Addr{addr1, addr2, addr3})
@@ -70,18 +72,18 @@ func TestMultiAddrSource(t *testing.T) {
7072
})
7173

7274
// Let source 1 know of address 1 for node 1.
73-
src1.On("AddrsForNode", pk1).Return(
75+
src1.On("AddrsForNode", ctx, pk1).Return(
7476
true, []net.Addr{addr1}, nil,
7577
).Once()
76-
src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once()
78+
src2.On("AddrsForNode", ctx, pk1).Return(false, nil, nil).Once()
7779

7880
// Create a multi-addr source that consists of both source 1
7981
// and 2.
8082
multiSrc := NewMultiAddrSource(src1, src2)
8183

8284
// Query it for the addresses known for node 1. The results
8385
// should contain addr 1.
84-
known, addrs, err := multiSrc.AddrsForNode(pk1)
86+
known, addrs, err := multiSrc.AddrsForNode(ctx, pk1)
8587
require.NoError(t, err)
8688
require.True(t, known)
8789
require.ElementsMatch(t, addrs, []net.Addr{addr1})
@@ -103,13 +105,13 @@ func TestMultiAddrSource(t *testing.T) {
103105
// and 2. Neither source known of node 1.
104106
multiSrc := NewMultiAddrSource(src1, src2)
105107

106-
src1.On("AddrsForNode", pk1).Return(false, nil, nil).Once()
107-
src2.On("AddrsForNode", pk1).Return(false, nil, nil).Once()
108+
src1.On("AddrsForNode", ctx, pk1).Return(false, nil, nil).Once()
109+
src2.On("AddrsForNode", ctx, pk1).Return(false, nil, nil).Once()
108110

109111
// Query it for the addresses known for node 1. It should return
110112
// false to indicate that the node is unknown to all backing
111113
// sources.
112-
known, addrs, err := multiSrc.AddrsForNode(pk1)
114+
known, addrs, err := multiSrc.AddrsForNode(ctx, pk1)
113115
require.NoError(t, err)
114116
require.False(t, known)
115117
require.Empty(t, addrs)
@@ -127,10 +129,10 @@ func newMockAddrSource(t *testing.T) *mockAddrSource {
127129
return &mockAddrSource{t: t}
128130
}
129131

130-
func (m *mockAddrSource) AddrsForNode(pub *btcec.PublicKey) (bool, []net.Addr,
131-
error) {
132+
func (m *mockAddrSource) AddrsForNode(ctx context.Context,
133+
pub *btcec.PublicKey) (bool, []net.Addr, error) {
132134

133-
args := m.Called(pub)
135+
args := m.Called(ctx, pub)
134136
if args.Get(1) == nil {
135137
return args.Bool(0), nil, args.Error(2)
136138
}

0 commit comments

Comments
 (0)