Skip to content

Commit 4bde8e2

Browse files
committed
graph/db: refactor and clean-up
Refactor channelIDToBytes to return a slice instead of an 8 byte array so that we dont need to use `[:]` everywhere. Also make sure we are using this helper everywhere.
1 parent 2310756 commit 4bde8e2

File tree

1 file changed

+31
-39
lines changed

1 file changed

+31
-39
lines changed

graph/db/sql_store.go

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,8 +1452,8 @@ func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
14521452
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
14531453
dbChans, err := db.GetPublicV1ChannelsBySCID(
14541454
ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
1455-
StartScid: chanIDStart[:],
1456-
EndScid: chanIDEnd[:],
1455+
StartScid: chanIDStart,
1456+
EndScid: chanIDEnd,
14571457
},
14581458
)
14591459
if err != nil {
@@ -1560,7 +1560,7 @@ func (s *SQLStore) MarkEdgeZombie(chanID uint64,
15601560
return db.UpsertZombieChannel(
15611561
ctx, sqlc.UpsertZombieChannelParams{
15621562
Version: int16(ProtocolV1),
1563-
Scid: chanIDB[:],
1563+
Scid: chanIDB,
15641564
NodeKey1: pubKey1[:],
15651565
NodeKey2: pubKey2[:],
15661566
},
@@ -1592,7 +1592,7 @@ func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
15921592
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
15931593
res, err := db.DeleteZombieChannel(
15941594
ctx, sqlc.DeleteZombieChannelParams{
1595-
Scid: chanIDB[:],
1595+
Scid: chanIDB,
15961596
Version: int16(ProtocolV1),
15971597
},
15981598
)
@@ -1644,7 +1644,7 @@ func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
16441644
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
16451645
zombie, err := db.GetZombieChannel(
16461646
ctx, sqlc.GetZombieChannelParams{
1647-
Scid: chanIDB[:],
1647+
Scid: chanIDB,
16481648
Version: int16(ProtocolV1),
16491649
},
16501650
)
@@ -1723,7 +1723,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
17231723

17241724
row, err := db.GetChannelBySCIDWithPolicies(
17251725
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
1726-
Scid: chanIDB[:],
1726+
Scid: chanIDB,
17271727
Version: int16(ProtocolV1),
17281728
},
17291729
)
@@ -1786,7 +1786,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
17861786
err = db.UpsertZombieChannel(
17871787
ctx, sqlc.UpsertZombieChannelParams{
17881788
Version: int16(ProtocolV1),
1789-
Scid: chanIDB[:],
1789+
Scid: chanIDB,
17901790
NodeKey1: nodeKey1[:],
17911791
NodeKey2: nodeKey2[:],
17921792
},
@@ -1833,14 +1833,12 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
18331833
ctx = context.TODO()
18341834
edge *models.ChannelEdgeInfo
18351835
policy1, policy2 *models.ChannelEdgePolicy
1836+
chanIDB = channelIDToBytes(chanID)
18361837
)
18371838
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1838-
var chanIDB [8]byte
1839-
byteOrder.PutUint64(chanIDB[:], chanID)
1840-
18411839
row, err := db.GetChannelBySCIDWithPolicies(
18421840
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
1843-
Scid: chanIDB[:],
1841+
Scid: chanIDB,
18441842
Version: int16(ProtocolV1),
18451843
},
18461844
)
@@ -1849,7 +1847,7 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
18491847
// index.
18501848
zombie, err := db.GetZombieChannel(
18511849
ctx, sqlc.GetZombieChannelParams{
1852-
Scid: chanIDB[:],
1850+
Scid: chanIDB,
18531851
Version: int16(ProtocolV1),
18541852
},
18551853
)
@@ -2033,21 +2031,19 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
20332031
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
20342032
}
20352033

