Skip to content

Commit ff84fa1

Browse files
committed
graph/db+sqldb: impl ForEachNodeCached and ForEachChannel
Which let's us run `TestGraphTraversal` against our SQL backends.
1 parent 6aa2933 commit ff84fa1

File tree

5 files changed

+462
-1
lines changed

5 files changed

+462
-1
lines changed

graph/db/graph_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ func TestForEachSourceNodeChannel(t *testing.T) {
12771277
func TestGraphTraversal(t *testing.T) {
12781278
t.Parallel()
12791279

1280-
graph := MakeTestGraph(t)
1280+
graph := MakeTestGraphNew(t)
12811281

12821282
// We'd like to test some of the graph traversal capabilities within
12831283
// the DB, so we'll create a series of fake nodes to insert into the

graph/db/sql_store.go

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ type SQLQueries interface {
9090
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
9191
HighestSCID(ctx context.Context, version int16) ([]byte, error)
9292
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
93+
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
9394
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
9495

9596
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
@@ -1044,6 +1045,223 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime,
10441045
return edges, nil
10451046
}
10461047

1048+
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
1049+
// data to the call-back.
1050+
//
1051+
// NOTE: The callback contents MUST not be modified.
1052+
//
1053+
// NOTE: part of the V1Store interface.
1054+
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
1055+
chans map[uint64]*DirectedChannel) error) error {
1056+
1057+
var ctx = context.TODO()
1058+
1059+
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1060+
return forEachNodeCacheable(ctx, db, func(nodeID int64,
1061+
nodePub route.Vertex) error {
1062+
1063+
features, err := getNodeFeatures(ctx, db, nodeID)
1064+
if err != nil {
1065+
return fmt.Errorf("unable to fetch "+
1066+
"node(id=%d) features: %w", nodeID, err)
1067+
}
1068+
1069+
toNodeCallback := func() route.Vertex {
1070+
return nodePub
1071+
}
1072+
1073+
rows, err := db.ListChannelsByNodeID(
1074+
ctx, sqlc.ListChannelsByNodeIDParams{
1075+
Version: int16(ProtocolV1),
1076+
NodeID1: nodeID,
1077+
},
1078+
)
1079+
if err != nil {
1080+
return fmt.Errorf("unable to fetch channels "+
1081+
"of node(id=%d): %w", nodeID, err)
1082+
}
1083+
1084+
channels := make(map[uint64]*DirectedChannel, len(rows))
1085+
for _, row := range rows {
1086+
node1, node2, err := buildNodeVertices(
1087+
row.Node1Pubkey, row.Node2Pubkey,
1088+
)
1089+
if err != nil {
1090+
return err
1091+
}
1092+
1093+
e, err := getAndBuildEdgeInfo(
1094+
ctx, db, s.cfg.ChainHash,
1095+
row.Channel.ID, row.Channel, node1,
1096+
node2,
1097+
)
1098+
if err != nil {
1099+
return fmt.Errorf("unable to build "+
1100+
"channel info: %w", err)
1101+
}
1102+
1103+
dbPol1, dbPol2, err := extractChannelPolicies(
1104+
row,
1105+
)
1106+
if err != nil {
1107+
return fmt.Errorf("unable to "+
1108+
"extract channel "+
1109+
"policies: %w", err)
1110+
}
1111+
1112+
p1, p2, err := getAndBuildChanPolicies(
1113+
ctx, db, dbPol1, dbPol2, e.ChannelID,
1114+
node1, node2,
1115+
)
1116+
if err != nil {
1117+
return fmt.Errorf("unable to "+
1118+
"build channel policies: %w",
1119+
err)
1120+
}
1121+
1122+
// Determine the outgoing and incoming policy
1123+
// for this channel and node combo.
1124+
outPolicy, inPolicy := p1, p2
1125+
if p1 != nil && p1.ToNode == nodePub {
1126+
outPolicy, inPolicy = p2, p1
1127+
} else if p2 != nil && p2.ToNode != nodePub {
1128+
outPolicy, inPolicy = p2, p1
1129+
}
1130+
1131+
var cachedInPolicy *models.CachedEdgePolicy
1132+
if inPolicy != nil {
1133+
cachedInPolicy = models.NewCachedPolicy(
1134+
p2,
1135+
)
1136+
cachedInPolicy.ToNodePubKey =
1137+
toNodeCallback
1138+
cachedInPolicy.ToNodeFeatures =
1139+
features
1140+
}
1141+
1142+
var inboundFee lnwire.Fee
1143+
outPolicy.InboundFee.WhenSome(
1144+
func(fee lnwire.Fee) {
1145+
inboundFee = fee
1146+
},
1147+
)
1148+
1149+
directedChannel := &DirectedChannel{
1150+
ChannelID: e.ChannelID,
1151+
IsNode1: nodePub ==
1152+
e.NodeKey1Bytes,
1153+
OtherNode: e.NodeKey2Bytes,
1154+
Capacity: e.Capacity,
1155+
OutPolicySet: p1 != nil,
1156+
InPolicy: cachedInPolicy,
1157+
InboundFee: inboundFee,
1158+
}
1159+
1160+
if nodePub == e.NodeKey2Bytes {
1161+
directedChannel.OtherNode =
1162+
e.NodeKey1Bytes
1163+
}
1164+
1165+
channels[e.ChannelID] = directedChannel
1166+
}
1167+
1168+
return cb(nodePub, channels)
1169+
})
1170+
}, sqldb.NoOpReset)
1171+
}
1172+
1173+
// ForEachChannel iterates through all the channel edges stored within the
1174+
// graph and invokes the passed callback for each edge. The callback takes two
1175+
// edges as since this is a directed graph, both the in/out edges are visited.
1176+
// If the callback returns an error, then the transaction is aborted and the
1177+
// iteration stops early.
1178+
//
1179+
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
1180+
// for that particular channel edge routing policy will be passed into the
1181+
// callback.
1182+
//
1183+
// NOTE: part of the V1Store interface.
1184+
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
1185+
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
1186+
1187+
ctx := context.TODO()
1188+
1189+
handleChannel := func(db SQLQueries,
1190+
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
1191+
1192+
node1, node2, err := buildNodeVertices(
1193+
row.Node1Pubkey, row.Node2Pubkey,
1194+
)
1195+
if err != nil {
1196+
return fmt.Errorf("unable to build node vertices: %w",
1197+
err)
1198+
}
1199+
1200+
edge, err := getAndBuildEdgeInfo(
1201+
ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
1202+
node1, node2,
1203+
)
1204+
if err != nil {
1205+
return fmt.Errorf("unable to build channel info: %w",
1206+
err)
1207+
}
1208+
1209+
dbPol1, dbPol2, err := extractChannelPolicies(row)
1210+
if err != nil {
1211+
return fmt.Errorf("unable to extract channel "+
1212+
"policies: %w", err)
1213+
}
1214+
1215+
p1, p2, err := getAndBuildChanPolicies(
1216+
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
1217+
)
1218+
if err != nil {
1219+
return fmt.Errorf("unable to build channel "+
1220+
"policies: %w", err)
1221+
}
1222+
1223+
err = cb(edge, p1, p2)
1224+
if err != nil {
1225+
return fmt.Errorf("callback failed for channel "+
1226+
"id=%d: %w", edge.ChannelID, err)
1227+
}
1228+
1229+
return nil
1230+
}
1231+
1232+
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1233+
var lastID int64
1234+
for {
1235+
//nolint:ll
1236+
rows, err := db.ListChannelsWithPoliciesPaginated(
1237+
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
1238+
Version: int16(ProtocolV1),
1239+
ID: lastID,
1240+
Limit: pageSize,
1241+
},
1242+
)
1243+
if err != nil {
1244+
return err
1245+
}
1246+
1247+
if len(rows) == 0 {
1248+
break
1249+
}
1250+
1251+
for _, row := range rows {
1252+
err := handleChannel(db, row)
1253+
if err != nil {
1254+
return err
1255+
}
1256+
1257+
lastID = row.Channel.ID
1258+
}
1259+
}
1260+
1261+
return nil
1262+
}, sqldb.NoOpReset)
1263+
}
1264+
10471265
// forEachNodeDirectedChannel iterates through all channels of a given
10481266
// node, executing the passed callback on the directed edge representing the
10491267
// channel and its incoming policy. If the node is not found, no error is
@@ -2525,6 +2743,46 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
25252743
}
25262744
}
25272745

