Skip to content

Commit f986e41

Browse files
committed
♻️ Switch KeycloakClient to use UidCache for generating missing UIDs
1 parent 38a28e4 commit f986e41

File tree

3 files changed

+64
-86
lines changed

3 files changed

+64
-86
lines changed

apricot/apricot_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
port: int,
2323
*,
2424
debug: bool = False,
25-
enable_mirrored_groups: bool,
25+
enable_mirrored_groups: bool = True,
2626
redis_host: str | None = None,
2727
redis_port: int | None = None,
2828
**kwargs: Any,

apricot/cache/uid_cache.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,31 @@ def _get_max_uid(self, category: str | None) -> int:
7777
keys = self.keys()
7878
values = [*self.values(keys), -999]
7979
return max(values)
80+
81+
def overwrite_group_uid(self, identifier: str, uid: int) -> None:
82+
"""
83+
Set UID for a group, overwriting the existing value if there is one
84+
85+
@param identifier: Identifier for group
86+
@param uid: Desired UID
87+
"""
88+
return self.overwrite_uid(identifier, category="group", uid=uid)
89+
90+
def overwrite_user_uid(self, identifier: str, uid: int) -> None:
91+
"""
92+
Get UID for a user, constructing one if necessary
93+
94+
@param identifier: Identifier for user
95+
@param uid: Desired UID
96+
"""
97+
return self.overwrite_uid(identifier, category="user", uid=uid)
98+
99+
def overwrite_uid(self, identifier: str, category: str, uid: int) -> None:
100+
"""
101+
Set UID, overwriting the existing one if necessary.
102+
103+
@param identifier: Identifier for object
104+
@param category: Category the object belongs to
105+
@param uid: Desired UID
106+
"""
107+
self.set(f"{category}-{identifier}", uid)

apricot/oauth/keycloak_client.py

Lines changed: 35 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,6 @@
55
from .oauth_client import OAuthClient
66

77

8-
def get_single_value_attribute(
9-
obj: JSONDict, key: str, default: str | None = None
10-
) -> Any:
11-
for part in key.split("."):
12-
obj = obj.get(part) # type: ignore
13-
if obj is None:
14-
return default
15-
if isinstance(obj, list):
16-
try:
17-
return next(iter(obj))
18-
except StopIteration:
19-
pass
20-
else:
21-
return obj
22-
return default
23-
24-
258
class KeycloakClient(OAuthClient):
269
"""OAuth client for the Keycloak backend."""
2710

@@ -62,41 +45,25 @@ def groups(self) -> list[JSONDict]:
6245
if len(data) != self.max_rows:
6346
break
6447

