Skip to content

Commit 9b86ee5

Browse files
committed
graph+autopilot: let autopilot use new graph ForEachNode method
Which passes a NodeRTx to the call-back instead of a `kvdb.RTx`.
1 parent 14cedef commit 9b86ee5

File tree

6 files changed

+96
-67
lines changed

6 files changed

+96
-67
lines changed

autopilot/graph.go

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"github.com/btcsuite/btcd/btcutil"
1111
graphdb "github.com/lightningnetwork/lnd/graph/db"
1212
"github.com/lightningnetwork/lnd/graph/db/models"
13-
"github.com/lightningnetwork/lnd/kvdb"
1413
"github.com/lightningnetwork/lnd/lnwire"
1514
"github.com/lightningnetwork/lnd/routing/route"
1615
)
@@ -51,11 +50,7 @@ func ChannelGraphFromDatabase(db *graphdb.ChannelGraph) ChannelGraph {
5150
// channeldb.LightningNode. The wrapper method implement the autopilot.Node
5251
// interface.
5352
type dbNode struct {
54-
db *graphdb.ChannelGraph
55-
56-
tx kvdb.RTx
57-
58-
node *models.LightningNode
53+
tx graphdb.NodeRTx
5954
}
6055

6156
// A compile time assertion to ensure dbNode meets the autopilot.Node
@@ -68,15 +63,15 @@ var _ Node = (*dbNode)(nil)
6863
//
6964
// NOTE: Part of the autopilot.Node interface.
7065
func (d *dbNode) PubKey() [33]byte {
71-
return d.node.PubKeyBytes
66+
return d.tx.Node().PubKeyBytes
7267
}
7368

7469
// Addrs returns a slice of publicly reachable public TCP addresses that the
7570
// peer is known to be listening on.
7671
//
7772
// NOTE: Part of the autopilot.Node interface.
7873
func (d *dbNode) Addrs() []net.Addr {
79-
return d.node.Addresses
74+
return d.tx.Node().Addresses
8075
}
8176

8277
// ForEachChannel is a higher-order function that will be used to iterate
@@ -86,43 +81,35 @@ func (d *dbNode) Addrs() []net.Addr {
8681
//
8782
// NOTE: Part of the autopilot.Node interface.
8883
func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
89-
return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes,
90-
func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep,
91-
_ *models.ChannelEdgePolicy) error {
92-
93-
// Skip channels for which no outgoing edge policy is
94-
// available.
95-
//
96-
// TODO(joostjager): Ideally the case where channels
97-
// have a nil policy should be supported, as autopilot
98-
// is not looking at the policies. For now, it is not
99-
// easily possible to get a reference to the other end
100-
// LightningNode object without retrieving the policy.
101-
if ep == nil {
102-
return nil
103-
}
84+
return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep,
85+
_ *models.ChannelEdgePolicy) error {
86+
87+
// Skip channels for which no outgoing edge policy is available.
88+
//
89+
// TODO(joostjager): Ideally the case where channels have a nil
90+
// policy should be supported, as autopilot is not looking at
91+
// the policies. For now, it is not easily possible to get a
92+
// reference to the other end LightningNode object without
93+
// retrieving the policy.
94+
if ep == nil {
95+
return nil
96+
}
10497

105-
node, err := d.db.FetchLightningNodeTx(
106-
tx, ep.ToNode,
107-
)
108-
if err != nil {
109-
return err
110-
}
98+
node, err := d.tx.FetchNode(ep.ToNode)
99+
if err != nil {
100+
return err
101+
}
111102

112-
edge := ChannelEdge{
113-
ChanID: lnwire.NewShortChanIDFromInt(
114-
ep.ChannelID,
115-
),
116-
Capacity: ei.Capacity,
117-
Peer: &dbNode{
118-
tx: tx,
119-
db: d.db,
120-
node: node,
121-
},
122-
}
103+
edge := ChannelEdge{
104+
ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID),
105+
Capacity: ei.Capacity,
106+
Peer: &dbNode{
107+
tx: node,
108+
},
109+
}
123110

124-
return cb(edge)
125-
})
111+
return cb(edge)
112+
})
126113
}
127114

128115
// ForEachNode is a higher-order function that should be called once for each
@@ -131,20 +118,16 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
131118
//
132119
// NOTE: Part of the autopilot.ChannelGraph interface.
133120
func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error {
134-
return d.db.ForEachNode(func(tx kvdb.RTx,
135-
n *models.LightningNode) error {
136-
121+
return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error {
137122
// We'll skip over any node that doesn't have any advertised
138123
// addresses. As we won't be able to reach them to actually
139124
// open any channels.
140-
if len(n.Addresses) == 0 {
125+
if len(nodeTx.Node().Addresses) == 0 {
141126
return nil
142127
}
143128

144129
node := &dbNode{
145-
db: d.db,
146-
tx: tx,
147-
node: n,
130+
tx: nodeTx,
148131
}
149132
return cb(node)
150133
})

autopilot/prefattach_test.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,18 +514,18 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey,
514514
return &ChannelEdge{
515515
ChanID: chanID,
516516
Capacity: capacity,
517-
Peer: &dbNode{
518-
db: d.db,
517+
Peer: &dbNode{tx: &testNodeTx{
518+
db: d,
519519
node: vertex1,
520-
},
520+
}},
521521
},
522522
&ChannelEdge{
523523
ChanID: chanID,
524524
Capacity: capacity,
525-
Peer: &dbNode{
526-
db: d.db,
525+
Peer: &dbNode{tx: &testNodeTx{
526+
db: d,
527527
node: vertex2,
528-
},
528+
}},
529529
},
530530
nil
531531
}
@@ -702,3 +702,37 @@ func (m *memChannelGraph) addRandNode() (*btcec.PublicKey, error) {
702702

703703
return newPub, nil
704704
}
705+
706+
type testNodeTx struct {
707+
db *testDBGraph
708+
node *models.LightningNode
709+
}
710+
711+
func (t *testNodeTx) Node() *models.LightningNode {
712+
return t.node
713+
}
714+
715+
func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo,
716+
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
717+
718+
return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func(_ kvdb.RTx,
719+
edge *models.ChannelEdgeInfo, policy1,
720+
policy2 *models.ChannelEdgePolicy) error {
721+
722+
return f(edge, policy1, policy2)
723+
})
724+
}
725+
726+
func (t *testNodeTx) FetchNode(pub route.Vertex) (graphdb.NodeRTx, error) {
727+
node, err := t.db.db.FetchLightningNode(pub)
728+
if err != nil {
729+
return nil, err
730+
}
731+
732+
return &testNodeTx{
733+
db: t.db,
734+
node: node,
735+
}, nil
736+
}
737+
738+
var _ graphdb.NodeRTx = (*testNodeTx)(nil)