2746+
return policy1, policy2, nil
2747+
2748+
case sqlc.ListChannelsWithPoliciesPaginatedRow:
2749+
if r.Policy1ID.Valid {
2750+
policy1 = &sqlc.ChannelPolicy{
2751+
ID: r.Policy1ID.Int64,
2752+
Version: r.Policy1Version.Int16,
2753+
ChannelID: r.Channel.ID,
2754+
NodeID: r.Policy1NodeID.Int64,
2755+
Timelock: r.Policy1Timelock.Int32,
2756+
FeePpm: r.Policy1FeePpm.Int64,
2757+
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
2758+
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
2759+
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
2760+
LastUpdate: r.Policy1LastUpdate,
2761+
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
2762+
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
2763+
Disabled: r.Policy1Disabled,
2764+
Signature: r.Policy1Signature,
2765+
}
2766+
}
2767+
if r.Policy2ID.Valid {
2768+
policy2 = &sqlc.ChannelPolicy{
2769+
ID: r.Policy2ID.Int64,
2770+
Version: r.Policy2Version.Int16,
2771+
ChannelID: r.Channel.ID,
2772+
NodeID: r.Policy2NodeID.Int64,
2773+
Timelock: r.Policy2Timelock.Int32,
2774+
FeePpm: r.Policy2FeePpm.Int64,
2775+
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
2776+
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
2777+
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
2778+
LastUpdate: r.Policy2LastUpdate,
2779+
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
2780+
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
2781+
Disabled: r.Policy2Disabled,
2782+
Signature: r.Policy2Signature,
2783+
}
2784+
}
2785+
25282786
return policy1, policy2, nil
25292787
default:
25302788
return nil, nil, fmt.Errorf("unexpected row type in "+

0 commit comments

Comments
 (0)