Skip to content

Commit b2a3dc1

Browse files
committed
include alternative check that includes subrelations
1 parent aaecc50 commit b2a3dc1

File tree

2 files changed

+212
-15
lines changed

2 files changed

+212
-15
lines changed

pkg/schema/type_check.go

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,30 @@ import (
1010

1111
const ellipsesRelation = "..."
1212

13-
// GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
13+
// GetRecursiveTypesForRelation returns, for a given definition and relation, are the potential
1414
// subject definition names of that relation.
15+
func (ts *TypeSystem) GetRecursiveTypesForRelation(ctx context.Context, defName string, relationName string) ([]string, error) {
16+
seen := mapz.NewSet[string]()
17+
set, err := ts.getTypesForRelationInternal(ctx, defName, relationName, seen, false)
18+
if err != nil {
19+
return nil, err
20+
}
21+
return set.AsSlice(), nil
22+
}
23+
24+
// GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
25+
// subject definition names of that relation, as well as any relation subtypes (eg, `group#member`) that may occur.
1526
func (ts *TypeSystem) GetRecursiveSubtypesForRelation(ctx context.Context, defName string, relationName string) ([]string, error) {
1627
seen := mapz.NewSet[string]()
17-
set, err := ts.getTypesForRelationInternal(ctx, defName, relationName, seen)
28+
set, err := ts.getTypesForRelationInternal(ctx, defName, relationName, seen, true)
1829
if err != nil {
1930
return nil, err
2031
}
2132
return set.AsSlice(), nil
33+
2234
}
2335

24-
func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName string, relationName string, seen *mapz.Set[string]) (*mapz.Set[string], error) {
36+
func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName string, relationName string, seen *mapz.Set[string], addRelations bool) (*mapz.Set[string], error) {
2537
id := fmt.Sprint(defName, "#", relationName)
2638
if seen.Has(id) {
2739
return nil, nil
@@ -36,20 +48,23 @@ func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName s
3648
return nil, asTypeError(NewRelationNotFoundErr(defName, relationName))
3749
}
3850
if rel.TypeInformation != nil {
39-
return ts.getTypesForInfo(ctx, defName, rel.TypeInformation, seen)
51+
return ts.getTypesForInfo(ctx, defName, rel.TypeInformation, seen, addRelations)
4052
} else if rel.UsersetRewrite != nil {
41-
return ts.getTypesForRewrite(ctx, defName, rel.UsersetRewrite, seen)
53+
return ts.getTypesForRewrite(ctx, defName, rel.UsersetRewrite, seen, addRelations)
4254
}
4355
return nil, asTypeError(NewMissingAllowedRelationsErr(defName, relationName))
4456
}
4557

