Skip to content

Commit c99f2dc

Browse files
authored
Merge pull request #1045 from ellemouton/sessionIDFromCtx
multi: extract session ID from context
2 parents 329b939 + a89b350 commit c99f2dc

File tree

8 files changed

+168
-8
lines changed

8 files changed

+168
-8
lines changed

firewall/privacy_mapper.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,11 @@ func (p *PrivacyMapper) Intercept(ctx context.Context,
106106
"interception request: %v", err)
107107
}
108108

109-
sessionID, err := session.IDFromMacaroon(ri.Macaroon)
109+
sessionID, err := ri.SessionID.UnwrapOrErr(
110+
fmt.Errorf("no session ID found in request info"),
111+
)
110112
if err != nil {
111-
return nil, fmt.Errorf("could not extract ID from macaroon")
113+
return nil, err
112114
}
113115

114116
log.Tracef("PrivacyMapper: Intercepting %v", ri)

firewall/privacy_mapper_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/lightningnetwork/lnd/lnrpc"
1414
"github.com/lightningnetwork/lnd/rpcperms"
1515
"github.com/stretchr/testify/require"
16+
"google.golang.org/grpc/metadata"
1617
"google.golang.org/protobuf/proto"
1718
"gopkg.in/macaroon-bakery.v2/bakery"
1819
"gopkg.in/macaroon.v2"
@@ -907,6 +908,9 @@ func TestPrivacyMapper(t *testing.T) {
907908
rawMsg, err := proto.Marshal(test.msg)
908909
require.NoError(t, err)
909910

911+
md := make(metadata.MD)
912+
session.AddToGRPCMetadata(md, sessionID)
913+
910914
interceptReq := &rpcperms.InterceptionRequest{
911915
Type: test.msgType,
912916
Macaroon: mac,
@@ -916,6 +920,7 @@ func TestPrivacyMapper(t *testing.T) {
916920
ProtoTypeName: string(
917921
proto.MessageName(test.msg),
918922
),
923+
CtxMetadataPairs: md,
919924
}
920925

921926
mwReq, err := interceptReq.ToRPC(1, 2)
@@ -1006,6 +1011,9 @@ func TestPrivacyMapper(t *testing.T) {
10061011
amounts := make([]uint64, numSamples)
10071012
timestamps := make([]uint64, numSamples)
10081013

1014+
md := make(metadata.MD)
1015+
session.AddToGRPCMetadata(md, sessionID)
1016+
10091017
for i := 0; i < numSamples; i++ {
10101018
interceptReq := &rpcperms.InterceptionRequest{
10111019
Type: rpcperms.TypeResponse,
@@ -1016,6 +1024,7 @@ func TestPrivacyMapper(t *testing.T) {
10161024
ProtoTypeName: string(
10171025
proto.MessageName(msg),
10181026
),
1027+
CtxMetadataPairs: md,
10191028
}
10201029

10211030
mwReq, err := interceptReq.ToRPC(1, 2)

firewall/request_info.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import (
44
"fmt"
55
"strings"
66

7+
"github.com/lightninglabs/lightning-terminal/session"
8+
"github.com/lightningnetwork/lnd/fn"
79
"github.com/lightningnetwork/lnd/lnrpc"
10+
"google.golang.org/grpc/metadata"
811
"gopkg.in/macaroon.v2"
912
)
1013

@@ -25,6 +28,7 @@ const (
2528
// RequestInfo stores the parsed representation of an incoming RPC middleware
2629
// request.
2730
type RequestInfo struct {
31+
SessionID fn.Option[session.ID]
2832
MsgID uint64
2933
RequestID uint64
3034
MWRequestType string
@@ -76,8 +80,22 @@ func NewInfoFromRequest(req *lnrpc.RPCMiddlewareRequest) (*RequestInfo, error) {
7680
return nil, fmt.Errorf("invalid request type: %T", t)
7781
}
7882

83+
md := make(metadata.MD)
84+
for k, vs := range req.MetadataPairs {
85+
for _, v := range vs.Values {
86+
md.Append(k, v)
87+
}
88+
}
89+
90+
sessionID, err := session.FromGRPCMetadata(md)
91+
if err != nil {
92+
return nil, fmt.Errorf("error extracting session ID "+
93+
"from request: %v", err)
94+
}
95+
7996
ri.MsgID = req.MsgId
8097
ri.RequestID = req.RequestId
98+
ri.SessionID = sessionID
8199

82100
// If there is no macaroon in the request, then there is nothing left
83101
// to parse.

firewall/request_logger.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo,
194194
}
195195

196196
actionReq := &firewalldb.AddActionReq{
197+
SessionID: ri.SessionID,
197198
MacaroonIdentifier: macaroonID,
198199
RPCMethod: ri.URI,
199200
}

firewall/rule_enforcer.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,11 @@ func (r *RuleEnforcer) Intercept(ctx context.Context,
237237
func (r *RuleEnforcer) handleRequest(ctx context.Context,
238238
ri *RequestInfo) (proto.Message, error) {
239239

240-
sessionID, err := session.IDFromMacaroon(ri.Macaroon)
240+
sessionID, err := ri.SessionID.UnwrapOrErr(
241+
fmt.Errorf("no session ID found in request info"),
242+
)
241243
if err != nil {
242-
return nil, fmt.Errorf("could not extract ID from macaroon")
244+
return nil, err
243245
}
244246

245247
rules, err := r.collectEnforcers(ctx, ri, sessionID)

session/context.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package session
2+
3+
import (
4+
"encoding/hex"
5+
"fmt"
6+
7+
"github.com/lightningnetwork/lnd/fn"
8+
"google.golang.org/grpc/metadata"
9+
)
10+
11+
// contextKey is a struct that is used as a key for storing session IDs
12+
// in a context. Using this unexported type prevents collisions with other
13+
// context keys that may be used in the same context. However, this only
14+
// applies if the context is passed around in the same binary and not if the
15+
// value is converted to grpc metadata and sent over the wire. In that case,
16+
// we need to use a string key to avoid collisions with other metadata keys.
17+
type contextKey struct {
18+
name string
19+
}
20+
21+
// sessionIDCtxKey is the context key used to store the session ID in
22+
// a context. The key is a string to avoid collisions with other context values
23+
// that may also be included in grpc metadata which is why we add the 'lit'
24+
// prefix.
25+
var sessionIDCtxKey = contextKey{"lit_session_id"}
26+
27+
// FromGRPCMetadata extracts the session ID from the given gRPC metadata kv
28+
// pairs if one is found.
29+
func FromGRPCMetadata(md metadata.MD) (fn.Option[ID], error) {
30+
val := md.Get(sessionIDCtxKey.name)
31+
if len(val) == 0 {
32+
return fn.None[ID](), nil
33+
}
34+
35+
if len(val) != 1 {
36+
return fn.None[ID](), fmt.Errorf("more than one session ID "+
37+
"found in gRPC metadata: %v", val)
38+
}
39+
40+
b, err := hex.DecodeString(val[0])
41+
if err != nil {
42+
return fn.None[ID](), err
43+
}
44+
45+
sessID, err := IDFromBytes(b)
46+
if err != nil {
47+
return fn.None[ID](), err
48+
}
49+
50+
return fn.Some(sessID), nil
51+
}
52+
53+
// AddToGRPCMetadata adds the session ID to the given gRPC metadata kv pairs.
54+
// The session ID is encoded as a hex string.
55+
func AddToGRPCMetadata(md metadata.MD, id ID) {
56+
md.Set(sessionIDCtxKey.name, hex.EncodeToString(id[:]))
57+
}

session/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ import (
1818

1919
type sessionID [33]byte
2020

21-
type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server
21+
type GRPCServerCreator func(sessionID ID,
22+
opts ...grpc.ServerOption) *grpc.Server
2223

2324
type mailboxSession struct {
2425
server *grpc.Server
@@ -70,7 +71,7 @@ func (m *mailboxSession) start(session *Session,
7071
}
7172

7273
noiseConn := mailbox.NewNoiseGrpcConn(keys)
73-
m.server = serverCreator(grpc.Creds(noiseConn))
74+
m.server = serverCreator(session.ID, grpc.Creds(noiseConn))
7475

7576
m.wg.Add(1)
7677
go m.run(mailboxServer)

session_rpcserver.go

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/lightningnetwork/lnd/fn"
2727
"github.com/lightningnetwork/lnd/macaroons"
2828
"google.golang.org/grpc"
29+
"google.golang.org/grpc/metadata"
2930
"gopkg.in/macaroon-bakery.v2/bakery"
3031
"gopkg.in/macaroon-bakery.v2/bakery/checkers"
3132
"gopkg.in/macaroon.v2"
@@ -77,10 +78,23 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
7778
// actual mailbox server that spins up the Terminal Connect server
7879
// interface.
7980
server := session.NewServer(
80-
func(opts ...grpc.ServerOption) *grpc.Server {
81-
allOpts := append(cfg.grpcOptions, opts...)
81+
func(id session.ID, opts ...grpc.ServerOption) *grpc.Server {
82+
// Add the session ID injector interceptors first so
83+
// that the session ID is available in the context of
84+
// all interceptors that come after.
85+
allOpts := []grpc.ServerOption{
86+
addSessionIDToStreamCtx(id),
87+
addSessionIDToUnaryCtx(id),
88+
}
89+
90+
allOpts = append(allOpts, cfg.grpcOptions...)
91+
allOpts = append(allOpts, opts...)
92+
93+
// Construct the gRPC server with the options.
8294
grpcServer := grpc.NewServer(allOpts...)
8395

96+
// Register various grpc servers with the LNC session
97+
// server.
8498
cfg.registerGrpcServers(grpcServer)
8599

86100
return grpcServer
@@ -94,6 +108,62 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
94108
}, nil
95109
}
96110

111+
// wrappedServerStream is a wrapper around the grpc.ServerStream that allows us
112+
// to set a custom context. This is needed since the stream handler function
113+
// doesn't take a context as an argument, but rather has a Context method on the
114+
// handler itself. So we use this custom wrapper to override this method.
115+
type wrappedServerStream struct {
116+
grpc.ServerStream
117+
ctx context.Context
118+
}
119+
120+
// Context returns the context of the stream.
121+
//
122+
// NOTE: This implements the grpc.ServerStream Context method.
123+
func (w *wrappedServerStream) Context() context.Context {
124+
return w.ctx
125+
}
126+
127+
// addSessionIDToStreamCtx is a gRPC stream interceptor that adds the given
128+
// session ID to the context of the stream. This allows us to access the
129+
// session ID later on for any gRPC calls made through this stream.
130+
func addSessionIDToStreamCtx(id session.ID) grpc.ServerOption {
131+
return grpc.StreamInterceptor(func(srv any, ss grpc.ServerStream,
132+
info *grpc.StreamServerInfo,
133+
handler grpc.StreamHandler) error {
134+
135+
md, _ := metadata.FromIncomingContext(ss.Context())
136+
mdCopy := md.Copy()
137+
session.AddToGRPCMetadata(mdCopy, id)
138+
139+
// Wrap the original stream with our custom context.
140+
wrapped := &wrappedServerStream{
141+
ServerStream: ss,
142+
ctx: metadata.NewIncomingContext(
143+
ss.Context(), mdCopy,
144+
),
145+
}
146+
147+
return handler(srv, wrapped)
148+
})
149+
}
150+
151+
// addSessionIDToUnaryCtx is a gRPC unary interceptor that adds the given
152+
// session ID to the context of the unary call. This allows us to access the
153+
// session ID later on for any gRPC calls made through this context.
154+
func addSessionIDToUnaryCtx(id session.ID) grpc.ServerOption {
155+
return grpc.UnaryInterceptor(func(ctx context.Context, req any,
156+
info *grpc.UnaryServerInfo,
157+
handler grpc.UnaryHandler) (resp any, err error) {
158+
159+
md, _ := metadata.FromIncomingContext(ctx)
160+
mdCopy := md.Copy()
161+
session.AddToGRPCMetadata(mdCopy, id)
162+
163+
return handler(metadata.NewIncomingContext(ctx, mdCopy), req)
164+
})
165+
}
166+
97167
// start all the components necessary for the sessionRpcServer to start serving
98168
// requests. This includes resuming all non-revoked sessions.
99169
func (s *sessionRpcServer) start(ctx context.Context) error {

0 commit comments

Comments
 (0)