Skip to content

Commit 34cb6a9

Browse files
committed
graph/db: impl GraphSession
1 parent 933ab3c commit 34cb6a9

File tree

2 files changed

+82
-12
lines changed

2 files changed

+82
-12
lines changed

graph/db/sql_store.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,6 +2674,63 @@ func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
26742674
return isClosed, nil
26752675
}
26762676

2677+
// GraphSession will provide the call-back with access to a NodeTraverser
2678+
// instance which can be used to perform queries against the channel graph.
2679+
//
2680+
// NOTE: part of the V1Store interface.
2681+
func (s *SQLStore) GraphSession(cb func(graph NodeTraverser) error) error {
2682+
var ctx = context.TODO()
2683+
2684+
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2685+
return cb(newSQLNodeTraverser(db, s.cfg.ChainHash))
2686+
}, sqldb.NoOpReset)
2687+
}
2688+
2689+
// sqlNodeTraverser implements the NodeTraverser interface but with a backing
2690+
// read only transaction for a consistent view of the graph.
2691+
type sqlNodeTraverser struct {
2692+
db SQLQueries
2693+
chain chainhash.Hash
2694+
}
2695+
2696+
// A compile-time assertion to ensure that sqlNodeTraverser implements the
2697+
// NodeTraverser interface.
2698+
var _ NodeTraverser = (*sqlNodeTraverser)(nil)
2699+
2700+
// newSQLNodeTraverser creates a new instance of the sqlNodeTraverser.
2701+
func newSQLNodeTraverser(db SQLQueries,
2702+
chain chainhash.Hash) *sqlNodeTraverser {
2703+
2704+
return &sqlNodeTraverser{
2705+
db: db,
2706+
chain: chain,
2707+
}
2708+
}
2709+
2710+
// ForEachNodeDirectedChannel calls the callback for every channel of the given
2711+
// node.
2712+
//
2713+
// NOTE: Part of the NodeTraverser interface.
2714+
func (s *sqlNodeTraverser) ForEachNodeDirectedChannel(nodePub route.Vertex,
2715+
cb func(channel *DirectedChannel) error) error {
2716+
2717+
ctx := context.TODO()
2718+
2719+
return forEachNodeDirectedChannel(ctx, s.db, nodePub, cb)
2720+
}
2721+
2722+
// FetchNodeFeatures returns the features of the given node. If the node is
2723+
// unknown, assume no additional features are supported.
2724+
//
2725+
// NOTE: Part of the NodeTraverser interface.
2726+
func (s *sqlNodeTraverser) FetchNodeFeatures(nodePub route.Vertex) (
2727+
*lnwire.FeatureVector, error) {
2728+
2729+
ctx := context.TODO()
2730+
2731+
return fetchNodeFeatures(ctx, s.db, nodePub)
2732+
}
2733+
26772734
// forEachNodeDirectedChannel iterates through all channels of a given
26782735
// node, executing the passed callback on the directed edge representing the
26792736
// channel and its incoming policy. If the node is not found, no error is