2034+
chanIDB := channelIDToBytes(chanID)
20362035
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
2037-
var chanIDB [8]byte
2038-
byteOrder.PutUint64(chanIDB[:], chanID)
2039-
20402036
channel, err := db.GetChannelBySCID(
20412037
ctx, sqlc.GetChannelBySCIDParams{
2042-
Scid: chanIDB[:],
2038+
Scid: chanIDB,
20432039
Version: int16(ProtocolV1),
20442040
},
20452041
)
20462042
if errors.Is(err, sql.ErrNoRows) {
20472043
// Check if it is a zombie channel.
20482044
isZombie, err = db.IsZombieChannel(
20492045
ctx, sqlc.IsZombieChannelParams{
2050-
Scid: chanIDB[:],
2046+
Scid: chanIDB,
20512047
Version: int16(ProtocolV1),
20522048
},
20532049
)
@@ -2179,15 +2175,14 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
21792175
)
21802176
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
21812177
for _, chanID := range chanIDs {
2182-
var chanIDB [8]byte
2183-
byteOrder.PutUint64(chanIDB[:], chanID)
2178+
chanIDB := channelIDToBytes(chanID)
21842179

21852180
// TODO(elle): potentially optimize this by using
21862181
// sqlc.slice() once that works for both SQLite and
21872182
// Postgres.
21882183
row, err := db.GetChannelBySCIDWithPolicies(
21892184
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
2190-
Scid: chanIDB[:],
2185+
Scid: chanIDB,
21912186
Version: int16(ProtocolV1),
21922187
},
21932188
)
@@ -2270,16 +2265,15 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
22702265
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
22712266
for _, chanInfo := range chansInfo {
22722267
channelID := chanInfo.ShortChannelID.ToUint64()
2273-
var chanIDB [8]byte
2274-
byteOrder.PutUint64(chanIDB[:], channelID)
2268+
chanIDB := channelIDToBytes(channelID)
22752269

22762270
// TODO(elle): potentially optimize this by using
22772271
// sqlc.slice() once that works for both SQLite and
22782272
// Postgres.
22792273
_, err := db.GetChannelBySCID(
22802274
ctx, sqlc.GetChannelBySCIDParams{
22812275
Version: int16(ProtocolV1),
2282-
Scid: chanIDB[:],
2276+
Scid: chanIDB,
22832277
},
22842278
)
22852279
if err == nil {
@@ -2291,7 +2285,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
22912285

22922286
isZombie, err := db.IsZombieChannel(
22932287
ctx, sqlc.IsZombieChannelParams{
2294-
Scid: chanIDB[:],
2288+
Scid: chanIDB,
22952289
Version: int16(ProtocolV1),
22962290
},
22972291
)
@@ -2609,18 +2603,16 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
26092603
endShortChanID = aliasmgr.StartingAlias
26102604

26112605
removedChans []*models.ChannelEdgeInfo
2612-
)
26132606

2614-
var chanIDStart [8]byte
2615-
byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
2616-
var chanIDEnd [8]byte
2617-
byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
2607+
chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
2608+
chanIDEnd = channelIDToBytes(endShortChanID.ToUint64())
2609+
)
26182610

26192611
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
26202612
rows, err := db.GetChannelsBySCIDRange(
26212613
ctx, sqlc.GetChannelsBySCIDRangeParams{
2622-
StartScid: chanIDStart[:],
2623-
EndScid: chanIDEnd[:],
2614+
StartScid: chanIDStart,
2615+
EndScid: chanIDEnd,
26242616
},
26252617
)
26262618
if err != nil {
@@ -2688,7 +2680,7 @@ func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
26882680
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
26892681
res, err := db.AddV1ChannelProof(
26902682
ctx, sqlc.AddV1ChannelProofParams{
2691-
Scid: scidBytes[:],
2683+
Scid: scidBytes,
26922684
Node1Signature: proof.NodeSig1Bytes,
26932685
Node2Signature: proof.NodeSig2Bytes,
26942686
Bitcoin1Signature: proof.BitcoinSig1Bytes,
@@ -2734,7 +2726,7 @@ func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
27342726
)
27352727

27362728
return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
2737-
return db.InsertClosedChannel(ctx, chanIDB[:])
2729+
return db.InsertClosedChannel(ctx, chanIDB)
27382730
}, sqldb.NoOpReset)
27392731
}
27402732

@@ -2750,7 +2742,7 @@ func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
27502742
)
27512743
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
27522744
var err error
2753-
isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
2745+
isClosed, err = db.IsClosedChannel(ctx, chanIDB)
27542746
if err != nil {
27552747
return fmt.Errorf("unable to fetch closed channel: %w",
27562748
err)
@@ -3077,7 +3069,7 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
30773069
// abort the transaction which would abort the entire batch.
30783070
dbChan, err := tx.GetChannelAndNodesBySCID(
30793071
ctx, sqlc.GetChannelAndNodesBySCIDParams{
3080-
Scid: chanIDB[:],
3072+
Scid: chanIDB,
30813073
Version: int16(ProtocolV1),
30823074
},
30833075
)
@@ -3779,7 +3771,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
37793771
// batch of transactions.
37803772
_, err := db.GetChannelBySCID(
37813773
ctx, sqlc.GetChannelBySCIDParams{
3782-
Scid: chanIDB[:],
3774+
Scid: chanIDB,
37833775
Version: int16(ProtocolV1),
37843776
},
37853777
)
@@ -3808,7 +3800,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
38083800

38093801
createParams := sqlc.CreateChannelParams{
38103802
Version: int16(ProtocolV1),
3811-
Scid: chanIDB[:],
3803+
Scid: chanIDB,
38123804
NodeID1: node1DBID,
38133805
NodeID2: node2DBID,
38143806
Outpoint: edge.ChannelPoint.String(),
@@ -4455,9 +4447,9 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
44554447

44564448
// channelIDToBytes converts a channel ID (SCID) to a byte array
44574449
// representation.
4458-
func channelIDToBytes(channelID uint64) [8]byte {
4450+
func channelIDToBytes(channelID uint64) []byte {
44594451
var chanIDB [8]byte
44604452
byteOrder.PutUint64(chanIDB[:], channelID)
44614453

4462-
return chanIDB
4454+
return chanIDB[:]
44634455
}

0 commit comments

Comments
 (0)