Skip to content

Commit 5a8606c

Browse files
authored
Merge pull request #10007 from ellemouton/chanUpdateBitFields
graph/db: explicitly store bitfields for channel_update message & channel flags
2 parents 0ca9123 + c5b2e4e commit 5a8606c

File tree

9 files changed

+250
-41
lines changed

9 files changed

+250
-41
lines changed

graph/db/graph_test.go

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func createLightningNode(priv *btcec.PrivateKey) *models.LightningNode {
7777
n := &models.LightningNode{
7878
HaveNodeAnnouncement: true,
7979
AuthSigBytes: testSig.Serialize(),
80-
LastUpdate: time.Unix(nextUpdateTime(), 0),
80+
LastUpdate: nextUpdateTime(),
8181
Color: color.RGBA{1, 2, 3, 0},
8282
Alias: "kek" + hex.EncodeToString(pub),
8383
Features: testFeatures,
@@ -784,7 +784,7 @@ func createChannelEdge(node1, node2 *models.LightningNode,
784784
edge1 := &models.ChannelEdgePolicy{
785785
SigBytes: testSig.Serialize(),
786786
ChannelID: chanID,
787-
LastUpdate: time.Unix(433453, 0),
787+
LastUpdate: nextUpdateTime(),
788788
MessageFlags: 1,
789789
ChannelFlags: 0,
790790
TimeLockDelta: 99,
@@ -798,7 +798,7 @@ func createChannelEdge(node1, node2 *models.LightningNode,
798798
edge2 := &models.ChannelEdgePolicy{
799799
SigBytes: testSig.Serialize(),
800800
ChannelID: chanID,
801-
LastUpdate: time.Unix(124234, 0),
801+
LastUpdate: nextUpdateTime(),
802802
MessageFlags: 1,
803803
ChannelFlags: 1,
804804
TimeLockDelta: 99,
@@ -908,6 +908,72 @@ func TestEdgeInfoUpdates(t *testing.T) {
908908
assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo)
909909
}
910910

911+
// TestEdgePolicyCRUD tests basic CRUD operations for edge policies.
912+
func TestEdgePolicyCRUD(t *testing.T) {
913+
t.Parallel()
914+
ctx := context.Background()
915+
916+
graph := MakeTestGraph(t)
917+
918+
node1 := createTestVertex(t)
919+
node2 := createTestVertex(t)
920+
921+
// Create an edge. Don't add it to the DB yet.
922+
edgeInfo, edge1, edge2 := createChannelEdge(node1, node2)
923+
924+
updateAndAssertPolicies := func() {
925+
// Make copies of the policies before calling UpdateEdgePolicy
926+
// to avoid any data race's that can occur during async calls
927+
// that UpdateEdgePolicy may trigger.
928+
edge1 := copyEdgePolicy(edge1)
929+
edge2 := copyEdgePolicy(edge2)
930+
931+
edge1.LastUpdate = nextUpdateTime()
932+
edge2.LastUpdate = nextUpdateTime()
933+
934+
require.NoError(t, graph.UpdateEdgePolicy(ctx, edge1))
935+
require.NoError(t, graph.UpdateEdgePolicy(ctx, edge2))
936+
937+
// Use the ForEachChannel method to fetch the policies and
938+
// assert that the deserialized policies match the original
939+
// ones.
940+
err := graph.ForEachChannel(func(info *models.ChannelEdgeInfo,
941+
policy1 *models.ChannelEdgePolicy,
942+
policy2 *models.ChannelEdgePolicy) error {
943+
944+
require.NoError(t, compareEdgePolicies(edge1, policy1))
945+
require.NoError(t, compareEdgePolicies(edge2, policy2))
946+
947+
return nil
948+
})
949+
require.NoError(t, err)
950+
}
951+
952+
// Make sure inserting the policy at this point, before the edge info
953+
// is added, will fail.
954+
require.ErrorIs(t, graph.UpdateEdgePolicy(ctx, edge1), ErrEdgeNotFound)
955+
956+
// Now add the edge.
957+
require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo))
958+
959+
updateAndAssertPolicies()
960+
961+
// Update one of the edges to have no extra opaque data.
962+
edge1.ExtraOpaqueData = nil
963+
964+
updateAndAssertPolicies()
965+
966+
// Update one of the edges to have ChannelFlags include a bit unknown
967+
// to us.
968+
edge1.ChannelFlags |= 1 << 6
969+
970+
// Update the other edge to have MessageFlags include a bit unknown to
971+
// us.
972+
edge2.MessageFlags |= 1 << 4
973+
974+
updateAndAssertPolicies()
975+
}
976+
911977
func assertNodeInCache(t *testing.T, g *ChannelGraph, n *models.LightningNode,
912978
expectedFeatures *lnwire.FeatureVector) {
913979

@@ -3455,13 +3521,13 @@ var (
34553521
updateTimeMu sync.Mutex
34563522
)
34573523

3458-
func nextUpdateTime() int64 {
3524+
func nextUpdateTime() time.Time {
34593525
updateTimeMu.Lock()
34603526
defer updateTimeMu.Unlock()
34613527

34623528
updateTime++
34633529

3464-
return updateTime
3530+
return time.Unix(updateTime, 0)
34653531
}
34663532

34673533
// TestNodeIsPublic ensures that we properly detect nodes that are seen as
@@ -3506,7 +3572,7 @@ func TestNodeIsPublic(t *testing.T) {
35063572
graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph}
35073573
for _, graph := range graphs {
35083574
for _, node := range nodes {
3509-
node.LastUpdate = time.Unix(nextUpdateTime(), 0)
3575+
node.LastUpdate = nextUpdateTime()
35103576
err := graph.AddLightningNode(ctx, node)
35113577
require.NoError(t, err)
35123578
}
@@ -3621,7 +3687,7 @@ func TestDisabledChannelIDs(t *testing.T) {
36213687

36223688
// Adding a new channel edge to the graph.
36233689
edgeInfo, edge1, edge2 := createChannelEdge(node1, node2)
3624-
node2.LastUpdate = time.Unix(nextUpdateTime(), 0)
3690+
node2.LastUpdate = nextUpdateTime()
36253691
if err := graph.AddLightningNode(ctx, node2); err != nil {
36263692
t.Fatalf("unable to add node: %v", err)
36273693
}
@@ -4114,7 +4180,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) {
41144180

41154181
// Make sure inserting the policy at this point, before the edge info
41164182
// is added, will fail.
4117-
require.Error(t, ErrEdgeNotFound, graph.UpdateEdgePolicy(ctx, edge1))
4183+
require.ErrorIs(t, graph.UpdateEdgePolicy(ctx, edge1), ErrEdgeNotFound)
41184184

41194185
// Add the edge info.
41204186
require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo))

graph/db/sql_store.go

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,7 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12711271
var pol1, pol2 *models.CachedEdgePolicy
12721272
if dbPol1 != nil {
12731273
policy1, err := buildChanPolicy(
1274-
*dbPol1, edge.ChannelID, nil, node2, true,
1274+
*dbPol1, edge.ChannelID, nil, node2,
12751275
)
12761276
if err != nil {
12771277
return err
@@ -1281,7 +1281,7 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12811281
}
12821282
if dbPol2 != nil {
12831283
policy2, err := buildChanPolicy(
1284-
*dbPol2, edge.ChannelID, nil, node1, false,
1284+
*dbPol2, edge.ChannelID, nil, node1,
12851285
)
12861286
if err != nil {
12871287
return err
@@ -2900,7 +2900,7 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
29002900
var p1, p2 *models.CachedEdgePolicy
29012901
if dbPol1 != nil {
29022902
policy1, err := buildChanPolicy(
2903-
*dbPol1, edge.ChannelID, nil, node2, true,
2903+
*dbPol1, edge.ChannelID, nil, node2,
29042904
)
29052905
if err != nil {
29062906
return err
@@ -2910,7 +2910,7 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
29102910
}
29112911
if dbPol2 != nil {
29122912
policy2, err := buildChanPolicy(
2913-
*dbPol2, edge.ChannelID, nil, node1, false,
2913+
*dbPol2, edge.ChannelID, nil, node1,
29142914
)
29152915
if err != nil {
29162916
return err
@@ -3138,6 +3138,8 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
31383138
Valid: edge.MessageFlags.HasMaxHtlc(),
31393139
Int64: int64(edge.MaxHTLC),
31403140
},
3141+
MessageFlags: sqldb.SQLInt16(edge.MessageFlags),
3142+
ChannelFlags: sqldb.SQLInt16(edge.ChannelFlags),
31413143
InboundBaseFeeMsat: inboundBase,
31423144
InboundFeeRateMilliMsat: inboundRate,
31433145
Signature: edge.SigBytes,
@@ -4135,15 +4137,15 @@ func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
41354137
var pol1, pol2 *models.ChannelEdgePolicy
41364138
if dbPol1 != nil {
41374139
pol1, err = buildChanPolicy(
4138-
*dbPol1, channelID, dbPol1Extras, node2, true,
4140+
*dbPol1, channelID, dbPol1Extras, node2,
41394141
)
41404142
if err != nil {
41414143
return nil, nil, err
41424144
}
41434145
}
41444146
if dbPol2 != nil {
41454147
pol2, err = buildChanPolicy(
4146-
*dbPol2, channelID, dbPol2Extras, node1, false,
4148+
*dbPol2, channelID, dbPol2Extras, node1,
41474149
)
41484150
if err != nil {
41494151
return nil, nil, err
@@ -4156,28 +4158,15 @@ func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
41564158
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
41574159
// provided sqlc.ChannelPolicy and other required information.
41584160
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4159-
extras map[uint64][]byte, toNode route.Vertex,
4160-
isNode1 bool) (*models.ChannelEdgePolicy, error) {
4161+
extras map[uint64][]byte,
4162+
toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
41614163

41624164
recs, err := lnwire.CustomRecords(extras).Serialize()
41634165
if err != nil {
41644166
return nil, fmt.Errorf("unable to serialize extra signed "+
41654167
"fields: %w", err)
41664168
}
41674169

4168-
var msgFlags lnwire.ChanUpdateMsgFlags
4169-
if dbPolicy.MaxHtlcMsat.Valid {
4170-
msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
4171-
}
4172-
4173-
var chanFlags lnwire.ChanUpdateChanFlags
4174-
if !isNode1 {
4175-
chanFlags |= lnwire.ChanUpdateDirection
4176-
}
4177-
if dbPolicy.Disabled.Bool {
4178-
chanFlags |= lnwire.ChanUpdateDisabled
4179-
}
4180-
41814170
var inboundFee fn.Option[lnwire.Fee]
41824171
if dbPolicy.InboundFeeRateMilliMsat.Valid ||
41834172
dbPolicy.InboundBaseFeeMsat.Valid {
@@ -4194,8 +4183,12 @@ func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
41944183
LastUpdate: time.Unix(
41954184
dbPolicy.LastUpdate.Int64, 0,
41964185
),
4197-
MessageFlags: msgFlags,
4198-
ChannelFlags: chanFlags,
4186+
MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
4187+
dbPolicy.MessageFlags,
4188+
),
4189+
ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
4190+
dbPolicy.ChannelFlags,
4191+
),
41994192
TimeLockDelta: uint16(dbPolicy.Timelock),
42004193
MinHTLC: lnwire.MilliSatoshi(
42014194
dbPolicy.MinHtlcMsat,
@@ -4259,6 +4252,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
42594252
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
42604253
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
42614254
Disabled: r.Policy1Disabled,
4255+
MessageFlags: r.Policy1MessageFlags,
4256+
ChannelFlags: r.Policy1ChannelFlags,
42624257
Signature: r.Policy1Signature,
42634258
}
42644259
}
@@ -4277,6 +4272,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
42774272
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
42784273
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
42794274
Disabled: r.Policy2Disabled,
4275+
MessageFlags: r.Policy2MessageFlags,
4276+
ChannelFlags: r.Policy2ChannelFlags,
42804277
Signature: r.Policy2Signature,
42814278
}
42824279
}
@@ -4299,6 +4296,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
42994296
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
43004297
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
43014298
Disabled: r.Policy1Disabled,
4299+
MessageFlags: r.Policy1MessageFlags,
4300+
ChannelFlags: r.Policy1ChannelFlags,
43024301
Signature: r.Policy1Signature,
43034302
}
43044303
}
@@ -4317,6 +4316,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43174316
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
43184317
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
43194318
Disabled: r.Policy2Disabled,
4319+
MessageFlags: r.Policy2MessageFlags,
4320+
ChannelFlags: r.Policy2ChannelFlags,
43204321
Signature: r.Policy2Signature,
43214322
}
43224323
}
@@ -4339,6 +4340,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43394340
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
43404341
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
43414342
Disabled: r.Policy1Disabled,
4343+
MessageFlags: r.Policy1MessageFlags,
4344+
ChannelFlags: r.Policy1ChannelFlags,
43424345
Signature: r.Policy1Signature,
43434346
}
43444347
}
@@ -4357,6 +4360,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43574360
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
43584361
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
43594362
Disabled: r.Policy2Disabled,
4363+
MessageFlags: r.Policy2MessageFlags,
4364+
ChannelFlags: r.Policy2ChannelFlags,
43604365
Signature: r.Policy2Signature,
43614366
}
43624367
}
@@ -4379,6 +4384,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43794384
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
43804385
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
43814386
Disabled: r.Policy1Disabled,
4387+
MessageFlags: r.Policy1MessageFlags,
4388+
ChannelFlags: r.Policy1ChannelFlags,
43824389
Signature: r.Policy1Signature,
43834390
}
43844391
}
@@ -4397,6 +4404,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43974404
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
43984405
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
43994406
Disabled: r.Policy2Disabled,
4407+
MessageFlags: r.Policy2MessageFlags,
4408+
ChannelFlags: r.Policy2ChannelFlags,
44004409
Signature: r.Policy2Signature,
44014410
}
44024411
}
@@ -4419,6 +4428,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
44194428
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
44204429
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
44214430
Disabled: r.Policy1Disabled,
4431+
MessageFlags: r.Policy1MessageFlags,
4432+
ChannelFlags: r.Policy1ChannelFlags,
44224433
Signature: r.Policy1Signature,
44234434
}
44244435
}
@@ -4437,6 +4448,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
44374448
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
44384449
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
44394450
Disabled: r.Policy2Disabled,
4451+
MessageFlags: r.Policy2MessageFlags,
4452+
ChannelFlags: r.Policy2ChannelFlags,
44404453
Signature: r.Policy2Signature,
44414454
}
44424455
}

graph/db/test_postgres.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import (
1212
)
1313

1414
// NewTestDB is a helper function that creates a SQLStore backed by a postgres
15-
// database for testing. At the moment, it embeds a KVStore but once the
16-
// SQLStore fully implements the V1Store interface, the KVStore will be removed.
15+
// database for testing.
1716
func NewTestDB(t testing.TB) V1Store {
1817
pgFixture := sqldb.NewTestPgFixture(
1918
t, sqldb.DefaultPostgresFixtureLifetime,

graph/db/test_sqlite.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import (
1212
)
1313

1414
// NewTestDB is a helper function that creates a SQLStore backed by a sqlite
15-
// database for testing. At the moment, it embeds a KVStore but once the
16-
// SQLStore fully implements the V1Store interface, the KVStore will be removed.
15+
// database for testing.
1716
func NewTestDB(t testing.TB) V1Store {
1817
db := sqldb.NewTestSqliteDB(t).BaseDB
1918

0 commit comments

Comments
 (0)