Skip to content

Commit 7bb4d3b

Browse files
committed
session_rpcserver: pass all known pairs to RealToPseudo
In this commit, we keep track of all known privacy map pairs for a session along with any new pairs to be persisted.
1 parent 8b52899 commit 7bb4d3b

File tree

1 file changed

+75
-49
lines changed

1 file changed

+75
-49
lines changed

session_rpcserver.go

Lines changed: 75 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -838,12 +838,71 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
838838
return nil, fmt.Errorf("expiry must be in the future")
839839
}
840840

841+
// If the privacy mapper is being used for this session, then we need
842+
// to keep track of all our known privacy map pairs for this session
843+
// along with any new pairs that we need to persist.
841844
var (
842845
privacy = !req.NoPrivacyMapper
843-
privacyMapPairs = make(map[string]string)
844846
knownPrivMapPairs = firewalldb.NewPrivacyMapPairs(nil)
847+
newPrivMapPairs = make(map[string]string)
845848
)
846849

850+
// If a previous session ID has been set to link this new one to, we
851+
// first check if we have the referenced session, and we make sure it
852+
// has been revoked.
853+
var (
854+
linkedGroupID *session.ID
855+
linkedGroupSession *session.Session
856+
)
857+
if len(req.LinkedGroupId) != 0 {
858+
var groupID session.ID
859+
copy(groupID[:], req.LinkedGroupId)
860+
861+
// Check that the group actually does exist.
862+
groupSess, err := s.cfg.db.GetSessionByID(groupID)
863+
if err != nil {
864+
return nil, err
865+
}
866+
867+
// Ensure that the linked session is in fact the first session
868+
// in its group.
869+
if groupSess.ID != groupSess.GroupID {
870+
return nil, fmt.Errorf("can not link to session "+
871+
"%x since it is not the first in the session "+
872+
"group %x", groupSess.ID, groupSess.GroupID)
873+
}
874+
875+
// Now we need to check that all the sessions in the group are
876+
// no longer active.
877+
ok, err := s.cfg.db.CheckSessionGroupPredicate(
878+
groupID, func(s *session.Session) bool {
879+
return s.State == session.StateRevoked ||
880+
s.State == session.StateExpired
881+
},
882+
)
883+
if err != nil {
884+
return nil, err
885+
}
886+
887+
if !ok {
888+
return nil, fmt.Errorf("a linked session in group "+
889+
"%x is still active", groupID)
890+
}
891+
892+
linkedGroupID = &groupID
893+
linkedGroupSession = groupSess
894+
895+
privDB := s.cfg.privMap(groupID)
896+
err = privDB.View(func(tx firewalldb.PrivacyMapTx) error {
897+
knownPrivMapPairs, err = tx.FetchAllPairs()
898+
899+
return err
900+
})
901+
if err != nil {
902+
return nil, err
903+
}
904+
}
905+
847906
// First need to fetch all the perms that need to be baked into this
848907
// mac based on the features.
849908
allFeatures, err := s.cfg.autopilot.ListFeatures(ctx)
@@ -892,8 +951,21 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
892951
return nil, err
893952
}
894953

954+
// Store the new privacy map pairs in
955+
// the newPrivMap pairs map so that
956+
// they are later persisted to the real
957+
// priv map db.
895958
for k, v := range privMapPairs {
896-
privacyMapPairs[k] = v
959+
newPrivMapPairs[k] = v
960+
}
961+
962+
// Also add the new pairs to the known
963+
// set of pairs.
964+
err = knownPrivMapPairs.Add(
965+
privMapPairs,
966+
)
967+
if err != nil {
968+
return nil, err
897969
}
898970
}
899971

@@ -1017,52 +1089,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
10171089
caveats = append(caveats, firewall.MetaPrivacyCaveat)
10181090
}
10191091

1020-
// If a previous session ID has been set to link this new one to, we
1021-
// first check if we have the referenced session, and we make sure it
1022-
// has been revoked.
1023-
var (
1024-
linkedGroupID *session.ID
1025-
linkedGroupSession *session.Session
1026-
)
1027-
if len(req.LinkedGroupId) != 0 {
1028-
var groupID session.ID
1029-
copy(groupID[:], req.LinkedGroupId)
1030-
1031-
// Check that the group actually does exist.
1032-
groupSess, err := s.cfg.db.GetSessionByID(groupID)
1033-
if err != nil {
1034-
return nil, err
1035-
}
1036-
1037-
// Ensure that the linked session is in fact the first session
1038-
// in its group.
1039-
if groupSess.ID != groupSess.GroupID {
1040-
return nil, fmt.Errorf("can not link to session "+
1041-
"%x since it is not the first in the session "+
1042-
"group %x", groupSess.ID, groupSess.GroupID)
1043-
}
1044-
1045-
// Now we need to check that all the sessions in the group are
1046-
// no longer active.
1047-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
1048-
groupID, func(s *session.Session) bool {
1049-
return s.State == session.StateRevoked ||
1050-
s.State == session.StateExpired
1051-
},
1052-
)
1053-
if err != nil {
1054-
return nil, err
1055-
}
1056-
1057-
if !ok {
1058-
return nil, fmt.Errorf("a linked session in group "+
1059-
"%x is still active", groupID)
1060-
}
1061-
1062-
linkedGroupID = &groupID
1063-
linkedGroupSession = groupSess
1064-
}
1065-
10661092
s.sessRegMu.Lock()
10671093
defer s.sessRegMu.Unlock()
10681094

@@ -1101,7 +1127,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
11011127
// Register all the privacy map pairs for this session ID.
11021128
privDB := s.cfg.privMap(sess.GroupID)
11031129
err = privDB.Update(func(tx firewalldb.PrivacyMapTx) error {
1104-
for r, p := range privacyMapPairs {
1130+
for r, p := range newPrivMapPairs {
11051131
err := tx.NewPair(r, p)
11061132
if err != nil {
11071133
return err

0 commit comments

Comments
 (0)