65-
group_data = sorted(
66-
group_data,
67-
key=lambda group: int(
68-
get_single_value_attribute(
69-
group, "attributes.gid", default="9999999999"
70-
),
71-
base=10,
72-
),
73-
)
74-
75-
next_gid = max(
76-
*(
77-
int(
78-
get_single_value_attribute(
79-
group, "attributes.gid", default="-1"
80-
),
81-
base=10,
48+
# Ensure that gid attribute exists for all groups
49+
for group_dict in group_data:
50+
group_dict["attributes"] = group_dict.get("attributes", {})
51+
if "gid" not in group_dict["attributes"]:
52+
group_dict["attributes"]["gid"] = None
53+
# If group_gid exists then set the cache to the same value
54+
# This ensures that any groups without a `gid` attribute will receive a
55+
# UID that does not overlap with existing groups
56+
if group_gid := group_dict["attributes"]["gid"]:
57+
self.uid_cache.overwrite_group_uid(
58+
group_dict["id"], int(group_gid, 10)
8259
)
83-
+ 1
84-
for group in group_data
85-
),
86-
3000,
87-
)
8860

61+
# Read group attributes
8962
for group_dict in group_data:
90-
group_gid = get_single_value_attribute(
91-
group_dict, "attributes.gid", default=None
92-
)
93-
if group_gid:
94-
group_gid = int(group_gid, 10)
95-
if not group_gid:
96-
group_gid = next_gid
97-
next_gid += 1
98-
group_dict["attributes"] = group_dict.get("attributes", {})
99-
group_dict["attributes"]["gid"] = [str(group_gid)]
63+
if not group_dict["attributes"]["gid"]:
64+
group_dict["attributes"]["gid"] = [
65+
str(self.uid_cache.get_group_uid(group_dict["id"]))
66+
]
10067
self.request(
10168
f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}",
10269
method="PUT",
@@ -105,7 +72,7 @@ def groups(self) -> list[JSONDict]:
10572
attributes: JSONDict = {}
10673
attributes["cn"] = group_dict.get("name", None)
10774
attributes["description"] = group_dict.get("id", None)
108-
attributes["gidNumber"] = group_gid
75+
attributes["gidNumber"] = group_dict["attributes"]["gid"]
10976
attributes["oauth_id"] = group_dict.get("id", None)
11077
# Add membership attributes
11178
members = self.query(
@@ -132,44 +99,27 @@ def users(self) -> list[JSONDict]:
13299
if len(data) != self.max_rows:
133100
break
134101

135-
user_data = sorted(
136-
user_data,
137-
key=lambda user: int(
138-
get_single_value_attribute(
139-
user, "attributes.uid", default="9999999999"
140-
),
141-
base=10,
142-
),
143-
)
144-
145-
next_uid = max(
146-
*(
147-
int(
148-
get_single_value_attribute(
149-
user, "attributes.uid", default="-1"
150-
),
151-
base=10,
102+
# Ensure that uid attribute exists for all users
103+
for user_dict in user_data:
104+
user_dict["attributes"] = user_dict.get("attributes", {})
105+
if "uid" not in user_dict["attributes"]:
106+
user_dict["attributes"]["uid"] = None
107+
# If user_uid exists then set the cache to the same value.
108+
# This ensures that any groups without a `gid` attribute will receive a
109+
# UID that does not overlap with existing groups
110+
if user_uid := user_dict["attributes"]["uid"]:
111+
self.uid_cache.overwrite_user_uid(
112+
user_dict["id"], int(user_uid, 10)
152113
)
153-
+ 1
154-
for user in user_data
155-
),
156-
3000,
157-
)
158114

115+
# Read user attributes
159116
for user_dict in sorted(
160117
user_data, key=lambda user: user["createdTimestamp"]
161118
):
162-
user_uid = get_single_value_attribute(
163-
user_dict, "attributes.uid", default=None
164-
)
165-
if user_uid:
166-
user_uid = int(user_uid, base=10)
167-
if not user_uid:
168-
user_uid = next_uid
169-
next_uid += 1
170-
171-
user_dict["attributes"] = user_dict.get("attributes", {})
172-
user_dict["attributes"]["uid"] = [str(user_uid)]
119+
if not user_dict["attributes"]["uid"]:
120+
user_dict["attributes"]["uid"] = [
121+
str(self.uid_cache.get_user_uid(user_dict["id"]))
122+
]
173123
self.request(
174124
f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}",
175125
method="PUT",
@@ -189,12 +139,12 @@ def users(self) -> list[JSONDict]:
189139
attributes["displayName"] = full_name
190140
attributes["mail"] = user_dict.get("email")
191141
attributes["description"] = ""
192-
attributes["gidNumber"] = user_uid
142+
attributes["gidNumber"] = user_dict["attributes"]["uid"]
193143
attributes["givenName"] = first_name if first_name else ""
194144
attributes["homeDirectory"] = f"/home/{username}" if username else None
195145
attributes["oauth_id"] = user_dict.get("id", None)
196146
attributes["sn"] = last_name if last_name else ""
197-
attributes["uidNumber"] = user_uid
147+
attributes["uidNumber"] = user_dict["attributes"]["uid"]
198148
output.append(attributes)
199149
except KeyError:
200150
pass

0 commit comments

Comments
 (0)