graph/db/graph.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,9 +588,9 @@ func (c *ChannelGraph) FetchNodeFeatures(
588588
}
589589
}
590590

591-
// ForEachNodeCached is similar to ForEachNode, but it utilizes the channel
591+
// ForEachNodeCached is similar to forEachNode, but it utilizes the channel
592592
// graph cache instead. Note that this doesn't return all the information the
593-
// regular ForEachNode method does.
593+
// regular forEachNode method does.
594594
//
595595
// NOTE: The callback contents MUST not be modified.
596596
func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex,
@@ -604,7 +604,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex,
604604
// We'll iterate over each node, then the set of channels for each
605605
// node, and construct a similar callback functiopn signature as the
606606
// main funcotin expects.
607-
return c.ForEachNode(func(tx kvdb.RTx,
607+
return c.forEachNode(func(tx kvdb.RTx,
608608
node *models.LightningNode) error {
609609

610610
channels := make(map[uint64]*DirectedChannel)
@@ -716,11 +716,25 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
716716
// ForEachNode iterates through all the stored vertices/nodes in the graph,
717717
// executing the passed callback with each node encountered. If the callback
718718
// returns an error, then the transaction is aborted and the iteration stops
719+
// early. Any operations performed on the NodeTx passed to the call-back are
720+
// executed under the same read transaction and so, methods on the NodeTx object
721+
// _MUST_ only be called from within the call-back.
722+
func (c *ChannelGraph) ForEachNode(cb func(tx NodeRTx) error) error {
723+
return c.forEachNode(func(tx kvdb.RTx,
724+
node *models.LightningNode) error {
725+
726+
return cb(newChanGraphNodeTx(tx, c, node))
727+
})
728+
}
729+
730+
// forEachNode iterates through all the stored vertices/nodes in the graph,
731+
// executing the passed callback with each node encountered. If the callback
732+
// returns an error, then the transaction is aborted and the iteration stops
719733
// early.
720734
//
721735
// TODO(roasbeef): add iterator interface to allow for memory efficient graph
722736
// traversal when graph gets mega
723-
func (c *ChannelGraph) ForEachNode(
737+
func (c *ChannelGraph) forEachNode(
724738
cb func(kvdb.RTx, *models.LightningNode) error) error {
725739

726740
traversal := func(tx kvdb.RTx) error {

graph/db/graph_cache_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func TestGraphCacheAddNode(t *testing.T) {
121121
assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
122122

123123
// Now that we've inserted two nodes into the graph, check that
124-
// we'll recover the same set of channels during ForEachNode.
124+
// we'll recover the same set of channels during forEachNode.
125125
nodes := make(map[route.Vertex]struct{})
126126
chans := make(map[uint64]struct{})
127127
_ = cache.ForEachNode(func(node route.Vertex,

graph/db/graph_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ func TestGraphTraversalCacheable(t *testing.T) {
10921092
// Create a map of all nodes with the iteration we know works (because
10931093
// it is tested in another test).
10941094
nodeMap := make(map[route.Vertex]struct{})
1095-
err = graph.ForEachNode(
1095+
err = graph.forEachNode(
10961096
func(tx kvdb.RTx, n *models.LightningNode) error {
10971097
nodeMap[n.PubKeyBytes] = struct{}{}
10981098

@@ -1217,7 +1217,7 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes,
12171217

12181218
// Iterate over each node as returned by the graph, if all nodes are
12191219
// reached, then the map created above should be empty.
1220-
err := graph.ForEachNode(
1220+
err := graph.forEachNode(
12211221
func(_ kvdb.RTx, node *models.LightningNode) error {
12221222
delete(nodeIndex, node.Alias)
12231223
return nil
@@ -1329,7 +1329,7 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) {
13291329

13301330
func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) {
13311331
numNodes := 0
1332-
err := graph.ForEachNode(
1332+
err := graph.forEachNode(
13331333
func(_ kvdb.RTx, _ *models.LightningNode) error {
13341334
numNodes++
13351335
return nil

rpcserver.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6533,10 +6533,8 @@ func (r *rpcServer) DescribeGraph(ctx context.Context,
65336533
// First iterate through all the known nodes (connected or unconnected
65346534
// within the graph), collating their current state into the RPC
65356535
// response.
6536-
err := graph.ForEachNode(func(_ kvdb.RTx,
6537-
node *models.LightningNode) error {
6538-
6539-
lnNode := marshalNode(node)
6536+
err := graph.ForEachNode(func(nodeTx graphdb.NodeRTx) error {
6537+
lnNode := marshalNode(nodeTx.Node())
65406538

65416539
resp.Nodes = append(resp.Nodes, lnNode)
65426540

0 commit comments

Comments
 (0)