@@ -10,18 +10,30 @@ import (
10
10
11
11
const ellipsesRelation = "..."
12
12
13
- // GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
13
+ // GetRecursiveTypesForRelation returns, for a given definition and relation, are the potential
14
14
// subject definition names of that relation.
15
+ func (ts * TypeSystem ) GetRecursiveTypesForRelation (ctx context.Context , defName string , relationName string ) ([]string , error ) {
16
+ seen := mapz .NewSet [string ]()
17
+ set , err := ts .getTypesForRelationInternal (ctx , defName , relationName , seen , false )
18
+ if err != nil {
19
+ return nil , err
20
+ }
21
+ return set .AsSlice (), nil
22
+ }
23
+
24
+ // GetRecursiveSubtypesForRelation returns, for a given definition and relation, are the potential
25
+ // subject definition names of that relation, as well as any relation subtypes (eg, `group#member`) that may occur.
15
26
func (ts * TypeSystem ) GetRecursiveSubtypesForRelation (ctx context.Context , defName string , relationName string ) ([]string , error ) {
16
27
seen := mapz .NewSet [string ]()
17
- set , err := ts .getTypesForRelationInternal (ctx , defName , relationName , seen )
28
+ set , err := ts .getTypesForRelationInternal (ctx , defName , relationName , seen , true )
18
29
if err != nil {
19
30
return nil , err
20
31
}
21
32
return set .AsSlice (), nil
33
+
22
34
}
23
35
24
- func (ts * TypeSystem ) getTypesForRelationInternal (ctx context.Context , defName string , relationName string , seen * mapz.Set [string ]) (* mapz.Set [string ], error ) {
36
+ func (ts * TypeSystem ) getTypesForRelationInternal (ctx context.Context , defName string , relationName string , seen * mapz.Set [string ], addRelations bool ) (* mapz.Set [string ], error ) {
25
37
id := fmt .Sprint (defName , "#" , relationName )
26
38
if seen .Has (id ) {
27
39
return nil , nil
@@ -36,20 +48,23 @@ func (ts *TypeSystem) getTypesForRelationInternal(ctx context.Context, defName s
36
48
return nil , asTypeError (NewRelationNotFoundErr (defName , relationName ))
37
49
}
38
50
if rel .TypeInformation != nil {
39
- return ts .getTypesForInfo (ctx , defName , rel .TypeInformation , seen )
51
+ return ts .getTypesForInfo (ctx , defName , rel .TypeInformation , seen , addRelations )
40
52
} else if rel .UsersetRewrite != nil {
41
- return ts .getTypesForRewrite (ctx , defName , rel .UsersetRewrite , seen )
53
+ return ts .getTypesForRewrite (ctx , defName , rel .UsersetRewrite , seen , addRelations )
42
54
}
43
55
return nil , asTypeError (NewMissingAllowedRelationsErr (defName , relationName ))
44
56
}
45
57
46
- func (ts * TypeSystem ) getTypesForInfo (ctx context.Context , defName string , rel * corev1.TypeInformation , seen * mapz.Set [string ]) (* mapz.Set [string ], error ) {
58
+ func (ts * TypeSystem ) getTypesForInfo (ctx context.Context , defName string , rel * corev1.TypeInformation , seen * mapz.Set [string ], addRelations bool ) (* mapz.Set [string ], error ) {
47
59
out := mapz .NewSet [string ]()
48
60
for _ , dr := range rel .GetAllowedDirectRelations () {
49
61
if dr .GetRelation () == ellipsesRelation {
50
62
out .Add (dr .GetNamespace ())
51
63
} else if dr .GetRelation () != "" {
52
- rest , err := ts .getTypesForRelationInternal (ctx , dr .GetNamespace (), dr .GetRelation (), seen )
64
+ if addRelations {
65
+ out .Add (fmt .Sprintf ("%s#%s" , dr .GetNamespace (), dr .GetRelation ()))
66
+ }
67
+ rest , err := ts .getTypesForRelationInternal (ctx , dr .GetNamespace (), dr .GetRelation (), seen , addRelations )
53
68
if err != nil {
54
69
return nil , err
55
70
}
@@ -62,7 +77,7 @@ func (ts *TypeSystem) getTypesForInfo(ctx context.Context, defName string, rel *
62
77
return out , nil
63
78
}
64
79
65
- func (ts * TypeSystem ) getTypesForRewrite (ctx context.Context , defName string , rel * corev1.UsersetRewrite , seen * mapz.Set [string ]) (* mapz.Set [string ], error ) {
80
+ func (ts * TypeSystem ) getTypesForRewrite (ctx context.Context , defName string , rel * corev1.UsersetRewrite , seen * mapz.Set [string ], addRelations bool ) (* mapz.Set [string ], error ) {
66
81
out := mapz .NewSet [string ]()
67
82
68
83
// We're finding the union of all the things touched, regardless.
@@ -74,39 +89,39 @@ func (ts *TypeSystem) getTypesForRewrite(ctx context.Context, defName string, re
74
89
}
75
90
for _ , child := range op .GetChild () {
76
91
if computed := child .GetComputedUserset (); computed != nil {
77
- set , err := ts .getTypesForRelationInternal (ctx , defName , computed .GetRelation (), seen )
92
+ set , err := ts .getTypesForRelationInternal (ctx , defName , computed .GetRelation (), seen , addRelations )
78
93
if err != nil {
79
94
return nil , err
80
95
}
81
96
out .Merge (set )
82
97
}
83
98
if rewrite := child .GetUsersetRewrite (); rewrite != nil {
84
- sub , err := ts .getTypesForRewrite (ctx , defName , rewrite , seen )
99
+ sub , err := ts .getTypesForRewrite (ctx , defName , rewrite , seen , addRelations )
85
100
if err != nil {
86
101
return nil , err
87
102
}
88
103
out .Merge (sub )
89
104
}
90
105
if userset := child .GetTupleToUserset (); userset != nil {
91
- set , err := ts .getTypesForRelationInternal (ctx , defName , userset .GetTupleset ().GetRelation (), seen )
106
+ set , err := ts .getTypesForRelationInternal (ctx , defName , userset .GetTupleset ().GetRelation (), seen , addRelations )
92
107
if err != nil {
93
108
return nil , err
94
109
}
95
110
for _ , s := range set .AsSlice () {
96
- targets , err := ts .getTypesForRelationInternal (ctx , s , userset .GetComputedUserset ().GetRelation (), seen )
111
+ targets , err := ts .getTypesForRelationInternal (ctx , s , userset .GetComputedUserset ().GetRelation (), seen , addRelations )
97
112
if err != nil {
98
113
return nil , err
99
114
}
100
115
out .Merge (targets )
101
116
}
102
117
}
103
118
if functioned := child .GetFunctionedTupleToUserset (); functioned != nil {
104
- set , err := ts .getTypesForRelationInternal (ctx , defName , functioned .GetTupleset ().GetRelation (), seen )
119
+ set , err := ts .getTypesForRelationInternal (ctx , defName , functioned .GetTupleset ().GetRelation (), seen , addRelations )
105
120
if err != nil {
106
121
return nil , err
107
122
}
108
123
for _ , s := range set .AsSlice () {
109
- targets , err := ts .getTypesForRelationInternal (ctx , s , functioned .GetComputedUserset ().GetRelation (), seen )
124
+ targets , err := ts .getTypesForRelationInternal (ctx , s , functioned .GetComputedUserset ().GetRelation (), seen , addRelations )
110
125
if err != nil {
111
126
return nil , err
112
127
}
0 commit comments