Skip to content

Commit dbaedfb

Browse files
committed
Add support for keycloak and option to disable group-of-groups
1 parent 44d6263 commit dbaedfb

13 files changed

+274
-26
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,28 @@ Do this as follows:
151151
- `Microsoft Graph` > `GroupMember.Read.All` (application)
152152
- `Microsoft Graph` > `User.Read.All` (delegated)
153153
- Select this and click the `Grant admin consent` button (otherwise manual consent is needed from each user)
154+
155+
156+
### Keycloak
157+
158+
You will need to use the following command line arguments:
159+
160+
```bash
161+
--backend Keycloak --keycloak-base-url "<your hostname>/<path to keycloak>" --keycloak-realm "<your realm>"
162+
```
163+
164+
You will need to register an application to interact with `Keycloak`.
165+
Do this as follows:
166+
167+
- Create a new `Client` in your `Keycloak` instance.
168+
- Set the name to whatever you choose (e.g. `apricot`)
169+
- Enable `Client authentication`
170+
- Enable the following authentication flows and disable the rest:
171+
- Direct access grants
172+
- Service account roles
173+
- Under `Credentials` copy `client secret`
174+
- Under `Service account roles`:
175+
- Ensure that the following role are assigned
176+
- `realm-management` > `view-users`
177+
- `realm-management` > `manage-users`
178+
- `realm-management` > `query-groups`

apricot/apricot_server.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import sys
23
from typing import Any, cast
34

