5
5
from .oauth_client import OAuthClient
6
6
7
7
8
- def get_single_value_attribute (
9
- obj : JSONDict , key : str , default : str | None = None
10
- ) -> Any :
11
- for part in key .split ("." ):
12
- obj = obj .get (part ) # type: ignore
13
- if obj is None :
14
- return default
15
- if isinstance (obj , list ):
16
- try :
17
- return next (iter (obj ))
18
- except StopIteration :
19
- pass
20
- else :
21
- return obj
22
- return default
23
-
24
-
25
8
class KeycloakClient (OAuthClient ):
26
9
"""OAuth client for the Keycloak backend."""
27
10
@@ -62,41 +45,25 @@ def groups(self) -> list[JSONDict]:
62
45
if len (data ) != self .max_rows :
63
46
break
64
47
65
- group_data = sorted (
66
- group_data ,
67
- key = lambda group : int (
68
- get_single_value_attribute (
69
- group , "attributes.gid" , default = "9999999999"
70
- ),
71
- base = 10 ,
72
- ),
73
- )
74
-
75
- next_gid = max (
76
- * (
77
- int (
78
- get_single_value_attribute (
79
- group , "attributes.gid" , default = "-1"
80
- ),
81
- base = 10 ,
48
+ # Ensure that gid attribute exists for all groups
49
+ for group_dict in group_data :
50
+ group_dict ["attributes" ] = group_dict .get ("attributes" , {})
51
+ if "gid" not in group_dict ["attributes" ]:
52
+ group_dict ["attributes" ]["gid" ] = None
53
+ # If group_gid exists then set the cache to the same value
54
+ # This ensures that any groups without a `gid` attribute will receive a
55
+ # UID that does not overlap with existing groups
56
+ if group_gid := group_dict ["attributes" ]["gid" ]:
57
+ self .uid_cache .overwrite_group_uid (
58
+ group_dict ["id" ], int (group_gid , 10 )
82
59
)
83
- + 1
84
- for group in group_data
85
- ),
86
- 3000 ,
87
- )
88
60
61
+ # Read group attributes
89
62
for group_dict in group_data :
90
- group_gid = get_single_value_attribute (
91
- group_dict , "attributes.gid" , default = None
92
- )
93
- if group_gid :
94
- group_gid = int (group_gid , 10 )
95
- if not group_gid :
96
- group_gid = next_gid
97
- next_gid += 1
98
- group_dict ["attributes" ] = group_dict .get ("attributes" , {})
99
- group_dict ["attributes" ]["gid" ] = [str (group_gid )]
63
+ if not group_dict ["attributes" ]["gid" ]:
64
+ group_dict ["attributes" ]["gid" ] = [
65
+ str (self .uid_cache .get_group_uid (group_dict ["id" ]))
66
+ ]
100
67
self .request (
101
68
f"{ self .base_url } /admin/realms/{ self .realm } /groups/{ group_dict ['id' ]} " ,
102
69
method = "PUT" ,
@@ -105,7 +72,7 @@ def groups(self) -> list[JSONDict]:
105
72
attributes : JSONDict = {}
106
73
attributes ["cn" ] = group_dict .get ("name" , None )
107
74
attributes ["description" ] = group_dict .get ("id" , None )
108
- attributes ["gidNumber" ] = group_gid
75
+ attributes ["gidNumber" ] = group_dict [ "attributes" ][ "gid" ]
109
76
attributes ["oauth_id" ] = group_dict .get ("id" , None )
110
77
# Add membership attributes
111
78
members = self .query (
@@ -132,44 +99,27 @@ def users(self) -> list[JSONDict]:
132
99
if len (data ) != self .max_rows :
133
100
break
134
101
135
- user_data = sorted (
136
- user_data ,
137
- key = lambda user : int (
138
- get_single_value_attribute (
139
- user , "attributes.uid" , default = "9999999999"
140
- ),
141
- base = 10 ,
142
- ),
143
- )
144
-
145
- next_uid = max (
146
- * (
147
- int (
148
- get_single_value_attribute (
149
- user , "attributes.uid" , default = "-1"
150
- ),
151
- base = 10 ,
102
+ # Ensure that uid attribute exists for all users
103
+ for user_dict in user_data :
104
+ user_dict ["attributes" ] = user_dict .get ("attributes" , {})
105
+ if "uid" not in user_dict ["attributes" ]:
106
+ user_dict ["attributes" ]["uid" ] = None
107
+ # If user_uid exists then set the cache to the same value.
108
+ # This ensures that any groups without a `gid` attribute will receive a
109
+ # UID that does not overlap with existing groups
110
+ if user_uid := user_dict ["attributes" ]["uid" ]:
111
+ self .uid_cache .overwrite_user_uid (
112
+ user_dict ["id" ], int (user_uid , 10 )
152
113
)
153
- + 1
154
- for user in user_data
155
- ),
156
- 3000 ,
157
- )
158
114
115
+ # Read user attributes
159
116
for user_dict in sorted (
160
117
user_data , key = lambda user : user ["createdTimestamp" ]
161
118
):
162
- user_uid = get_single_value_attribute (
163
- user_dict , "attributes.uid" , default = None
164
- )
165
- if user_uid :
166
- user_uid = int (user_uid , base = 10 )
167
- if not user_uid :
168
- user_uid = next_uid
169
- next_uid += 1
170
-
171
- user_dict ["attributes" ] = user_dict .get ("attributes" , {})
172
- user_dict ["attributes" ]["uid" ] = [str (user_uid )]
119
+ if not user_dict ["attributes" ]["uid" ]:
120
+ user_dict ["attributes" ]["uid" ] = [
121
+ str (self .uid_cache .get_user_uid (user_dict ["id" ]))
122
+ ]
173
123
self .request (
174
124
f"{ self .base_url } /admin/realms/{ self .realm } /users/{ user_dict ['id' ]} " ,
175
125
method = "PUT" ,
@@ -189,12 +139,12 @@ def users(self) -> list[JSONDict]:
189
139
attributes ["displayName" ] = full_name
190
140
attributes ["mail" ] = user_dict .get ("email" )
191
141
attributes ["description" ] = ""
192
- attributes ["gidNumber" ] = user_uid
142
+ attributes ["gidNumber" ] = user_dict [ "attributes" ][ "uid" ]
193
143
attributes ["givenName" ] = first_name if first_name else ""
194
144
attributes ["homeDirectory" ] = f"/home/{ username } " if username else None
195
145
attributes ["oauth_id" ] = user_dict .get ("id" , None )
196
146
attributes ["sn" ] = last_name if last_name else ""
197
- attributes ["uidNumber" ] = user_uid
147
+ attributes ["uidNumber" ] = user_dict [ "attributes" ][ "uid" ]
198
148
output .append (attributes )
199
149
except KeyError :
200
150
pass
0 commit comments