@@ -10,18 +10,29 @@ import (
10
10
11
11
const ellipsesRelation = "..."
12
12
13
- // GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
13
+ // GetRecursiveTypesForRelation returns, for a given definition and relation, are the potential
14
14
// subject definition names of that relation.
15
+ func (ts * TypeSystem ) GetRecursiveTypesForRelation (ctx context.Context , defName string , relationName string ) ([]string , error ) {
16
+ seen := mapz .NewSet [string ]()
17
+ set , err := ts .getTypesForRelationInternal (ctx , defName , relationName , seen , false )
18
+ if err != nil {
19
+ return nil , err
20
+ }
21
+ return set .AsSlice (), nil
22
+ }
23
+
24
+ // GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
25
+ // subject definition names of that relation, as well as any relation subtypes (eg, `group#member`) that may occur.
15
26
func (ts * TypeSystem ) GetRecursiveSubtypesForRelation (ctx context.Context , defName string , relationName string ) ([]string , error ) {
16
27
seen := mapz .NewSet [string ]()
17
- set , err := ts .getTypesForRelationInternal (ctx , defName , relationName , seen )
28
+ set , err := ts .getTypesForRelationInternal (ctx , defName , relationName , seen , true )
18
29
if err != nil {
19
30
return nil , err
20
31
}
21
32
return set .AsSlice (), nil
22
33
}
23
34
24
- func (ts * TypeSystem ) getTypesForRelationInternal (ctx context.Context , defName string , relationName string , seen * mapz.Set [string ]) (* mapz.Set [string ], error ) {
35
+ func (ts * TypeSystem ) getTypesForRelationInternal (ctx context.Context , defName string , relationName string , seen * mapz.Set [string ], addRelations bool ) (* mapz.Set [string ], error ) {
25
36
id := fmt .Sprint (defName , "#" , relationName )
26
37
if seen .Has (id ) {
27
38
return nil , nil
@@ -36,20 +47,23 @@ func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName s
36
47
return nil , asTypeError (NewRelationNotFoundErr (defName , relationName ))
37
48
}
38
49
if rel .TypeInformation != nil {
39
- return ts .getTypesForInfo (ctx , defName , rel .TypeInformation , seen )
50
+ return ts .getTypesForInfo (ctx , defName , rel .TypeInformation , seen , addRelations )
40
51
} else if rel .UsersetRewrite != nil {
41
- return ts .getTypesForRewrite (ctx , defName , rel .UsersetRewrite , seen )
52
+ return ts .getTypesForRewrite (ctx , defName , rel .UsersetRewrite , seen , addRelations )
42
53
}
43
54
return nil , asTypeError (NewMissingAllowedRelationsErr (defName , relationName ))
44
55
}
45
56
46
- func (ts * TypeSystem ) getTypesForInfo (ctx context.Context , defName string , rel * corev1.TypeInformation , seen * mapz.Set [string ]) (* mapz.Set [string ], error ) {
57
+ func (ts * TypeSystem ) getTypesForInfo (ctx context.Context , defName string , rel * corev1.TypeInformation , seen * mapz.Set [string ], addRelations bool ) (* mapz.Set [string ], error ) {
47
58
out := mapz .NewSet [string ]()
48
59
for _ , dr := range rel .GetAllowedDirectRelations () {
49
60
if dr .GetRelation () == ellipsesRelation {
50
61
out .Add (dr .GetNamespace ())
51
62
} else if dr .GetRelation () != "" {
52
- rest , err := ts .getTypesForRelationInternal (ctx , dr .GetNamespace (), dr .GetRelation (), seen )
63
+ if addRelations {
64
+ out .Add (fmt .Sprintf ("%s#%s" , dr .GetNamespace (), dr .GetRelation ()))
65
+ }
66
+ rest , err := ts .getTypesForRelationInternal (ctx , dr .GetNamespace (), dr .GetRelation (), seen , addRelations )
53
67
if err != nil {
54
68
return nil , err
55
69
}
@@ -62,7 +76,7 @@ func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *
62
76
return out , nil
63
77
}
64
78
65
- func (ts * TypeSystem ) getTypesForRewrite (ctx context.Context , defName string , rel * corev1.UsersetRewrite , seen * mapz.Set [string ]) (* mapz.Set [string ], error ) {
79
+ func (ts * TypeSystem ) getTypesForRewrite (ctx context.Context , defName string , rel * corev1.UsersetRewrite , seen * mapz.Set [string ], addRelations bool ) (* mapz.Set [string ], error ) {
66
80
out := mapz .NewSet [string ]()
67
81
68
82
// We're finding the union of all the things touched, regardless.
@@ -74,39 +88,39 @@ func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, re
74
88
}
75
89
for _ , child := range op .GetChild () {
76
90
if computed := child .GetComputedUserset (); computed != nil {
77
- set , err := ts .getTypesForRelationInternal (ctx , defName , computed .GetRelation (), seen )
91
+ set , err := ts .getTypesForRelationInternal (ctx , defName , computed .GetRelation (), seen , addRelations )
78
92
if err != nil {
79
93
return nil , err
80
94
}
81
95
out .Merge (set )
82
96
}
83
97
if rewrite := child .GetUsersetRewrite (); rewrite != nil {
84
- sub , err := ts .getTypesForRewrite (ctx , defName , rewrite , seen )
98
+ sub , err := ts .getTypesForRewrite (ctx , defName , rewrite , seen , addRelations )
85
99
if err != nil {
86
100
return nil , err
87
101
}
88
102
out .Merge (sub )
89
103
}
90
104
if userset := child .GetTupleToUserset (); userset != nil {
91
- set , err := ts .getTypesForRelationInternal (ctx , defName , userset .GetTupleset ().GetRelation (), seen )
105
+ set , err := ts .getTypesForRelationInternal (ctx , defName , userset .GetTupleset ().GetRelation (), seen , addRelations )
92
106
if err != nil {
93
107
return nil , err
94
108
}
95
109
for _ , s := range set .AsSlice () {
96
- targets , err := ts .getTypesForRelationInternal (ctx , s , userset .GetComputedUserset ().GetRelation (), seen )
110
+ targets , err := ts .getTypesForRelationInternal (ctx , s , userset .GetComputedUserset ().GetRelation (), seen , addRelations )
97
111
if err != nil {
98
112
return nil , err
99
113
}
100
114
out .Merge (targets )
101
115
}
102
116
}
103
117
if functioned := child .GetFunctionedTupleToUserset (); functioned != nil {
104
- set , err := ts .getTypesForRelationInternal (ctx , defName , functioned .GetTupleset ().GetRelation (), seen )
118
+ set , err := ts .getTypesForRelationInternal (ctx , defName , functioned .GetTupleset ().GetRelation (), seen , addRelations )
105
119
if err != nil {
106
120
return nil , err
107
121
}
108
122
for _ , s := range set .AsSlice () {
109
- targets , err := ts .getTypesForRelationInternal (ctx , s , functioned .GetComputedUserset ().GetRelation (), seen )
123
+ targets , err := ts .getTypesForRelationInternal (ctx , s , functioned .GetComputedUserset ().GetRelation (), seen , addRelations )
110
124
if err != nil {
111
125
return nil , err
112
126
}
0 commit comments