Skip to content

Commit e4137a3

Browse files
committed
graph/db: fix ChanUpdate message/channel flag bug
Here we start using the newly added message_flags and channel_flags columns of the channel_policies table. The test added previoulsy to demonstrate the bug is now updated to show that the bug has been fixed.
1 parent 4a05e5a commit e4137a3

File tree

3 files changed

+47
-47
lines changed

3 files changed

+47
-47
lines changed

graph/db/graph_test.go

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -909,10 +909,6 @@ func TestEdgeInfoUpdates(t *testing.T) {
909909
}
910910

911911
// TestEdgePolicyCRUD tests basic CRUD operations for edge policies.
912-
//
913-
// NOTE: this currently demonstrates a bug in the SQL backend where
914-
// Channel Flags and Message Flags are not properly stored. This will be fixed
915-
// in an upcoming commit.
916912
func TestEdgePolicyCRUD(t *testing.T) {
917913
t.Parallel()
918914
ctx := context.Background()
@@ -925,7 +921,7 @@ func TestEdgePolicyCRUD(t *testing.T) {
925921
// Create an edge. Don't add it to the DB yet.
926922
edgeInfo, edge1, edge2 := createChannelEdge(node1, node2)
927923

928-
updateAndAssertPolicies := func(expErr bool) {
924+
updateAndAssertPolicies := func() {
929925
// Make copies of the policies before calling UpdateEdgePolicy
930926
// to avoid any data race's that can occur during async calls
931927
// that UpdateEdgePolicy may trigger.
@@ -945,17 +941,6 @@ func TestEdgePolicyCRUD(t *testing.T) {
945941
policy1 *models.ChannelEdgePolicy,
946942
policy2 *models.ChannelEdgePolicy) error {
947943

948-
if expErr {
949-
require.Error(
950-
t, compareEdgePolicies(edge1, policy1),
951-
)
952-
require.Error(
953-
t, compareEdgePolicies(edge2, policy2),
954-
)
955-
956-
return nil
957-
}
958-
959944
require.NoError(t, compareEdgePolicies(edge1, policy1))
960945
require.NoError(t, compareEdgePolicies(edge2, policy2))
961946

@@ -971,12 +956,12 @@ func TestEdgePolicyCRUD(t *testing.T) {
971956
// Now add the edge.
972957
require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo))
973958

974-
updateAndAssertPolicies(false)
959+
updateAndAssertPolicies()
975960

976961
// Update one of the edges to have no extra opaque data.
977962
edge1.ExtraOpaqueData = nil
978963

979-
updateAndAssertPolicies(false)
964+
updateAndAssertPolicies()
980965

981966
// Update one of the edges to have ChannelFlags include a bit unknown
982967
// to us.
@@ -986,12 +971,7 @@ func TestEdgePolicyCRUD(t *testing.T) {
986971
// us.
987972
edge2.MessageFlags |= 1 << 4
988973

989-
// NOTE: If the backend is SQL, then we expect an error here as
990-
// there is currently a bug in the SQL backend where
991-
// ChannelFlags and MessageFlags are not properly stored. This will
992-
// be fixed in an upcoming commit.
993-
_, isSQLImp := graph.V1Store.(*SQLStore)
994-
updateAndAssertPolicies(isSQLImp)
974+
updateAndAssertPolicies()
995975
}
996976

997977
func assertNodeInCache(t *testing.T, g *ChannelGraph, n *models.LightningNode,

graph/db/sql_store.go

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,7 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12711271
var pol1, pol2 *models.CachedEdgePolicy
12721272
if dbPol1 != nil {
12731273
policy1, err := buildChanPolicy(
1274-
*dbPol1, edge.ChannelID, nil, node2, true,
1274+
*dbPol1, edge.ChannelID, nil, node2,
12751275
)
12761276
if err != nil {
12771277
return err
@@ -1281,7 +1281,7 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12811281
}
12821282
if dbPol2 != nil {
12831283
policy2, err := buildChanPolicy(
1284-
*dbPol2, edge.ChannelID, nil, node1, false,
1284+
*dbPol2, edge.ChannelID, nil, node1,
12851285
)
12861286
if err != nil {
12871287
return err
@@ -2900,7 +2900,7 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
29002900
var p1, p2 *models.CachedEdgePolicy
29012901
if dbPol1 != nil {
29022902
policy1, err := buildChanPolicy(
2903-
*dbPol1, edge.ChannelID, nil, node2, true,
2903+
*dbPol1, edge.ChannelID, nil, node2,
29042904
)
29052905
if err != nil {
29062906
return err
@@ -2910,7 +2910,7 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
29102910
}
29112911
if dbPol2 != nil {
29122912
policy2, err := buildChanPolicy(
2913-
*dbPol2, edge.ChannelID, nil, node1, false,
2913+
*dbPol2, edge.ChannelID, nil, node1,
29142914
)
29152915
if err != nil {
29162916
return err
@@ -3138,6 +3138,8 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
31383138
Valid: edge.MessageFlags.HasMaxHtlc(),
31393139
Int64: int64(edge.MaxHTLC),
31403140
},
3141+
MessageFlags: sqldb.SQLInt16(edge.MessageFlags),
3142+
ChannelFlags: sqldb.SQLInt16(edge.ChannelFlags),
31413143
InboundBaseFeeMsat: inboundBase,
31423144
InboundFeeRateMilliMsat: inboundRate,
31433145
Signature: edge.SigBytes,
@@ -4135,15 +4137,15 @@ func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
41354137
var pol1, pol2 *models.ChannelEdgePolicy
41364138
if dbPol1 != nil {
41374139
pol1, err = buildChanPolicy(
4138-
*dbPol1, channelID, dbPol1Extras, node2, true,
4140+
*dbPol1, channelID, dbPol1Extras, node2,
41394141
)
41404142
if err != nil {
41414143
return nil, nil, err
41424144
}
41434145
}
41444146
if dbPol2 != nil {
41454147
pol2, err = buildChanPolicy(
4146-
*dbPol2, channelID, dbPol2Extras, node1, false,
4148+
*dbPol2, channelID, dbPol2Extras, node1,
41474149
)
41484150
if err != nil {
41494151
return nil, nil, err
@@ -4156,28 +4158,15 @@ func getAndBuildChanPolicies(ctx context.Context, db SQLQueries,
41564158
// buildChanPolicy builds a models.ChannelEdgePolicy instance from the
41574159
// provided sqlc.ChannelPolicy and other required information.
41584160
func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
4159-
extras map[uint64][]byte, toNode route.Vertex,
4160-
isNode1 bool) (*models.ChannelEdgePolicy, error) {
4161+
extras map[uint64][]byte,
4162+
toNode route.Vertex) (*models.ChannelEdgePolicy, error) {
41614163

41624164
recs, err := lnwire.CustomRecords(extras).Serialize()
41634165
if err != nil {
41644166
return nil, fmt.Errorf("unable to serialize extra signed "+
41654167
"fields: %w", err)
41664168
}
41674169

4168-
var msgFlags lnwire.ChanUpdateMsgFlags
4169-
if dbPolicy.MaxHtlcMsat.Valid {
4170-
msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
4171-
}
4172-
4173-
var chanFlags lnwire.ChanUpdateChanFlags
4174-
if !isNode1 {
4175-
chanFlags |= lnwire.ChanUpdateDirection
4176-
}
4177-
if dbPolicy.Disabled.Bool {
4178-
chanFlags |= lnwire.ChanUpdateDisabled
4179-
}
4180-
41814170
var inboundFee fn.Option[lnwire.Fee]
41824171
if dbPolicy.InboundFeeRateMilliMsat.Valid ||
41834172
dbPolicy.InboundBaseFeeMsat.Valid {
@@ -4194,8 +4183,12 @@ func buildChanPolicy(dbPolicy sqlc.ChannelPolicy, channelID uint64,
41944183
LastUpdate: time.Unix(
41954184
dbPolicy.LastUpdate.Int64, 0,
41964185
),
4197-
MessageFlags: msgFlags,
4198-
ChannelFlags: chanFlags,
4186+
MessageFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateMsgFlags](
4187+
dbPolicy.MessageFlags,
4188+
),
4189+
ChannelFlags: sqldb.ExtractSqlInt16[lnwire.ChanUpdateChanFlags](
4190+
dbPolicy.ChannelFlags,
4191+
),
41994192
TimeLockDelta: uint16(dbPolicy.Timelock),
42004193
MinHTLC: lnwire.MilliSatoshi(
42014194
dbPolicy.MinHtlcMsat,
@@ -4259,6 +4252,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
42594252
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
42604253
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
42614254
Disabled: r.Policy1Disabled,
4255+
MessageFlags: r.Policy1MessageFlags,
4256+
ChannelFlags: r.Policy1ChannelFlags,
42624257
Signature: r.Policy1Signature,
42634258
}
42644259
}
@@ -4277,6 +4272,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
42774272
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
42784273
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
42794274
Disabled: r.Policy2Disabled,
4275+
MessageFlags: r.Policy2MessageFlags,
4276+
ChannelFlags: r.Policy2ChannelFlags,
42804277
Signature: r.Policy2Signature,
42814278
}
42824279
}
@@ -4299,6 +4296,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
42994296
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
43004297
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
43014298
Disabled: r.Policy1Disabled,
4299+
MessageFlags: r.Policy1MessageFlags,
4300+
ChannelFlags: r.Policy1ChannelFlags,
43024301
Signature: r.Policy1Signature,
43034302
}
43044303
}
@@ -4317,6 +4316,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43174316
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
43184317
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
43194318
Disabled: r.Policy2Disabled,
4319+
MessageFlags: r.Policy2MessageFlags,
4320+
ChannelFlags: r.Policy2ChannelFlags,
43204321
Signature: r.Policy2Signature,
43214322
}
43224323
}
@@ -4339,6 +4340,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43394340
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
43404341
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
43414342
Disabled: r.Policy1Disabled,
4343+
MessageFlags: r.Policy1MessageFlags,
4344+
ChannelFlags: r.Policy1ChannelFlags,
43424345
Signature: r.Policy1Signature,
43434346
}
43444347
}
@@ -4357,6 +4360,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43574360
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
43584361
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
43594362
Disabled: r.Policy2Disabled,
4363+
MessageFlags: r.Policy2MessageFlags,
4364+
ChannelFlags: r.Policy2ChannelFlags,
43604365
Signature: r.Policy2Signature,
43614366
}
43624367
}
@@ -4379,6 +4384,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43794384
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
43804385
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
43814386
Disabled: r.Policy1Disabled,
4387+
MessageFlags: r.Policy1MessageFlags,
4388+
ChannelFlags: r.Policy1ChannelFlags,
43824389
Signature: r.Policy1Signature,
43834390
}
43844391
}
@@ -4397,6 +4404,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
43974404
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
43984405
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
43994406
Disabled: r.Policy2Disabled,
4407+
MessageFlags: r.Policy2MessageFlags,
4408+
ChannelFlags: r.Policy2ChannelFlags,
44004409
Signature: r.Policy2Signature,
44014410
}
44024411
}
@@ -4419,6 +4428,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
44194428
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
44204429
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
44214430
Disabled: r.Policy1Disabled,
4431+
MessageFlags: r.Policy1MessageFlags,
4432+
ChannelFlags: r.Policy1ChannelFlags,
44224433
Signature: r.Policy1Signature,
44234434
}
44244435
}
@@ -4437,6 +4448,8 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
44374448
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
44384449
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
44394450
Disabled: r.Policy2Disabled,
4451+
MessageFlags: r.Policy2MessageFlags,
4452+
ChannelFlags: r.Policy2ChannelFlags,
44404453
Signature: r.Policy2Signature,
44414454
}
44424455
}

sqldb/sqlutils.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,10 @@ func SQLTime(t time.Time) sql.NullTime {
6868
Valid: true,
6969
}
7070
}
71+
72+
// ExtractSqlInt16 turns a NullInt16 into a numerical type. This can be useful
73+
// when reading directly from the database, as this function handles extracting
74+
// the inner value from the "option"-like struct.
75+
func ExtractSqlInt16[T constraints.Integer](num sql.NullInt16) T {
76+
return T(num.Int16)
77+
}

0 commit comments

Comments
 (0)