46-
func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *corev1.TypeInformation, seen *mapz.Set[string]) (*mapz.Set[string], error) {
58+
func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *corev1.TypeInformation, seen *mapz.Set[string], addRelations bool) (*mapz.Set[string], error) {
4759
out := mapz.NewSet[string]()
4860
for _, dr := range rel.GetAllowedDirectRelations() {
4961
if dr.GetRelation() == ellipsesRelation {
5062
out.Add(dr.GetNamespace())
5163
} else if dr.GetRelation() != "" {
52-
rest, err := ts.getTypesForRelationInternal(ctx, dr.GetNamespace(), dr.GetRelation(), seen)
64+
if addRelations {
65+
out.Add(fmt.Sprintf("%s#%s", dr.GetNamespace(), dr.GetRelation()))
66+
}
67+
rest, err := ts.getTypesForRelationInternal(ctx, dr.GetNamespace(), dr.GetRelation(), seen, addRelations)
5368
if err != nil {
5469
return nil, err
5570
}
@@ -62,7 +77,7 @@ func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *
6277
return out, nil
6378
}
6479

65-
func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, rel *corev1.UsersetRewrite, seen *mapz.Set[string]) (*mapz.Set[string], error) {
80+
func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, rel *corev1.UsersetRewrite, seen *mapz.Set[string], addRelations bool) (*mapz.Set[string], error) {
6681
out := mapz.NewSet[string]()
6782

6883
// We're finding the union of all the things touched, regardless.
@@ -74,39 +89,39 @@ func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, re
7489
}
7590
for _, child := range op.GetChild() {
7691
if computed := child.GetComputedUserset(); computed != nil {
77-
set, err := ts.getTypesForRelationInternal(ctx, defName, computed.GetRelation(), seen)
92+
set, err := ts.getTypesForRelationInternal(ctx, defName, computed.GetRelation(), seen, addRelations)
7893
if err != nil {
7994
return nil, err
8095
}
8196
out.Merge(set)
8297
}
8398
if rewrite := child.GetUsersetRewrite(); rewrite != nil {
84-
sub, err := ts.getTypesForRewrite(ctx, defName, rewrite, seen)
99+
sub, err := ts.getTypesForRewrite(ctx, defName, rewrite, seen, addRelations)
85100
if err != nil {
86101
return nil, err
87102
}
88103
out.Merge(sub)
89104
}
90105
if userset := child.GetTupleToUserset(); userset != nil {
91-
set, err := ts.getTypesForRelationInternal(ctx, defName, userset.GetTupleset().GetRelation(), seen)
106+
set, err := ts.getTypesForRelationInternal(ctx, defName, userset.GetTupleset().GetRelation(), seen, addRelations)
92107
if err != nil {
93108
return nil, err
94109
}
95110
for _, s := range set.AsSlice() {
96-
targets, err := ts.getTypesForRelationInternal(ctx, s, userset.GetComputedUserset().GetRelation(), seen)
111+
targets, err := ts.getTypesForRelationInternal(ctx, s, userset.GetComputedUserset().GetRelation(), seen, addRelations)
97112
if err != nil {
98113
return nil, err
99114
}
100115
out.Merge(targets)
101116
}
102117
}
103118
if functioned := child.GetFunctionedTupleToUserset(); functioned != nil {
104-
set, err := ts.getTypesForRelationInternal(ctx, defName, functioned.GetTupleset().GetRelation(), seen)
119+
set, err := ts.getTypesForRelationInternal(ctx, defName, functioned.GetTupleset().GetRelation(), seen, addRelations)
105120
if err != nil {
106121
return nil, err
107122
}
108123
for _, s := range set.AsSlice() {
109-
targets, err := ts.getTypesForRelationInternal(ctx, s, functioned.GetComputedUserset().GetRelation(), seen)
124+
targets, err := ts.getTypesForRelationInternal(ctx, s, functioned.GetComputedUserset().GetRelation(), seen, addRelations)
110125
if err != nil {
111126
return nil, err
112127
}

pkg/schema/type_check_test.go

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
1111
)
1212

13-
func TestTypechecking(t *testing.T) {
13+
func TestTypecheckingJustTypes(t *testing.T) {
1414
t.Parallel()
1515
type testcase struct {
1616
name string
@@ -171,6 +171,188 @@ func TestTypechecking(t *testing.T) {
171171
}, compiler.AllowUnprefixedObjectType())
172172
require.NoError(t, err)
173173

174+
res := ResolverForCompiledSchema(*schema)
175+
ts := NewTypeSystem(res)
176+
for _, resource := range schema.ObjectDefinitions {
177+
for _, relation := range resource.Relation {
178+
types, err := ts.GetRecursiveTypesForRelation(context.Background(), resource.Name, relation.Name)
179+
require.NoError(t, err)
180+
181+
rel := resource.Name + "#" + relation.Name
182+
expected, ok := tc.expected[rel]
183+
require.True(t, ok, fmt.Sprintf("expected %v to be in %v", rel, tc.expected))
184+
require.Len(t, types, len(expected), rel)
185+
186+
for _, typ := range types {
187+
require.Contains(t, expected, typ, fmt.Sprintf("expected %v to be in %v", typ, expected))
188+
}
189+
}
190+
}
191+
})
192+
}
193+
}
194+
195+
func TestTypecheckingWithSubrelations(t *testing.T) {
196+
t.Parallel()
197+
type testcase struct {
198+
name string
199+
schemaText string
200+
expected map[string][]string
201+
}
202+
tcs := []testcase{
203+
{
204+
name: "basic arrow",
205+
schemaText: `
206+
definition user {}
207+
208+
definition organization {
209+
relation member: user
210+
}
211+
212+
definition resource {
213+
relation org: organization
214+
relation viewer: user
215+
permission view = org->member + viewer
216+
}
217+
`,
218+
expected: map[string][]string{
219+
"organization#member": {"user"},
220+
"resource#viewer": {"user"},
221+
"resource#org": {"organization"},
222+
"resource#view": {"user"},
223+
},
224+
},
225+
{
226+
name: "multi-type arrow",
227+
schemaText: `
228+
definition user {}
229+
230+
definition organization {
231+
relation member: user
232+
}
233+
234+
definition resource {
235+
relation org: organization
236+
relation viewer: user
237+
permission view = org + viewer
238+
}
239+
`,
240+
expected: map[string][]string{
241+
"organization#member": {"user"},
242+
"resource#viewer": {"user"},
243+
"resource#org": {"organization"},
244+
"resource#view": {"organization", "user"},
245+
},
246+
},
247+
{
248+
name: "functional",
249+
schemaText: `
250+
definition user {}
251+
252+
definition organization {
253+
relation member: user
254+
}
255+
256+
definition resource {
257+
relation org: organization
258+
permission view = org.all(member)
259+
}
260+
`,
261+
expected: map[string][]string{
262+
"organization#member": {"user"},
263+
"resource#viewer": {"user"},
264+
"resource#org": {"organization"},
265+
"resource#view": {"user"},
266+
},
267+
},
268+
{
269+
name: "multi-type rel",
270+
schemaText: `
271+
definition user {}
272+
273+
definition organization {
274+
relation member: user
275+
}
276+
277+
definition resource {
278+
relation viewer: user | organization
279+
}
280+
`,
281+
expected: map[string][]string{
282+
"organization#member": {"user"},
283+
"resource#viewer": {"user", "organization"},
284+
},
285+
},
286+
{
287+
name: "subrel",
288+
schemaText: `
289+
definition user {}
290+
291+
definition organization {
292+
relation member: user
293+
}
294+
295+
definition resource {
296+
relation viewer: organization#member
297+
}
298+
`,
299+
expected: map[string][]string{
300+
"organization#member": {"user"},
301+
"resource#viewer": {"user", "organization#member"},
302+
},
303+
},
304+
{
305+
name: "wildcard",
306+
schemaText: `
307+
definition user {}
308+
309+
definition organization {
310+
relation member: user:*
311+
}
312+
313+
definition resource {
314+
relation viewer: organization#member
315+
}
316+
`,
317+
expected: map[string][]string{
318+
"organization#member": {"user"},
319+
"resource#viewer": {"user", "organization#member"},
320+
},
321+
},
322+
{
323+
name: "banned",
324+
schemaText: `
325+
definition user {}
326+
327+
definition organization {
328+
relation member: user
329+
}
330+
331+
definition resource {
332+
relation viewer: organization#member
333+
relation banned: user
334+
permission view = viewer - banned
335+
}
336+
`,
337+
expected: map[string][]string{
338+
"organization#member": {"user"},
339+
"resource#viewer": {"user", "organization#member"},
340+
"resource#banned": {"user"},
341+
"resource#view": {"user", "organization#member"},
342+
},
343+
},
344+
}
345+
for _, tc := range tcs {
346+
t.Run(tc.name, func(t *testing.T) {
347+
tc := tc
348+
t.Parallel()
349+
350+
schema, err := compiler.Compile(compiler.InputSchema{
351+
Source: "",
352+
SchemaString: tc.schemaText,
353+
}, compiler.AllowUnprefixedObjectType())
354+
require.NoError(t, err)
355+
174356
res := ResolverForCompiledSchema(*schema)
175357
ts := NewTypeSystem(res)
176358
for _, resource := range schema.ObjectDefinitions {

0 commit comments

Comments
 (0)