@@ -19,6 +20,7 @@ def __init__(
1920
client_secret: str,
2021
domain: str,
2122
port: int,
23+
enable_group_of_groups: bool,
2224
*,
2325
debug: bool = False,
2426
redis_host: str | None = None,
@@ -45,12 +47,14 @@ def __init__(
4547
try:
4648
if self.debug:
4749
log.msg(f"Creating an OAuthClient for {backend}.")
48-
oauth_client = OAuthClientMap[backend](
50+
oauth_backend = OAuthClientMap[backend]
51+
oauth_backend_args = inspect.getfullargspec(oauth_backend.__init__).args
52+
oauth_client = oauth_backend(
4953
client_id=client_id,
5054
client_secret=client_secret,
5155
debug=debug,
5256
uid_cache=uid_cache,
53-
**kwargs,
57+
**{k: v for k, v in kwargs.items() if k in oauth_backend_args},
5458
)
5559
except Exception as exc:
5660
msg = f"Could not construct an OAuth client for the '{backend}' backend.\n{exc!s}"
@@ -59,7 +63,7 @@ def __init__(
5963
# Create an LDAPServerFactory
6064
if self.debug:
6165
log.msg("Creating an LDAPServerFactory.")
62-
factory = OAuthLDAPServerFactory(domain, oauth_client)
66+
factory = OAuthLDAPServerFactory(domain, oauth_client, enable_group_of_groups)
6367

6468
# Attach a listening endpoint
6569
if self.debug:

apricot/ldap/oauth_ldap_server_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88

99

1010
class OAuthLDAPServerFactory(ServerFactory):
11-
def __init__(self, domain: str, oauth_client: OAuthClient):
11+
def __init__(self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool):
1212
"""
1313
Initialise an LDAPServerFactory
1414
1515
@param oauth_client: An OAuth client used to construct the LDAP tree
1616
"""
1717
# Create an LDAP lookup tree
18-
self.adaptor = OAuthLDAPTree(domain, oauth_client)
18+
self.adaptor = OAuthLDAPTree(domain, oauth_client, enable_group_of_groups)
1919

2020
def __repr__(self) -> str:
2121
return f"{self.__class__.__name__} using adaptor {self.adaptor}"

apricot/ldap/oauth_ldap_tree.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class OAuthLDAPTree:
1515

1616
def __init__(
17-
self, domain: str, oauth_client: OAuthClient, refresh_interval: int = 60
17+
self, domain: str, oauth_client: OAuthClient, enable_group_of_groups: bool, refresh_interval: int = 60
1818
) -> None:
1919
"""
2020
Initialise an OAuthLDAPTree
@@ -29,6 +29,7 @@ def __init__(
2929
self.oauth_client = oauth_client
3030
self.refresh_interval = refresh_interval
3131
self.root_: OAuthLDAPEntry | None = None
32+
self.enable_group_of_groups = enable_group_of_groups
3233

3334
@property
3435
def dn(self) -> DistinguishedName:
@@ -47,7 +48,7 @@ def root(self) -> OAuthLDAPEntry:
4748
):
4849
# Update users and groups from the OAuth server
4950
log.msg("Retrieving OAuth data.")
50-
oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client)
51+
oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client, self.enable_group_of_groups)
5152

5253
# Create a root node for the tree
5354
log.msg("Rebuilding LDAP tree.")

apricot/models/ldap_attribute_adaptor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def __init__(self, attributes: dict[Any, Any]) -> None:
88
self.attributes = {
99
str(k): list(map(str, v)) if isinstance(v, list) else [str(v)]
1010
for k, v in attributes.items()
11+
if v is not None
1112
}
1213

1314
@property

apricot/models/ldap_inetorgperson.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from .ldap_organizational_person import LDAPOrganizationalPerson
24

35

@@ -15,6 +17,7 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson):
1517
displayName: str # noqa: N815
1618
givenName: str # noqa: N815
1719
sn: str
20+
mail: Optional[str] = None
1821

1922
def names(self) -> list[str]:
2023
return [*super().names(), "inetOrgPerson"]

apricot/oauth/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from apricot.types import LDAPAttributeDict, LDAPControlTuple
22

33
from .enums import OAuthBackend
4+
from .keycloak_client import KeycloakClient
45
from .microsoft_entra_client import MicrosoftEntraClient
56
from .oauth_client import OAuthClient
67
from .oauth_data_adaptor import OAuthDataAdaptor
78

8-
OAuthClientMap = {OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient}
9+
OAuthClientMap = {
10+
OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient,
11+
OAuthBackend.KEYCLOAK: KeycloakClient,
12+
}
913

1014
__all__ = [
1115
"LDAPAttributeDict",

apricot/oauth/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ class OAuthBackend(str, Enum):
55
"""Available OAuth backends."""
66

77
MICROSOFT_ENTRA = "MicrosoftEntra"
8+
KEYCLOAK = "Keycloak"

apricot/oauth/keycloak_client.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Any, cast
2+
3+
from apricot.types import JSONDict
4+
5+
from .oauth_client import OAuthClient
6+
7+
8+
def get_single_value_attribute(obj: JSONDict, key: str, default=None) -> Any:
9+
for part in key.split("."):
10+
obj = obj.get(part)
11+
if obj is None:
12+
return default
13+
if isinstance(obj, list):
14+
try:
15+
return next(iter(obj))
16+
except StopIteration:
17+
pass
18+
else:
19+
return obj
20+
return default
21+
22+
23+
class KeycloakClient(OAuthClient):
24+
"""OAuth client for the Keycloak backend."""
25+
26+
max_rows = 100
27+
28+
def __init__(
29+
self,
30+
keycloak_base_url: str,
31+
keycloak_realm: str,
32+
**kwargs: Any,
33+
):
34+
self.base_url = keycloak_base_url
35+
self.realm = keycloak_realm
36+
37+
redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL
38+
scopes = [] # this is the default scope
39+
token_url = (
40+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token"
41+
)
42+
43+
super().__init__(
44+
redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs,
45+
)
46+
47+
def extract_token(self, json_response: JSONDict) -> str:
48+
return str(json_response["access_token"])
49+
50+
def groups(self) -> list[JSONDict]:
51+
output = []
52+
try:
53+
group_data = []
54+
while data := self.query(
55+
f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false"
56+
):
57+
group_data.extend(data)
58+
if len(data) != self.max_rows:
59+
break
60+
61+
group_data = sorted(group_data, key=lambda g: int(get_single_value_attribute(g, "attributes.gid", default="9999999999"), 10))
62+
63+
next_gid = max(
64+
*(
65+
int(get_single_value_attribute(g, "attributes.gid", default="-1"), 10)+1
66+
for g in group_data
67+
),
68+
3000
69+
)
70+
71+
for group_dict in cast(
72+
list[JSONDict],
73+
group_data,
74+
):
75+
group_gid = get_single_value_attribute(group_dict, "attributes.gid", default=None)
76+
if group_gid:
77+
group_gid = int(group_gid, 10)
78+
if not group_gid:
79+
group_gid = next_gid
80+
next_gid += 1
81+
group_dict["attributes"] = group_dict.get("attributes", {})
82+
group_dict["attributes"]["gid"] = [str(group_gid)]
83+
self.request(
84+
f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}",
85+
method="PUT",
86+
json=group_dict
87+
)
88+
attributes: JSONDict = {}
89+
attributes["cn"] = group_dict.get("name", None)
90+
attributes["description"] = group_dict.get("id", None)
91+
attributes["gidNumber"] = group_gid
92+
attributes["oauth_id"] = group_dict.get("id", None)
93+
# Add membership attributes
94+
members = self.query(
95+
f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members"
96+
)
97+
attributes["memberUid"] = [
98+
user["username"]
99+
for user in cast(list[JSONDict], members)
100+
]
101+
output.append(attributes)
102+
except KeyError:
103+
pass
104+
return output
105+
106+
def users(self) -> list[JSONDict]:
107+
output = []
108+
try:
109+
user_data = []
110+
while data := self.query(
111+
f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false"
112+
):
113+
user_data.extend(data)
114+
if len(data) != self.max_rows:
115+
break
116+
117+
user_data = sorted(user_data, key=lambda u: int(get_single_value_attribute(u, "attributes.uid", default="9999999999"), 10))
118+
119+
next_uid = max(
120+
*(
121+
int(get_single_value_attribute(g, "attributes.uid", default="-1"), 10)+1
122+
for g in user_data
123+
),
124+
3000
125+
)
126+
127+
for user_dict in cast(
128+
list[JSONDict],
129+
sorted(user_data, key=lambda user: user["createdTimestamp"]),
130+
):
131+
user_uid = get_single_value_attribute(user_dict, "attributes.uid", default=None)
132+
if user_uid:
133+
user_uid = int(user_uid, 10)
134+
if not user_uid:
135+
user_uid = next_uid
136+
next_uid += 1
137+
138+
user_dict["attributes"] = user_dict.get("attributes", {})
139+
user_dict["attributes"]["uid"] = [str(user_uid)]
140+
self.request(
141+
f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}",
142+
method="PUT",
143+
json=user_dict
144+
)
145+
# Get user attributes
146+
first_name = user_dict.get("firstName", None)
147+
last_name = user_dict.get("lastName", None)
148+
full_name = " ".join(filter(lambda x: x, [first_name, last_name])) or None
149+
username = user_dict.get("username")
150+
attributes: JSONDict = {}
151+
attributes["cn"] = username
152+
attributes["uid"] = username
153+
attributes["oauth_username"] = username
154+
attributes["displayName"] = full_name
155+
attributes["mail"] = user_dict.get("email")
156+
attributes["description"] = ""
157+
attributes["gidNumber"] = user_uid
158+
attributes["givenName"] = first_name if first_name else ""
159+
attributes["homeDirectory"] = f"/home/{username}" if username else None
160+
attributes["oauth_id"] = user_dict.get("id", None)
161+
attributes["sn"] = last_name if last_name else ""
162+
attributes["uidNumber"] = user_uid
163+
output.append(attributes)
164+
except KeyError:
165+
pass
166+
return output

apricot/oauth/oauth_client.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from abc import ABC, abstractmethod
3+
from http import HTTPStatus
34
from typing import Any
45

56
import requests
@@ -116,9 +117,7 @@ def query(self, url: str) -> dict[str, Any]:
116117
def query_(url: str) -> requests.Response:
117118
return self.session_application.get( # type: ignore[no-any-return]
118119
url=url,
119-
headers={"Authorization": f"Bearer {self.bearer_token}"},
120-
client_id=self.session_application._client.client_id,
121-
client_secret=self.client_secret,
120+
headers={"Authorization": f"Bearer {self.bearer_token}"}
122121
)
123122

124123
try:
@@ -130,6 +129,28 @@ def query_(url: str) -> requests.Response:
130129
result = query_(url)
131130
return result.json() # type: ignore
132131

132+
def request(self, *args, method="GET", **kwargs) -> dict[str, Any]:
133+
"""
134+
Make a query against the OAuth backend
135+
"""
136+
137+
def query_(*args, **kwargs) -> requests.Response:
138+
return self.session_application.request( # type: ignore[no-any-return]
139+
method,
140+
*args, **kwargs,
141+
headers={"Authorization": f"Bearer {self.bearer_token}"}
142+
)
143+
144+
try:
145+
result = query_(*args, **kwargs)
146+
result.raise_for_status()
147+
except (TokenExpiredError, requests.exceptions.HTTPError):
148+
log.msg("Authentication token has expired.")
149+
self.bearer_token_ = None
150+
result = query_( *args, **kwargs)
151+
if result.status_code != HTTPStatus.NO_CONTENT:
152+
return result.json() # type: ignore
153+
133154
def verify(self, username: str, password: str) -> bool:
134155
"""
135156
Verify username and password by attempting to authenticate against the OAuth backend.

0 commit comments

Comments
 (0)