Skip to content

Commit 2aa3962

Browse files
committed
policylist: add emma-proofing for entity maps
1 parent f3411ea commit 2aa3962

File tree

2 files changed

+39
-30
lines changed

2 files changed

+39
-30
lines changed

policylist/list.go

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package policylist
22

33
import (
4+
"slices"
45
"sync"
56
"time"
67

@@ -26,8 +27,8 @@ type dplNode struct {
2627
type List struct {
2728
matchDuration prometheus.Observer
2829
byStateKey map[string]*dplNode
29-
byEntity map[string]*dplNode
30-
byEntityHash map[[util.HashSize]byte]*dplNode
30+
byEntity map[string][]*Policy
31+
byEntityHash map[[util.HashSize]byte][]*Policy
3132
dynamicHead *dplNode
3233
lock sync.RWMutex
3334
}
@@ -36,8 +37,8 @@ func NewList(roomID id.RoomID, entityType string) *List {
3637
return &List{
3738
matchDuration: matchDuration.WithLabelValues(roomID.String(), entityType),
3839
byStateKey: make(map[string]*dplNode),
39-
byEntity: make(map[string]*dplNode),
40-
byEntityHash: make(map[[util.HashSize]byte]*dplNode),
40+
byEntity: make(map[string][]*Policy),
41+
byEntityHash: make(map[[util.HashSize]byte][]*Policy),
4142
}
4243
}
4344

@@ -66,6 +67,29 @@ func (l *List) removeFromLinkedList(node *dplNode) {
6667
}
6768
}
6869

70+
func deletePolicyFromStaticMap[T comparable](m map[T][]*Policy, key T, policy *Policy) {
71+
existing, ok := m[key]
72+
if ok {
73+
modified := slices.DeleteFunc(existing, func(item *Policy) bool {
74+
return item == policy
75+
})
76+
if len(modified) == 0 {
77+
delete(m, key)
78+
} else if len(existing) != len(modified) {
79+
m[key] = modified
80+
}
81+
}
82+
}
83+
84+
func (l *List) removeFromStaticMaps(policy *Policy) {
85+
if policy.Entity != "" {
86+
deletePolicyFromStaticMap(l.byEntity, policy.Entity, policy)
87+
}
88+
if policy.EntityHash != nil {
89+
deletePolicyFromStaticMap(l.byEntityHash, *policy.EntityHash, policy)
90+
}
91+
}
92+
6993
func (l *List) Add(value *Policy) (*Policy, bool) {
7094
l.lock.Lock()
7195
defer l.lock.Unlock()
@@ -82,21 +106,16 @@ func (l *List) Add(value *Policy) (*Policy, bool) {
82106
}
83107
// There's an existing event with the same state key, but the entity changed, remove the old node.
84108
l.removeFromLinkedList(existing)
85-
if existing.Entity != "" {
86-
delete(l.byEntity, existing.Entity)
87-
}
88-
if existing.EntityHash != nil {
89-
delete(l.byEntityHash, *existing.EntityHash)
90-
}
109+
l.removeFromStaticMaps(existing.Policy)
91110
}
92111
node := &dplNode{Policy: value}
93112
l.byStateKey[value.StateKey] = node
94113
if !value.Ignored {
95114
if value.Entity != "" {
96-
l.byEntity[value.Entity] = node
115+
l.byEntity[value.Entity] = append(l.byEntity[value.Entity], value)
97116
}
98117
if value.EntityHash != nil {
99-
l.byEntityHash[*value.EntityHash] = node
118+
l.byEntityHash[*value.EntityHash] = append(l.byEntityHash[*value.EntityHash], value)
100119
}
101120
}
102121
if _, isStatic := value.Pattern.(glob.ExactGlob); value.Entity != "" && !isStatic && !value.Ignored {
@@ -117,14 +136,7 @@ func (l *List) Remove(eventType event.Type, stateKey string) *Policy {
117136
defer l.lock.Unlock()
118137
if value, ok := l.byStateKey[stateKey]; ok && eventType == value.Type {
119138
l.removeFromLinkedList(value)
120-
if entValue, ok := l.byEntity[value.Entity]; ok && entValue == value && value.Entity != "" {
121-
delete(l.byEntity, value.Entity)
122-
}
123-
if value.EntityHash != nil {
124-
if entHashValue, ok := l.byEntityHash[*value.EntityHash]; ok && entHashValue == value {
125-
delete(l.byEntityHash, *value.EntityHash)
126-
}
127-
}
139+
l.removeFromStaticMaps(value.Policy)
128140
delete(l.byStateKey, stateKey)
129141
return value.Policy
130142
}
@@ -151,13 +163,13 @@ func (l *List) Match(entity string) (output Match) {
151163
start := time.Now()
152164
exactMatch, ok := l.byEntity[entity]
153165
if ok {
154-
output = Match{exactMatch.Policy}
166+
output = append(output, exactMatch...)
155167
}
156168
if value, ok := l.byEntityHash[util.SHA256String(entity)]; ok {
157-
output = append(output, value.Policy)
169+
output = append(output, value...)
158170
}
159171
for item := l.dynamicHead; item != nil; item = item.next {
160-
if !item.Ignored && item.Pattern.Match(entity) && item != exactMatch {
172+
if !item.Ignored && item.Pattern.Match(entity) && !slices.Contains(output, item.Policy) {
161173
output = append(output, item.Policy)
162174
}
163175
}
@@ -172,21 +184,18 @@ func (l *List) MatchExact(entity string) (output Match) {
172184
l.lock.RLock()
173185
defer l.lock.RUnlock()
174186
if value, ok := l.byEntity[entity]; ok {
175-
output = Match{value.Policy}
187+
output = append(output, value...)
176188
}
177189
if value, ok := l.byEntityHash[util.SHA256String(entity)]; ok {
178-
output = append(output, value.Policy)
190+
output = append(output, value...)
179191
}
180192
return
181193
}
182194

183195
func (l *List) MatchHash(hash [util.HashSize]byte) (output Match) {
184196
l.lock.RLock()
185197
defer l.lock.RUnlock()
186-
if value, ok := l.byEntityHash[hash]; ok {
187-
output = Match{value.Policy}
188-
}
189-
return
198+
return slices.Clone(l.byEntityHash[hash])
190199
}
191200

192201
func (l *List) Search(patternString string, pattern glob.Glob) (output Match) {

policylist/store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func (s *Store) compileList(listIDs []id.RoomID, listGetter func(*Room) *List) (
201201
rules := listGetter(list)
202202
rules.lock.RLock()
203203
for _, policy := range rules.byEntity {
204-
output[policy.Entity] = policy.Policy
204+
output[policy[0].Entity] = policy[0]
205205
}
206206
rules.lock.RUnlock()
207207
}

0 commit comments

Comments
 (0)