Skip to content

Commit 52d478e

Browse files
committed
include alternative check that includes subrelations
1 parent aaecc50 commit 52d478e

File tree

2 files changed

+211
-15
lines changed

2 files changed

+211
-15
lines changed

pkg/schema/type_check.go

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,29 @@ import (
1010

1111
const ellipsesRelation = "..."
1212

13-
// GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
13+
// GetRecursiveTypesForRelation returns, for a given definition and relation, are the potential
1414
// subject definition names of that relation.
15+
func (ts *TypeSystem) GetRecursiveTypesForRelation(ctx context.Context, defName string, relationName string) ([]string, error) {
16+
seen := mapz.NewSet[string]()
17+
set, err := ts.getTypesForRelationInternal(ctx, defName, relationName, seen, false)
18+
if err != nil {
19+
return nil, err
20+
}
21+
return set.AsSlice(), nil
22+
}
23+
24+
// GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
25+
// subject definition names of that relation, as well as any relation subtypes (eg, `group#member`) that may occur.
1526
func (ts *TypeSystem) GetRecursiveSubtypesForRelation(ctx context.Context, defName string, relationName string) ([]string, error) {
1627
seen := mapz.NewSet[string]()
17-
set, err := ts.getTypesForRelationInternal(ctx, defName, relationName, seen)
28+
set, err := ts.getTypesForRelationInternal(ctx, defName, relationName, seen, true)
1829
if err != nil {
1930
return nil, err
2031
}
2132
return set.AsSlice(), nil
2233
}
2334

24-
func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName string, relationName string, seen *mapz.Set[string]) (*mapz.Set[string], error) {
35+
func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName string, relationName string, seen *mapz.Set[string], addRelations bool) (*mapz.Set[string], error) {
2536
id := fmt.Sprint(defName, "#", relationName)
2637
if seen.Has(id) {
2738
return nil, nil
@@ -36,20 +47,23 @@ func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName s
3647
return nil, asTypeError(NewRelationNotFoundErr(defName, relationName))
3748
}
3849
if rel.TypeInformation != nil {
39-
return ts.getTypesForInfo(ctx, defName, rel.TypeInformation, seen)
50+
return ts.getTypesForInfo(ctx, defName, rel.TypeInformation, seen, addRelations)
4051
} else if rel.UsersetRewrite != nil {
41-
return ts.getTypesForRewrite(ctx, defName, rel.UsersetRewrite, seen)
52+
return ts.getTypesForRewrite(ctx, defName, rel.UsersetRewrite, seen, addRelations)
4253
}
4354
return nil, asTypeError(NewMissingAllowedRelationsErr(defName, relationName))
4455
}
4556

