Skip to content

Commit 4fad4a7

Browse files
committed
graph/db+sqldb: implement FetchChannelEdgesByOutpoint/SCID
And run `TestEdgeInsertionDeletion` against our SQL backends.
1 parent 4335d9c commit 4fad4a7

File tree

5 files changed

+412
-2
lines changed

5 files changed

+412
-2
lines changed

graph/db/graph_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ func TestEdgeInsertionDeletion(t *testing.T) {
392392
t.Parallel()
393393
ctx := context.Background()
394394

395-
graph := MakeTestGraph(t)
395+
graph := MakeTestGraphNew(t)
396396

397397
// We'd like to test the insertion/deletion of edges, so we create two
398398
// vertexes to connect.

graph/db/sql_store.go

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ type SQLQueries interface {
9595
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
9696
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
9797
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
98+
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
9899
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
99100
DeleteChannel(ctx context.Context, id int64) error
100101

@@ -118,6 +119,7 @@ type SQLQueries interface {
118119
GetZombieChannel(ctx context.Context, arg sqlc.GetZombieChannelParams) (sqlc.ZombieChannel, error)
119120
CountZombieChannels(ctx context.Context, version int16) (int64, error)
120121
DeleteZombieChannel(ctx context.Context, arg sqlc.DeleteZombieChannelParams) (sql.Result, error)
122+
IsZombieChannel(ctx context.Context, arg sqlc.IsZombieChannelParams) (bool, error)
121123
}
122124

123125
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
@@ -1671,6 +1673,166 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
16711673
return deleted, nil
16721674
}
16731675

1676+
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
1677+
// channel identified by the channel ID. If the channel can't be found, then
1678+
// ErrEdgeNotFound is returned. A struct which houses the general information
1679+
// for the channel itself is returned as well as two structs that contain the
1680+
// routing policies for the channel in either direction.
1681+
//
1682+
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
1683+
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
1684+
// the ChannelEdgeInfo will only include the public keys of each node.
1685+
//
1686+
// NOTE: part of the V1Store interface.
1687+
func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
1688+
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1689+
*models.ChannelEdgePolicy, error) {
1690+
1691+
var (
1692+
ctx = context.TODO()
1693+
edge *models.ChannelEdgeInfo
1694+
policy1, policy2 *models.ChannelEdgePolicy
1695+
)
1696+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1697+
var chanIDB [8]byte
1698+
byteOrder.PutUint64(chanIDB[:], chanID)
1699+
1700+
row, err := db.GetChannelBySCIDWithPolicies(
1701+
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
1702+
Scid: chanIDB[:],
1703+
Version: int16(ProtocolV1),
1704+
},
1705+
)
1706+
if errors.Is(err, sql.ErrNoRows) {
1707+
// First check if this edge is perhaps in the zombie
1708+
// index.
1709+
isZombie, err := db.IsZombieChannel(
1710+
ctx, sqlc.IsZombieChannelParams{
1711+
Scid: chanIDB[:],
1712+
Version: int16(ProtocolV1),
1713+
},
1714+
)
1715+
if err != nil {
1716+
return fmt.Errorf("unable to check if "+
1717+
"channel is zombie: %w", err)
1718+
} else if isZombie {
1719+
return ErrZombieEdge
1720+
}
1721+
1722+
return ErrEdgeNotFound
1723+
} else if err != nil {
1724+
return fmt.Errorf("unable to fetch channel: %w", err)
1725+
}
1726+
1727+
node1, node2, err := buildNodeVertices(
1728+
row.Node.PubKey, row.Node_2.PubKey,
1729+
)
1730+
if err != nil {
1731+
return err
1732+
}
1733+
1734+
edge, err = getAndBuildEdgeInfo(
1735+
ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
1736+
node1, node2,
1737+
)
1738+
if err != nil {
1739+
return fmt.Errorf("unable to build channel info: %w",
1740+
err)
1741+
}
1742+
1743+
dbPol1, dbPol2, err := extractChannelPolicies(row)
1744+
if err != nil {
1745+
return fmt.Errorf("unable to extract channel "+
1746+
"policies: %w", err)
1747+
}
1748+
1749+
policy1, policy2, err = getAndBuildChanPolicies(
1750+
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
1751+
)
1752+
if err != nil {
1753+
return fmt.Errorf("unable to build channel "+
1754+
"policies: %w", err)
1755+
}
1756+
1757+
return nil
1758+
}, sqldb.NoOpReset)
1759+
if err != nil {
1760+
return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
1761+
err)
1762+
}
1763+
1764+
return edge, policy1, policy2, nil
1765+
}
1766+
1767+
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
1768+
// the channel identified by the funding outpoint. If the channel can't be
1769+
// found, then ErrEdgeNotFound is returned. A struct which houses the general
1770+
// information for the channel itself is returned as well as two structs that
1771+
// contain the routing policies for the channel in either direction.
1772+
//
1773+
// NOTE: part of the V1Store interface.
1774+
func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
1775+
*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
1776+
*models.ChannelEdgePolicy, error) {
1777+
1778+
var (
1779+
ctx = context.TODO()
1780+
edge *models.ChannelEdgeInfo
1781+
policy1, policy2 *models.ChannelEdgePolicy
1782+
)
1783+
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1784+
row, err := db.GetChannelByOutpointWithPolicies(
1785+
ctx, sqlc.GetChannelByOutpointWithPoliciesParams{
1786+
Outpoint: op.String(),
1787+
Version: int16(ProtocolV1),
1788+
},
1789+
)
1790+
if errors.Is(err, sql.ErrNoRows) {
1791+
return ErrEdgeNotFound
1792+
} else if err != nil {
1793+
return fmt.Errorf("unable to fetch channel: %w", err)
1794+
}
1795+
1796+
node1, node2, err := buildNodeVertices(
1797+
row.Node1Pubkey, row.Node2Pubkey,
1798+
)
1799+
if err != nil {
1800+
return err
1801+
}
1802+
1803+
edge, err = getAndBuildEdgeInfo(
1804+
ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
1805+
node1, node2,
1806+
)
1807+
if err != nil {
1808+
return fmt.Errorf("unable to build channel info: %w",
1809+
err)
1810+
}
1811+
1812+
dbPol1, dbPol2, err := extractChannelPolicies(row)
1813+
if err != nil {
1814+
return fmt.Errorf("unable to extract channel "+
1815+
"policies: %w", err)
1816+
}
1817+
1818+
policy1, policy2, err = getAndBuildChanPolicies(
1819+
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
1820+
)
1821+
if err != nil {
1822+
return fmt.Errorf("unable to build channel "+
1823+
"policies: %w", err)
1824+
}
1825+
1826+
return nil
1827+
}, sqldb.NoOpReset)
1828+
if err != nil {
1829+
return nil, nil, nil, fmt.Errorf("could not fetch channel: %w",
1830+
err)
1831+
}
1832+
1833+
return edge, policy1, policy2, nil
1834+
}
1835+
16741836
// forEachNodeDirectedChannel iterates through all channels of a given
16751837
// node, executing the passed callback on the directed edge representing the
16761838
// channel and its incoming policy. If the node is not found, no error is
@@ -3066,12 +3228,52 @@ func buildNodes(ctx context.Context, db SQLQueries, dbNode1,
30663228
// information. It returns two policies, which may be nil if the policy
30673229
// information is not present in the row.
30683230
//
3069-
//nolint:ll,dupl
3231+
//nolint:ll,dupl,funlen
30703232
func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
30713233
error) {
30723234

30733235
var policy1, policy2 *sqlc.ChannelPolicy
30743236
switch r := row.(type) {
3237+
case sqlc.GetChannelBySCIDWithPoliciesRow:
3238+
if r.Policy1ID.Valid {
3239+
policy1 = &sqlc.ChannelPolicy{
3240+
ID: r.Policy1ID.Int64,
3241+
Version: r.Policy1Version.Int16,
3242+
ChannelID: r.Channel.ID,
3243+
NodeID: r.Policy1NodeID.Int64,
3244+
Timelock: r.Policy1Timelock.Int32,
3245+
FeePpm: r.Policy1FeePpm.Int64,
3246+
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
3247+
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
3248+
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
3249+
LastUpdate: r.Policy1LastUpdate,
3250+
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
3251+
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
3252+
Disabled: r.Policy1Disabled,
3253+
Signature: r.Policy1Signature,
3254+
}
3255+
}
3256+
if r.Policy2ID.Valid {
3257+
policy2 = &sqlc.ChannelPolicy{
3258+
ID: r.Policy2ID.Int64,
3259+
Version: r.Policy2Version.Int16,
3260+
ChannelID: r.Channel.ID,
3261+
NodeID: r.Policy2NodeID.Int64,
3262+
Timelock: r.Policy2Timelock.Int32,
3263+
FeePpm: r.Policy2FeePpm.Int64,
3264+
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
3265+
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
3266+
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
3267+
LastUpdate: r.Policy2LastUpdate,
3268+
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
3269+
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
3270+
Disabled: r.Policy2Disabled,
3271+
Signature: r.Policy2Signature,
3272+
}
3273+
}
3274+
3275+
return policy1, policy2, nil
3276+
30753277
case sqlc.GetChannelsByPolicyLastUpdateRangeRow:
30763278
if r.Policy1ID.Valid {
30773279
policy1 = &sqlc.ChannelPolicy{

0 commit comments

Comments
 (0)