routing/pathfind_test.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ func makeTestGraph(t *testing.T, useCache bool) (*graphdb.ChannelGraph,
160160
kvdb.Backend, error) {
161161

162162
// Create channelgraph for the first time.
163-
graph := graphdb.MakeTestGraph(t, graphdb.WithUseGraphCache(useCache))
163+
graph := graphdb.MakeTestGraphNew(
164+
t, graphdb.WithUseGraphCache(useCache),
165+
)
164166
require.NoError(t, graph.Start())
165167
t.Cleanup(func() {
166168
require.NoError(t, graph.Stop())
@@ -289,6 +291,11 @@ func parseTestGraph(t *testing.T, useCache bool, path string) (
289291
}
290292

291293
source = dbNode
294+
295+
// If this is the source node, we don't have to call
296+
// AddLightningNode below since we will call
297+
// SetSourceNode later.
298+
continue
292299
}
293300

294301
// With the node fully parsed, add it as a vertex within the
@@ -540,8 +547,8 @@ func createTestGraphFromChannels(t *testing.T, useCache bool,
540547
privKeyMap := make(map[string]*btcec.PrivateKey)
541548

542549
nodeIndex := byte(0)
543-
addNodeWithAlias := func(alias string, features *lnwire.FeatureVector) (
544-
*models.LightningNode, error) {
550+
addNodeWithAlias := func(alias string,
551+
features *lnwire.FeatureVector) error {
545552

546553
keyBytes := []byte{
547554
0, 0, 0, 0, 0, 0, 0, 0,
@@ -571,29 +578,29 @@ func createTestGraphFromChannels(t *testing.T, useCache bool,
571578

572579
// With the node fully parsed, add it as a vertex within the
573580
// graph.
574-
if err := graph.AddLightningNode(ctx, dbNode); err != nil {
575-
return nil, err
581+
if alias == source {
582+
err = graph.SetSourceNode(ctx, dbNode)
583+
require.NoError(t, err)
584+
} else {
585+
err := graph.AddLightningNode(ctx, dbNode)
586+
require.NoError(t, err)
576587
}
577588

578589
aliasMap[alias] = dbNode.PubKeyBytes
579590
nodeIndex++
580591

581-
return dbNode, nil
592+
return nil
582593
}
583594

584595
// Add the source node.
585-
dbNode, err := addNodeWithAlias(
596+
err = addNodeWithAlias(
586597
source, lnwire.NewFeatureVector(
587598
lnwire.NewRawFeatureVector(sourceFeatureBits...),
588599
lnwire.Features,
589600
),
590601
)
591602
require.NoError(t, err)
592603

593-
if err = graph.SetSourceNode(ctx, dbNode); err != nil {
594-
return nil, err
595-
}
596-
597604
// Initialize variable that keeps track of the next channel id to assign
598605
// if none is specified.
599606
nextUnassignedChannelID := uint64(100000)
@@ -611,7 +618,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool,
611618
features =
612619
node.testChannelPolicy.Features
613620
}
614-
_, err := addNodeWithAlias(
621+
err := addNodeWithAlias(
615622
node.Alias, features,
616623
)
617624
if err != nil {
@@ -2157,6 +2164,7 @@ func runRouteFailMaxHTLC(t *testing.T, useCache bool) {
21572164
require.NoError(t, err, "unable to fetch channel edges by ID")
21582165
midEdge.MessageFlags = 1
21592166
midEdge.MaxHTLC = payAmt - 1
2167+
midEdge.LastUpdate = midEdge.LastUpdate.Add(time.Second)
21602168
err = graph.UpdateEdgePolicy(context.Background(), midEdge)
21612169
require.NoError(t, err)
21622170

@@ -2199,10 +2207,12 @@ func runRouteFailDisabledEdge(t *testing.T, useCache bool) {
21992207
_, e1, e2, err := graph.graph.FetchChannelEdgesByID(roasToPham)
22002208
require.NoError(t, err, "unable to fetch edge")
22012209
e1.ChannelFlags |= lnwire.ChanUpdateDisabled
2210+
e1.LastUpdate = e1.LastUpdate.Add(time.Second)
22022211
if err := graph.graph.UpdateEdgePolicy(ctx, e1); err != nil {
22032212
t.Fatalf("unable to update edge: %v", err)
22042213
}
22052214
e2.ChannelFlags |= lnwire.ChanUpdateDisabled
2215+
e2.LastUpdate = e2.LastUpdate.Add(time.Second)
22062216
if err := graph.graph.UpdateEdgePolicy(ctx, e2); err != nil {
22072217
t.Fatalf("unable to update edge: %v", err)
22082218
}
@@ -2220,6 +2230,7 @@ func runRouteFailDisabledEdge(t *testing.T, useCache bool) {
22202230
_, e, _, err := graph.graph.FetchChannelEdgesByID(phamToSophon)
22212231
require.NoError(t, err, "unable to fetch edge")
22222232
e.ChannelFlags |= lnwire.ChanUpdateDisabled
2233+
e.LastUpdate = e.LastUpdate.Add(time.Second)
22232234
if err := graph.graph.UpdateEdgePolicy(ctx, e); err != nil {
22242235
t.Fatalf("unable to update edge: %v", err)
22252236
}
@@ -2302,10 +2313,12 @@ func runPathSourceEdgesBandwidth(t *testing.T, useCache bool) {
23022313
_, e1, e2, err := graph.graph.FetchChannelEdgesByID(roasToSongoku)
23032314
require.NoError(t, err, "unable to fetch edge")
23042315
e1.ChannelFlags |= lnwire.ChanUpdateDisabled
2316+
e1.LastUpdate = e1.LastUpdate.Add(time.Second)
23052317
if err := graph.graph.UpdateEdgePolicy(ctx, e1); err != nil {
23062318
t.Fatalf("unable to update edge: %v", err)
23072319
}
23082320
e2.ChannelFlags |= lnwire.ChanUpdateDisabled
2321+
e2.LastUpdate = e2.LastUpdate.Add(time.Second)
23092322
if err := graph.graph.UpdateEdgePolicy(ctx, e2); err != nil {
23102323
t.Fatalf("unable to update edge: %v", err)
23112324
}

0 commit comments

Comments
 (0)