46-
func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *corev1.TypeInformation, seen *mapz.Set[string]) (*mapz.Set[string], error) {
57+
func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *corev1.TypeInformation, seen *mapz.Set[string], addRelations bool) (*mapz.Set[string], error) {
4758
out := mapz.NewSet[string]()
4859
for _, dr := range rel.GetAllowedDirectRelations() {
4960
if dr.GetRelation() == ellipsesRelation {
5061
out.Add(dr.GetNamespace())
5162
} else if dr.GetRelation() != "" {
52-
rest, err := ts.getTypesForRelationInternal(ctx, dr.GetNamespace(), dr.GetRelation(), seen)
63+
if addRelations {
64+
out.Add(fmt.Sprintf("%s#%s", dr.GetNamespace(), dr.GetRelation()))
65+
}
66+
rest, err := ts.getTypesForRelationInternal(ctx, dr.GetNamespace(), dr.GetRelation(), seen, addRelations)
5367
if err != nil {
5468
return nil, err
5569
}
@@ -62,7 +76,7 @@ func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *
6276
return out, nil
6377
}
6478

65-
func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, rel *corev1.UsersetRewrite, seen *mapz.Set[string]) (*mapz.Set[string], error) {
79+
func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, rel *corev1.UsersetRewrite, seen *mapz.Set[string], addRelations bool) (*mapz.Set[string], error) {
6680
out := mapz.NewSet[string]()
6781

6882
// We're finding the union of all the things touched, regardless.
@@ -74,39 +88,39 @@ func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, re
7488
}
7589
for _, child := range op.GetChild() {
7690
if computed := child.GetComputedUserset(); computed != nil {
77-
set, err := ts.getTypesForRelationInternal(ctx, defName, computed.GetRelation(), seen)
91+
set, err := ts.getTypesForRelationInternal(ctx, defName, computed.GetRelation(), seen, addRelations)
7892
if err != nil {
7993
return nil, err
8094
}
8195
out.Merge(set)
8296
}
8397
if rewrite := child.GetUsersetRewrite(); rewrite != nil {
84-
sub, err := ts.getTypesForRewrite(ctx, defName, rewrite, seen)
98+
sub, err := ts.getTypesForRewrite(ctx, defName, rewrite, seen, addRelations)
8599
if err != nil {
86100
return nil, err
87101
}
88102
out.Merge(sub)
89103
}
90104
if userset := child.GetTupleToUserset(); userset != nil {
91-
set, err := ts.getTypesForRelationInternal(ctx, defName, userset.GetTupleset().GetRelation(), seen)
105+
set, err := ts.getTypesForRelationInternal(ctx, defName, userset.GetTupleset().GetRelation(), seen, addRelations)
92106
if err != nil {
93107
return nil, err
94108
}
95109
for _, s := range set.AsSlice() {
96-
targets, err := ts.getTypesForRelationInternal(ctx, s, userset.GetComputedUserset().GetRelation(), seen)
110+
targets, err := ts.getTypesForRelationInternal(ctx, s, userset.GetComputedUserset().GetRelation(), seen, addRelations)
97111
if err != nil {
98112
return nil, err
99113
}
100114
out.Merge(targets)
101115
}
102116
}
103117
if functioned := child.GetFunctionedTupleToUserset(); functioned != nil {
104-
set, err := ts.getTypesForRelationInternal(ctx, defName, functioned.GetTupleset().GetRelation(), seen)
118+
set, err := ts.getTypesForRelationInternal(ctx, defName, functioned.GetTupleset().GetRelation(), seen, addRelations)
105119
if err != nil {
106120
return nil, err
107121
}
108122
for _, s := range set.AsSlice() {
109-
targets, err := ts.getTypesForRelationInternal(ctx, s, functioned.GetComputedUserset().GetRelation(), seen)
123+
targets, err := ts.getTypesForRelationInternal(ctx, s, functioned.GetComputedUserset().GetRelation(), seen, addRelations)
110124
if err != nil {
111125
return nil, err
112126
}

pkg/schema/type_check_test.go

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
1111
)
1212

13-
func TestTypechecking(t *testing.T) {
13+
func TestTypecheckingJustTypes(t *testing.T) {
1414
t.Parallel()
1515
type testcase struct {
1616
name string
@@ -171,6 +171,188 @@ func TestTypechecking(t *testing.T) {
171171
}, compiler.AllowUnprefixedObjectType())
172172
require.NoError(t, err)
173173

174+
res := ResolverForCompiledSchema(*schema)
175+
ts := NewTypeSystem(res)
176+
for _, resource := range schema.ObjectDefinitions {
177+
for _, relation := range resource.Relation {
178+
types, err := ts.GetRecursiveTypesForRelation(context.Background(), resource.Name, relation.Name)
179+
require.NoError(t, err)
180+
181+
rel := resource.Name + "#" + relation.Name
182+
expected, ok := tc.expected[rel]
183+
require.True(t, ok, fmt.Sprintf("expected %v to be in %v", rel, tc.expected))
184+
require.Len(t, types, len(expected), rel)
185+
186+
for _, typ := range types {
187+
require.Contains(t, expected, typ, fmt.Sprintf("expected %v to be in %v", typ, expected))
188+
}
189+
}
190+
}
191+
})
192+
}
193+
}
194+
195+
func TestTypecheckingWithSubrelations(t *testing.T) {
196+
t.Parallel()
197+
type testcase struct {
198+
name string
199+
schemaText string
200+
expected map[string][]string
201+
}
202+
tcs := []testcase{
203+
{
204+
name: "basic arrow",
205+
schemaText: `
206+
definition user {}
207+
208+
definition organization {
209+
relation member: user
210+
}
211+
212+
definition resource {
213+
relation org: organization
214+
relation viewer: user
215+
permission view = org->member + viewer
216+
}
217+
`,
218+
expected: map[string][]string{
219+
"organization#member": {"user"},
220+
"resource#viewer": {"user"},
221+
"resource#org": {"organization"},
222+
"resource#view": {"user"},
223+
},
224+
},
225+
{
226+
name: "multi-type arrow",
227+
schemaText: `
228+
definition user {}
229+
230+
definition organization {
231+
relation member: user
232+
}
233+
234+
definition resource {
235+
relation org: organization
236+
relation viewer: user
237+
permission view = org + viewer
238+
}
239+
`,
240+
expected: map[string][]string{
241+
"organization#member": {"user"},
242+
"resource#viewer": {"user"},
243+
"resource#org": {"organization"},
244+
"resource#view": {"organization", "user"},
245+
},
246+
},
247+
{
248+
name: "functional",
249+
schemaText: `
250+
definition user {}
251+
252+
definition organization {
253+
relation member: user
254+
}
255+
256+
definition resource {
257+
relation org: organization
258+
permission view = org.all(member)
259+
}
260+
`,
261+
expected: map[string][]string{
262+
"organization#member": {"user"},
263+
"resource#viewer": {"user"},
264+
"resource#org": {"organization"},
265+
"resource#view": {"user"},
266+
},
267+
},
268+
{
269+
name: "multi-type rel",
270+
schemaText: `
271+
definition user {}
272+
273+
definition organization {
274+
relation member: user
275+
}
276+
277+
definition resource {
278+
relation viewer: user | organization
279+
}
280+
`,
281+
expected: map[string][]string{
282+
"organization#member": {"user"},
283+
"resource#viewer": {"user", "organization"},
284+
},
285+
},
286+
{
287+
name: "subrel",
288+
schemaText: `
289+
definition user {}
290+
291+
definition organization {
292+
relation member: user
293+
}
294+
295+
definition resource {
296+
relation viewer: organization#member
297+
}
298+
`,
299+
expected: map[string][]string{
300+
"organization#member": {"user"},
301+
"resource#viewer": {"user", "organization#member"},
302+
},
303+
},
304+
{
305+
name: "wildcard",
306+
schemaText: `
307+
definition user {}
308+
309+
definition organization {
310+
relation member: user:*
311+
}
312+
313+
definition resource {
314+
relation viewer: organization#member
315+
}
316+
`,
317+
expected: map[string][]string{
318+
"organization#member": {"user"},
319+
"resource#viewer": {"user", "organization#member"},
320+
},
321+
},
322+
{
323+
name: "banned",
324+
schemaText: `
325+
definition user {}
326+
327+
definition organization {
328+
relation member: user
329+
}
330+
331+
definition resource {
332+
relation viewer: organization#member
333+
relation banned: user
334+
permission view = viewer - banned
335+
}
336+
`,
337+
expected: map[string][]string{
338+
"organization#member": {"user"},
339+
"resource#viewer": {"user", "organization#member"},
340+
"resource#banned": {"user"},
341+
"resource#view": {"user", "organization#member"},
342+
},
343+
},
344+
}
345+
for _, tc := range tcs {
346+
t.Run(tc.name, func(t *testing.T) {
347+
tc := tc
348+
t.Parallel()
349+
350+
schema, err := compiler.Compile(compiler.InputSchema{
351+
Source: "",
352+
SchemaString: tc.schemaText,
353+
}, compiler.AllowUnprefixedObjectType())
354+
require.NoError(t, err)
355+
174356
res := ResolverForCompiledSchema(*schema)
175357
ts := NewTypeSystem(res)
176358
for _, resource := range schema.ObjectDefinitions {

0 commit comments

Comments
 (0)