Skip to content

Commit 8425f45

Browse files
committed
rules: add ChannelRestrict rule
1 parent 0629601 commit 8425f45

File tree

3 files changed

+555
-0
lines changed

3 files changed

+555
-0
lines changed

rules/channel_restrictions.go

Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
package rules
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
8+
"github.com/lightninglabs/lightning-terminal/firewalldb"
9+
"github.com/lightninglabs/lightning-terminal/litrpc"
10+
mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware"
11+
"github.com/lightningnetwork/lnd/lnrpc"
12+
"google.golang.org/protobuf/proto"
13+
)
14+
15+
var (
16+
// Compile-time checks to ensure that ChannelRestrict,
17+
// ChannelRestrictMgr and ChannelRestrictEnforcer implement the
18+
// appropriate Manager, Enforcer and Values interface.
19+
_ Manager = (*ChannelRestrictMgr)(nil)
20+
_ Enforcer = (*ChannelRestrictEnforcer)(nil)
21+
_ Values = (*ChannelRestrict)(nil)
22+
)
23+
24+
// ChannelRestrictName is the string identifier of the ChannelRestrict rule.
25+
const ChannelRestrictName = "channel-restriction"
26+
27+
// ChannelRestrictMgr manages the ChannelRestrict rule.
28+
type ChannelRestrictMgr struct {
29+
// here we can have known chanID to ChanOutpoint map (and vice versa)
30+
// then in NewEnforcer, if chan comes in we dont know of, then we
31+
// refresh the maps.
32+
33+
// chanIDToPoint is a map from channel ID to channel points for our
34+
// known set of channels.
35+
chanIDToPoint map[uint64]string
36+
37+
// chanPointToID is a map from channel point to channel ID's for our
38+
// known set of channels.
39+
chanPointToID map[string]uint64
40+
mu sync.Mutex
41+
}
42+
43+
// NewChannelRestrictMgr constructs a new instance of a ChannelRestrictMgr.
44+
func NewChannelRestrictMgr() *ChannelRestrictMgr {
45+
return &ChannelRestrictMgr{
46+
chanIDToPoint: make(map[uint64]string),
47+
chanPointToID: make(map[string]uint64),
48+
}
49+
}
50+
51+
// Stop cleans up the resources held by the manager.
52+
//
53+
// NOTE: This is part of the Manager interface.
54+
func (c *ChannelRestrictMgr) Stop() error {
55+
return nil
56+
}
57+
58+
// NewEnforcer constructs a new ChannelRestrict rule enforcer using the passed
59+
// values and config.
60+
//
61+
// NOTE: This is part of the Manager interface.
62+
func (c *ChannelRestrictMgr) NewEnforcer(cfg Config, values Values) (Enforcer,
63+
error) {
64+
65+
channels, ok := values.(*ChannelRestrict)
66+
if !ok {
67+
return nil, fmt.Errorf("values must be of type "+
68+
"ChannelRestrict, got %T", values)
69+
}
70+
71+
chanMap := make(map[uint64]bool, len(channels.DenyList))
72+
for _, chanID := range channels.DenyList {
73+
chanMap[chanID] = true
74+
if err := c.maybeUpdateChannelMaps(cfg, chanID); err != nil {
75+
return nil, err
76+
}
77+
}
78+
79+
return &ChannelRestrictEnforcer{
80+
mgr: c,
81+
ChannelRestrict: channels,
82+
channelMap: chanMap,
83+
}, nil
84+
}
85+
86+
// NewValueFromProto converts the given proto value into a ChannelRestrict Value
87+
// object.
88+
//
89+
// NOTE: This is part of the Manager interface.
90+
func (c *ChannelRestrictMgr) NewValueFromProto(v *litrpc.RuleValue) (Values,
91+
error) {
92+
93+
rv, ok := v.Value.(*litrpc.RuleValue_ChannelRestrict)
94+
if !ok {
95+
return nil, fmt.Errorf("incorrect RuleValue type")
96+
}
97+
98+
chanIDs := rv.ChannelRestrict.ChannelIds
99+
100+
if len(chanIDs) == 0 {
101+
return nil, fmt.Errorf("channel restrict list cannot be " +
102+
"empty. If no channel restrictions should be applied " +
103+
"then there is no need to add the rule")
104+
}
105+
106+
return &ChannelRestrict{
107+
DenyList: chanIDs,
108+
}, nil
109+
}
110+
111+
// EmptyValue returns a new ChannelRestrict instance.
112+
//
113+
// NOTE: This is part of the Manager interface.
114+
func (c *ChannelRestrictMgr) EmptyValue() Values {
115+
return &ChannelRestrict{}
116+
}
117+
118+
// maybeUpdateChannelMaps updates the ChannelRestrictMgrs set of known channels
119+
// iff the channel given by the caller is not found in the current map set.
120+
func (c *ChannelRestrictMgr) maybeUpdateChannelMaps(cfg Config,
121+
chanID uint64) error {
122+
123+
c.mu.Lock()
124+
defer c.mu.Unlock()
125+
126+
// If we already know of this channel, we don't need to go update our
127+
// maps.
128+
_, ok := c.chanIDToPoint[chanID]
129+
if ok {
130+
return nil
131+
}
132+
133+
// Fetch a list of our open channels from LND.
134+
lnd := cfg.GetLndClient()
135+
chans, err := lnd.ListChannels(context.Background(), false, false)
136+
if err != nil {
137+
return err
138+
}
139+
140+
var (
141+
found bool
142+
point string
143+
id uint64
144+
)
145+
146+
// Update our set of maps and also make sure that the channel specified
147+
// by the caller is valid given our set of open channels.
148+
for _, channel := range chans {
149+
point = channel.ChannelPoint
150+
id = channel.ChannelID
151+
152+
c.chanPointToID[point] = id
153+
c.chanIDToPoint[id] = point
154+
155+
if id == chanID {
156+
found = true
157+
}
158+
}
159+
160+
if !found {
161+
return fmt.Errorf("invalid channel ID")
162+
}
163+
164+
return nil
165+
}
166+
167+
func (c *ChannelRestrictMgr) getChannelID(point string) (uint64, bool) {
168+
c.mu.Lock()
169+
defer c.mu.Unlock()
170+
171+
id, ok := c.chanPointToID[point]
172+
return id, ok
173+
}
174+
175+
// ChannelRestrictEnforcer enforces requests and responses against a
176+
// ChannelRestrict rule.
177+
type ChannelRestrictEnforcer struct {
178+
mgr *ChannelRestrictMgr
179+
*ChannelRestrict
180+
channelMap map[uint64]bool
181+
}
182+
183+
// HandleRequest checks the validity of a request using the ChannelRestrict
184+
// rpcmiddleware.RoundTripCheckers.
185+
//
186+
// NOTE: this is part of the Enforcer interface.
187+
func (c *ChannelRestrictEnforcer) HandleRequest(ctx context.Context, uri string,
188+
msg proto.Message) (proto.Message, error) {
189+
190+
checkers := c.checkers()
191+
if checkers == nil {
192+
return nil, nil
193+
}
194+
195+
checker, ok := checkers[uri]
196+
if !ok {
197+
return nil, nil
198+
}
199+
200+
if !checker.HandlesRequest(msg.ProtoReflect().Type()) {
201+
return nil, fmt.Errorf("invalid implementation, checker for "+
202+
"URI %s does not accept request of type %v", uri,
203+
msg.ProtoReflect().Type())
204+
}
205+
206+
return checker.HandleRequest(ctx, msg)
207+
}
208+
209+
// HandleResponse handles a response using the ChannelRestrict
210+
// rpcmiddleware.RoundTripCheckers.
211+
//
212+
// NOTE: this is part of the Enforcer interface.
213+
func (c *ChannelRestrictEnforcer) HandleResponse(ctx context.Context, uri string,
214+
msg proto.Message) (proto.Message, error) {
215+
216+
checkers := c.checkers()
217+
if checkers == nil {
218+
return nil, nil
219+
}
220+
221+
checker, ok := checkers[uri]
222+
if !ok {
223+
return nil, nil
224+
}
225+
226+
if !checker.HandlesResponse(msg.ProtoReflect().Type()) {
227+
return nil, fmt.Errorf("invalid implementation, checker for "+
228+
"URI %s does not accept response of type %v", uri,
229+
msg.ProtoReflect().Type())
230+
}
231+
232+
return checker.HandleResponse(ctx, msg)
233+
}
234+
235+
// HandleErrorResponse handles and possible alters an error. This is a noop for
236+
// the ChannelRestrict rule.
237+
//
238+
// NOTE: this is part of the Enforcer interface.
239+
func (c *ChannelRestrictEnforcer) HandleErrorResponse(_ context.Context,
240+
_ string, _ error) (error, error) {
241+
242+
return nil, nil
243+
}
244+
245+
// checkers returns a map of URI to rpcmiddleware.RoundTripChecker which define
246+
// how the URI should be handled.
247+
func (c *ChannelRestrictEnforcer) checkers() map[string]mid.RoundTripChecker {
248+
return map[string]mid.RoundTripChecker{
249+
"/lnrpc.Lightning/UpdateChannelPolicy": mid.NewRequestChecker(
250+
&lnrpc.PolicyUpdateRequest{},
251+
&lnrpc.PolicyUpdateResponse{},
252+
func(ctx context.Context,
253+
r *lnrpc.PolicyUpdateRequest) error {
254+
255+
if r.GetGlobal() {
256+
return fmt.Errorf("cant apply call " +
257+
"to global scope when using " +
258+
"a channel restriction list")
259+
}
260+
261+
chanPoint := r.GetChanPoint()
262+
if chanPoint == nil {
263+
return fmt.Errorf("no channel point " +
264+
"specified")
265+
}
266+
267+
txid, err := lnrpc.GetChanPointFundingTxid(
268+
chanPoint,
269+
)
270+
if err != nil {
271+
return err
272+
}
273+
274+
index := chanPoint.GetOutputIndex()
275+
point := fmt.Sprintf(
276+
"%s:%d", txid.String(), index,
277+
)
278+
279+
id, ok := c.mgr.getChannelID(point)
280+
if !ok {
281+
return nil
282+
}
283+
284+
if c.channelMap[id] {
285+
return fmt.Errorf("illegal action on " +
286+
"channel in channel " +
287+
"restriction list")
288+
}
289+
290+
return nil
291+
},
292+
),
293+
}
294+
}
295+
296+
// ChannelRestrict is a rule prevents calls from acting upon a given set of
297+
// channels.
298+
type ChannelRestrict struct {
299+
// DenyList is a list of SCIDs that should not be acted upon by
300+
// any call.
301+
DenyList []uint64 `json:"channel_deny_list"`
302+
}
303+
304+
// VerifySane checks that the value of the values is ok given the min and max
305+
// allowed values. This is a noop for the ChannelRestrict rule.
306+
//
307+
// NOTE: this is part of the Values interface.
308+
func (c *ChannelRestrict) VerifySane(_, _ Values) error {
309+
return nil
310+
}
311+
312+
// RuleName returns the name of the rule that these values are to be used with.
313+
//
314+
// NOTE: this is part of the Values interface.
315+
func (c *ChannelRestrict) RuleName() string {
316+
return ChannelRestrictName
317+
}
318+
319+
// ToProto converts the rule Values to the litrpc counterpart.
320+
//
321+
// NOTE: this is part of the Values interface.
322+
func (c *ChannelRestrict) ToProto() *litrpc.RuleValue {
323+
return &litrpc.RuleValue{
324+
Value: &litrpc.RuleValue_ChannelRestrict{
325+
ChannelRestrict: &litrpc.ChannelRestrict{
326+
ChannelIds: c.DenyList,
327+
},
328+
},
329+
}
330+
}
331+
332+
// PseudoToReal assumes that the deny-list contains pseudo channel IDs and uses
333+
// these to check the privacy map db for the corresponding real channel IDs.
334+
// It constructs a new ChannelRestrict instance with these real channel IDs.
335+
//
336+
// NOTE: this is part of the Values interface.
337+
func (c *ChannelRestrict) PseudoToReal(db firewalldb.PrivacyMapDB) (Values,
338+
error) {
339+
340+
restrictList := make([]uint64, len(c.DenyList))
341+
err := db.View(func(tx firewalldb.PrivacyMapTx) error {
342+
for i, chanID := range c.DenyList {
343+
real, err := firewalldb.RevealUint64(tx, chanID)
344+
if err != nil {
345+
return err
346+
}
347+
348+
restrictList[i] = real
349+
}
350+
351+
return nil
352+
},
353+
)
354+
if err != nil {
355+
return nil, err
356+
}
357+
358+
return &ChannelRestrict{
359+
DenyList: restrictList,
360+
}, nil
361+
}
362+
363+
// RealToPseudo converts all the channel IDs into pseudo IDs.
364+
//
365+
// NOTE: this is part of the Values interface.
366+
func (c *ChannelRestrict) RealToPseudo() (Values, map[string]string, error) {
367+
pseudoIDs := make([]uint64, len(c.DenyList))
368+
privMapPairs := make(map[string]string)
369+
for i, c := range c.DenyList {
370+
// TODO(elle): check that this channel actually exists
371+
372+
chanID := firewalldb.Uint64ToStr(c)
373+
if pseudo, ok := privMapPairs[chanID]; ok {
374+
p, err := firewalldb.StrToUint64(pseudo)
375+
if err != nil {
376+
return nil, nil, err
377+
}
378+
379+
pseudoIDs[i] = p
380+
continue
381+
}
382+
383+
pseudoCp, pseudoCpStr := firewalldb.NewPseudoUint64()
384+
privMapPairs[chanID] = pseudoCpStr
385+
pseudoIDs[i] = pseudoCp
386+
}
387+
388+
return &ChannelRestrict{
389+
DenyList: pseudoIDs,
390+
}, privMapPairs, nil
391+
}

0 commit comments

Comments
 (0)