Skip to content

Commit d8a12a5

Browse files
authored
Merge pull request #10011 from ellemouton/graphRefactor
refactor+graph/db: refactor preparations required for incoming SQL migration code
2 parents cd7fa63 + 9284938 commit d8a12a5

File tree

2 files changed

+95
-74
lines changed

2 files changed

+95
-74
lines changed

graph/db/kv_store.go

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ func (c channelMapKey) String() string {
246246

247247
// getChannelMap loads all channel edge policies from the database and stores
248248
// them in a map.
249-
func (c *KVStore) getChannelMap(edges kvdb.RBucket) (
249+
func getChannelMap(edges kvdb.RBucket) (
250250
map[channelMapKey]*models.ChannelEdgePolicy, error) {
251251

252252
// Create a map to store all channel edge policies.
@@ -407,15 +407,30 @@ func (c *KVStore) AddrsForNode(ctx context.Context,
407407
func (c *KVStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
408408
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
409409

410-
return c.db.View(func(tx kvdb.RTx) error {
410+
return forEachChannel(c.db, cb)
411+
}
412+
413+
// forEachChannel iterates through all the channel edges stored within the
414+
// graph and invokes the passed callback for each edge. The callback takes two
415+
// edges as since this is a directed graph, both the in/out edges are visited.
416+
// If the callback returns an error, then the transaction is aborted and the
417+
// iteration stops early.
418+
//
419+
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
420+
// for that particular channel edge routing policy will be passed into the
421+
// callback.
422+
func forEachChannel(db kvdb.Backend, cb func(*models.ChannelEdgeInfo,
423+
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
424+
425+
return db.View(func(tx kvdb.RTx) error {
411426
edges := tx.ReadBucket(edgeBucket)
412427
if edges == nil {
413428
return ErrGraphNoEdgesFound
414429
}
415430

416431
// First, load all edges in memory indexed by node and channel
417432
// id.
418-
channelMap, err := c.getChannelMap(edges)
433+
channelMap, err := getChannelMap(edges)
419434
if err != nil {
420435
return err
421436
}
@@ -479,7 +494,7 @@ func (c *KVStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
479494

480495
// First, load all edges in memory indexed by node and channel
481496
// id.
482-
channelMap, err := c.getChannelMap(edges)
497+
channelMap, err := getChannelMap(edges)
483498
if err != nil {
484499
return err
485500
}
@@ -658,7 +673,7 @@ func (c *KVStore) ForEachNodeCached(cb func(node route.Vertex,
658673
// We'll iterate over each node, then the set of channels for each
659674
// node, and construct a similar callback functiopn signature as the
660675
// main funcotin expects.
661-
return c.forEachNode(func(tx kvdb.RTx,
676+
return forEachNode(c.db, func(tx kvdb.RTx,
662677
node *models.LightningNode) error {
663678

664679
channels := make(map[uint64]*DirectedChannel)
@@ -774,7 +789,7 @@ func (c *KVStore) DisabledChannelIDs() ([]uint64, error) {
774789
// executed under the same read transaction and so, methods on the NodeTx object
775790
// _MUST_ only be called from within the call-back.
776791
func (c *KVStore) ForEachNode(cb func(tx NodeRTx) error) error {
777-
return c.forEachNode(func(tx kvdb.RTx,
792+
return forEachNode(c.db, func(tx kvdb.RTx,
778793
node *models.LightningNode) error {
779794

780795
return cb(newChanGraphNodeTx(tx, c, node))
@@ -788,7 +803,7 @@ func (c *KVStore) ForEachNode(cb func(tx NodeRTx) error) error {
788803
//
789804
// TODO(roasbeef): add iterator interface to allow for memory efficient graph
790805
// traversal when graph gets mega.
791-
func (c *KVStore) forEachNode(
806+
func forEachNode(db kvdb.Backend,
792807
cb func(kvdb.RTx, *models.LightningNode) error) error {
793808

794809
traversal := func(tx kvdb.RTx) error {
@@ -819,7 +834,7 @@ func (c *KVStore) forEachNode(
819834
})
820835
}
821836

822-
return kvdb.View(c.db, traversal, func() {})
837+
return kvdb.View(db, traversal, func() {})
823838
}
824839

825840
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
@@ -866,19 +881,23 @@ func (c *KVStore) ForEachNodeCacheable(cb func(route.Vertex,
866881
// as the center node within a star-graph. This method may be used to kick off
867882
// a path finding algorithm in order to explore the reachability of another
868883
// node based off the source node.
869-
func (c *KVStore) SourceNode(_ context.Context) (*models.LightningNode,
870-
error) {
884+
func (c *KVStore) SourceNode(_ context.Context) (*models.LightningNode, error) {
885+
return sourceNode(c.db)
886+
}
871887

888+
// sourceNode fetches the source node of the graph. The source node is treated
889+
// as the center node within a star-graph.
890+
func sourceNode(db kvdb.Backend) (*models.LightningNode, error) {
872891
var source *models.LightningNode
873-
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
892+
err := kvdb.View(db, func(tx kvdb.RTx) error {
874893
// First grab the nodes bucket which stores the mapping from
875894
// pubKey to node information.
876895
nodes := tx.ReadBucket(nodeBucket)
877896
if nodes == nil {
878897
return ErrGraphNotFound
879898
}
880899

881-
node, err := c.sourceNode(nodes)
900+
node, err := sourceNodeWithTx(nodes)
882901
if err != nil {
883902
return err
884903
}
@@ -895,13 +914,11 @@ func (c *KVStore) SourceNode(_ context.Context) (*models.LightningNode,
895914
return source, nil
896915
}
897916

898-
// sourceNode uses an existing database transaction and returns the source node
899-
// of the graph. The source node is treated as the center node within a
917+
// sourceNodeWithTx uses an existing database transaction and returns the source
918+
// node of the graph. The source node is treated as the center node within a
900919
// star-graph. This method may be used to kick off a path finding algorithm in
901920
// order to explore the reachability of another node based off the source node.
902-
func (c *KVStore) sourceNode(nodes kvdb.RBucket) (*models.LightningNode,
903-
error) {
904-
921+
func sourceNodeWithTx(nodes kvdb.RBucket) (*models.LightningNode, error) {
905922
selfPub := nodes.Get(sourceKey)
906923
if selfPub == nil {
907924
return nil, ErrSourceNodeNotSet
@@ -1554,7 +1571,7 @@ func (c *KVStore) pruneGraphNodes(nodes kvdb.RwBucket,
15541571

15551572
// We'll retrieve the graph's source node to ensure we don't remove it
15561573
// even if it no longer has any open channels.
1557-
sourceNode, err := c.sourceNode(nodes)
1574+
sourceNode, err := sourceNodeWithTx(nodes)
15581575
if err != nil {
15591576
return nil, err
15601577
}
@@ -3240,7 +3257,7 @@ func (c *KVStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
32403257
return ErrGraphNotFound
32413258
}
32423259

3243-
node, err := c.sourceNode(nodes)
3260+
node, err := sourceNodeWithTx(nodes)
32443261
if err != nil {
32453262
return err
32463263
}

0 commit comments

Comments
 (0)