diff --git a/README.md b/README.md index f00251b..170917d 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Start the `Apricot` server on port 1389 by running: python run.py --client-id "" --client-secret "" --backend "" --port 1389 --domain "" --redis-host "" ``` -Alternatively, you can run in Docker by editing `docker/docker-compose.yaml` and running: +If you prefer to use Docker, you can edit `docker/docker-compose.yaml` and run: ```bash docker compose up @@ -67,7 +67,9 @@ member: ## Primary groups -Note that each user will have an associated group to act as its POSIX user primary group +:exclamation: You can disable the creation of mirrored groups with the `--disable-primary-groups` command line option :exclamation: + +Apricot creates an associated group for each user, which acts as its POSIX user primary group. For example: @@ -97,6 +99,8 @@ member: CN=sherlock.holmes,OU=users,DC= ## Mirrored groups +:exclamation: You can disable the creation of mirrored groups with the `--disable-mirrored-groups` command line option :exclamation: + Each group of users will have an associated group-of-groups where each user in the group will have its user primary group in the group-of-groups. Note that these groups-of-groups are **not** `posixGroup`s as POSIX does not allow nested groups. @@ -109,6 +113,7 @@ objectClass: posixGroup objectClass: top ... member: CN=sherlock.holmes,OU=users,DC= +... ``` will have an associated group-of-groups @@ -122,6 +127,32 @@ member: CN=sherlock.holmes,OU=groups,DC= ... ``` +This allows a user to make a request for "all primary user groups needed by members of group X" without getting a large number of primary user groups for unrelated users. To do this, you will need an LDAP request that looks like: + +```ldif +(&(objectClass=posixGroup)(|(CN=Detectives)(memberOf=Primary user groups for Detectives))) +``` + +which will return: + +```ldif +dn:CN=Detectives,OU=groups,DC= +objectClass: groupOfNames +objectClass: posixGroup +objectClass: top +... +member: CN=sherlock.holmes,OU=users,DC= +... + +dn: CN=sherlock.holmes,OU=groups,DC= +objectClass: groupOfNames +objectClass: posixGroup +objectClass: top +... +member: CN=sherlock.holmes,OU=users,DC= +... +``` + ## OpenID Connect Instructions for specific OpenID Connect backends below. @@ -146,8 +177,34 @@ Do this as follows: - Set the expiry time to whatever is relevant for your use-case - You **must** record the value of this secret at **creation time**, as it will not be visible later. - Under `API permissions`: - - Ensure that the following permissions are enabled + - Enable the following permissions: - `Microsoft Graph` > `User.Read.All` (application) - `Microsoft Graph` > `GroupMember.Read.All` (application) - `Microsoft Graph` > `User.Read.All` (delegated) - - Select this and click the `Grant admin consent` button (otherwise manual consent is needed from each user) + - Select this and click the `Grant admin consent` button (otherwise each user will need to manually consent) + +### Keycloak + +You will need to use the following command line arguments: + +```bash +--backend Keycloak --keycloak-base-url "/" --keycloak-realm "" +``` + +You will need to register an application to interact with `Keycloak`. +Do this as follows: + +- Create a new `Client` in your `Keycloak` instance. + - Set the name to whatever you choose (e.g. `apricot`) + - Enable `Client authentication` + - Enable the following authentication flows and disable the rest: + - Direct access grants + - Service account roles +- Under `Credentials` copy `client secret` +- Under `Service account roles`: + - Click on `Assign role` then `Filter by clients` + - Assign the following roles: + - `realm-management` > `view-users` + - `realm-management` > `manage-users` + - `realm-management` > `query-groups` + - `realm-management` > `query-users` diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 7163ffb..fa98c22 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,3 +1,4 @@ +import inspect import sys from typing import Any, cast @@ -21,6 +22,7 @@ def __init__( port: int, *, debug: bool = False, + enable_mirrored_groups: bool = True, redis_host: str | None = None, redis_port: int | None = None, **kwargs: Any, @@ -45,12 +47,16 @@ def __init__( try: if self.debug: log.msg(f"Creating an OAuthClient for {backend}.") - oauth_client = OAuthClientMap[backend]( + oauth_backend = OAuthClientMap[backend] + oauth_backend_args = inspect.getfullargspec( + oauth_backend.__init__ # type: ignore + ).args + oauth_client = oauth_backend( client_id=client_id, client_secret=client_secret, debug=debug, uid_cache=uid_cache, - **kwargs, + **{k: v for k, v in kwargs.items() if k in oauth_backend_args}, ) except Exception as exc: msg = f"Could not construct an OAuth client for the '{backend}' backend.\n{exc!s}" @@ -59,7 +65,9 @@ def __init__( # Create an LDAPServerFactory if self.debug: log.msg("Creating an LDAPServerFactory.") - factory = OAuthLDAPServerFactory(domain, oauth_client) + factory = OAuthLDAPServerFactory( + domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups + ) # Attach a listening endpoint if self.debug: diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 4a1d919..24ac506 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -9,7 +9,7 @@ class RedisCache(UidCache): def __init__(self, redis_host: str, redis_port: int) -> None: self.redis_host = redis_host self.redis_port = redis_port - self.cache_: "redis.Redis[str]" | None = None + self.cache_: "redis.Redis[str]" | None = None # noqa: UP037 @property def cache(self) -> "redis.Redis[str]": diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index ab46029..eb9c729 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -77,3 +77,31 @@ def _get_max_uid(self, category: str | None) -> int: keys = self.keys() values = [*self.values(keys), -999] return max(values) + + def overwrite_group_uid(self, identifier: str, uid: int) -> None: + """ + Set UID for a group, overwriting the existing value if there is one + + @param identifier: Identifier for group + @param uid: Desired UID + """ + return self.overwrite_uid(identifier, category="group", uid=uid) + + def overwrite_user_uid(self, identifier: str, uid: int) -> None: + """ + Get UID for a user, constructing one if necessary + + @param identifier: Identifier for user + @param uid: Desired UID + """ + return self.overwrite_uid(identifier, category="user", uid=uid) + + def overwrite_uid(self, identifier: str, category: str, uid: int) -> None: + """ + Set UID, overwriting the existing one if necessary. + + @param identifier: Identifier for object + @param category: Category the object belongs to + @param uid: Desired UID + """ + self.set(f"{category}-{identifier}", uid) diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 2890b35..bcabc6c 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -8,14 +8,18 @@ class OAuthLDAPServerFactory(ServerFactory): - def __init__(self, domain: str, oauth_client: OAuthClient): + def __init__( + self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool + ): """ Initialise an LDAPServerFactory @param oauth_client: An OAuth client used to construct the LDAP tree """ # Create an LDAP lookup tree - self.adaptor = OAuthLDAPTree(domain, oauth_client) + self.adaptor = OAuthLDAPTree( + domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups + ) def __repr__(self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 88333ec..d9eb133 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -14,7 +14,12 @@ class OAuthLDAPTree: def __init__( - self, domain: str, oauth_client: OAuthClient, refresh_interval: int = 60 + self, + domain: str, + oauth_client: OAuthClient, + *, + enable_mirrored_groups: bool, + refresh_interval: int = 60, ) -> None: """ Initialise an OAuthLDAPTree @@ -29,6 +34,7 @@ def __init__( self.oauth_client = oauth_client self.refresh_interval = refresh_interval self.root_: OAuthLDAPEntry | None = None + self.enable_mirrored_groups = enable_mirrored_groups @property def dn(self) -> DistinguishedName: @@ -47,7 +53,11 @@ def root(self) -> OAuthLDAPEntry: ): # Update users and groups from the OAuth server log.msg("Retrieving OAuth data.") - oauth_adaptor = OAuthDataAdaptor(self.domain, self.oauth_client) + oauth_adaptor = OAuthDataAdaptor( + self.domain, + self.oauth_client, + enable_mirrored_groups=self.enable_mirrored_groups, + ) # Create a root node for the tree log.msg("Rebuilding LDAP tree.") diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index dfd3bd1..40f986d 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -8,6 +8,7 @@ def __init__(self, attributes: dict[Any, Any]) -> None: self.attributes = { str(k): list(map(str, v)) if isinstance(v, list) else [str(v)] for k, v in attributes.items() + if v is not None } @property diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index fe86b8e..51e5cb5 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -12,9 +12,12 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): """ cn: str - displayName: str # noqa: N815 - givenName: str # noqa: N815 + displayName: str | None = None # noqa: N815 + employeeNumber: str | None = None # noqa: N815 + givenName: str | None = None # noqa: N815 sn: str + mail: str | None = None + telephoneNumber: str | None = None # noqa: N815 def names(self) -> list[str]: return [*super().names(), "inetOrgPerson"] diff --git a/apricot/oauth/__init__.py b/apricot/oauth/__init__.py index 0cd8aa5..c5d6268 100644 --- a/apricot/oauth/__init__.py +++ b/apricot/oauth/__init__.py @@ -1,11 +1,15 @@ from apricot.types import LDAPAttributeDict, LDAPControlTuple from .enums import OAuthBackend +from .keycloak_client import KeycloakClient from .microsoft_entra_client import MicrosoftEntraClient from .oauth_client import OAuthClient from .oauth_data_adaptor import OAuthDataAdaptor -OAuthClientMap = {OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient} +OAuthClientMap = { + OAuthBackend.MICROSOFT_ENTRA: MicrosoftEntraClient, + OAuthBackend.KEYCLOAK: KeycloakClient, +} __all__ = [ "LDAPAttributeDict", diff --git a/apricot/oauth/enums.py b/apricot/oauth/enums.py index 8675218..d9c356d 100644 --- a/apricot/oauth/enums.py +++ b/apricot/oauth/enums.py @@ -5,3 +5,4 @@ class OAuthBackend(str, Enum): """Available OAuth backends.""" MICROSOFT_ENTRA = "MicrosoftEntra" + KEYCLOAK = "Keycloak" diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py new file mode 100644 index 0000000..5b584c7 --- /dev/null +++ b/apricot/oauth/keycloak_client.py @@ -0,0 +1,155 @@ +from typing import Any, cast + +from apricot.types import JSONDict + +from .oauth_client import OAuthClient + + +class KeycloakClient(OAuthClient): + """OAuth client for the Keycloak backend.""" + + max_rows = 100 + + def __init__( + self, + keycloak_base_url: str, + keycloak_realm: str, + **kwargs: Any, + ): + self.base_url = keycloak_base_url + self.realm = keycloak_realm + + redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL + scopes: list[str] = [] # this is the default scope + token_url = f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token" + + super().__init__( + redirect_uri=redirect_uri, + scopes=scopes, + token_url=token_url, + **kwargs, + ) + + def extract_token(self, json_response: JSONDict) -> str: + return str(json_response["access_token"]) + + def groups(self) -> list[JSONDict]: + output = [] + try: + group_data: list[JSONDict] = [] + while data := self.query( + f"{self.base_url}/admin/realms/{self.realm}/groups?first={len(group_data)}&max={self.max_rows}&briefRepresentation=false", + use_client_secret=False, + ): + group_data.extend(cast(list[JSONDict], data)) + if len(data) != self.max_rows: + break + + # Ensure that gid attribute exists for all groups + for group_dict in group_data: + group_dict["attributes"] = group_dict.get("attributes", {}) + if "gid" not in group_dict["attributes"]: + group_dict["attributes"]["gid"] = None + # If group_gid exists then set the cache to the same value + # This ensures that any groups without a `gid` attribute will receive a + # UID that does not overlap with existing groups + if (group_gid := group_dict["attributes"]["gid"]) and len( + group_dict["attributes"]["gid"] + ) == 1: + self.uid_cache.overwrite_group_uid( + group_dict["id"], int(group_gid[0], 10) + ) + + # Read group attributes + for group_dict in group_data: + if not group_dict["attributes"]["gid"]: + group_dict["attributes"]["gid"] = [ + str(self.uid_cache.get_group_uid(group_dict["id"])) + ] + self.request( + f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", + method="PUT", + json=group_dict, + ) + attributes: JSONDict = {} + attributes["cn"] = group_dict.get("name", None) + attributes["description"] = group_dict.get("id", None) + attributes["gidNumber"] = group_dict["attributes"]["gid"][0] + attributes["oauth_id"] = group_dict.get("id", None) + # Add membership attributes + members = self.query( + f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}/members", + use_client_secret=False, + ) + attributes["memberUid"] = [ + user["username"] for user in cast(list[JSONDict], members) + ] + output.append(attributes) + except KeyError: + pass + return output + + def users(self) -> list[JSONDict]: + output = [] + try: + user_data: list[JSONDict] = [] + while data := self.query( + f"{self.base_url}/admin/realms/{self.realm}/users?first={len(user_data)}&max={self.max_rows}&briefRepresentation=false", + use_client_secret=False, + ): + user_data.extend(cast(list[JSONDict], data)) + if len(data) != self.max_rows: + break + + # Ensure that uid attribute exists for all users + for user_dict in user_data: + user_dict["attributes"] = user_dict.get("attributes", {}) + if "uid" not in user_dict["attributes"]: + user_dict["attributes"]["uid"] = None + # If user_uid exists then set the cache to the same value. + # This ensures that any groups without a `gid` attribute will receive a + # UID that does not overlap with existing groups + if (user_uid := user_dict["attributes"]["uid"]) and len( + user_dict["attributes"]["uid"] + ) == 1: + self.uid_cache.overwrite_user_uid( + user_dict["id"], int(user_uid[0], 10) + ) + + # Read user attributes + for user_dict in sorted( + user_data, key=lambda user: user["createdTimestamp"] + ): + if not user_dict["attributes"]["uid"]: + user_dict["attributes"]["uid"] = [ + str(self.uid_cache.get_user_uid(user_dict["id"])) + ] + self.request( + f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", + method="PUT", + json=user_dict, + ) + # Get user attributes + first_name = user_dict.get("firstName", None) + last_name = user_dict.get("lastName", None) + full_name = ( + " ".join(filter(lambda x: x, [first_name, last_name])) or None + ) + username = user_dict.get("username") + attributes: JSONDict = {} + attributes["cn"] = username + attributes["uid"] = username + attributes["oauth_username"] = username + attributes["displayName"] = full_name + attributes["mail"] = user_dict.get("email") + attributes["description"] = "" + attributes["gidNumber"] = user_dict["attributes"]["uid"][0] + attributes["givenName"] = first_name if first_name else "" + attributes["homeDirectory"] = f"/home/{username}" if username else None + attributes["oauth_id"] = user_dict.get("id", None) + attributes["sn"] = last_name if last_name else "" + attributes["uidNumber"] = user_dict["attributes"]["uid"][0] + output.append(attributes) + except KeyError: + pass + return output diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 857b553..b47f98c 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -1,5 +1,6 @@ import os from abc import ABC, abstractmethod +from http import HTTPStatus from typing import Any import requests @@ -108,26 +109,46 @@ def users(self) -> list[JSONDict]: """ pass - def query(self, url: str) -> dict[str, Any]: + def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: """ Make a query against the OAuth backend """ + kwargs = ( + { + "client_id": self.session_application._client.client_id, + "client_secret": self.client_secret, + } + if use_client_secret + else {} + ) + return self.request( + url=url, + method="GET", + **kwargs, + ) - def query_(url: str) -> requests.Response: - return self.session_application.get( # type: ignore[no-any-return] - url=url, + def request(self, *args: Any, method: str = "GET", **kwargs: Any) -> dict[str, Any]: + """ + Make a request to the OAuth backend + """ + + def query_(*args: Any, **kwargs: Any) -> requests.Response: + return self.session_application.request( # type: ignore[no-any-return] + method, + *args, + **kwargs, headers={"Authorization": f"Bearer {self.bearer_token}"}, - client_id=self.session_application._client.client_id, - client_secret=self.client_secret, ) try: - result = query_(url) + result = query_(*args, **kwargs) result.raise_for_status() except (TokenExpiredError, requests.exceptions.HTTPError): log.msg("Authentication token has expired.") self.bearer_token_ = None - result = query_(url) + result = query_(*args, **kwargs) + if result.status_code == HTTPStatus.NO_CONTENT: + return {} return result.json() # type: ignore def verify(self, username: str, password: str) -> bool: diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 2263af5..e2e6ea5 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -21,10 +21,13 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" - def __init__(self, domain: str, oauth_client: OAuthClient): + def __init__( + self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool + ): self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=") + self.enable_mirrored_groups = enable_mirrored_groups # Retrieve and validate user and group information annotated_groups, annotated_users = self._retrieve_entries() @@ -105,20 +108,21 @@ def _retrieve_entries( # Add one group of groups for each existing group. # Its members are the primary user groups for each original group member. groups_of_groups = [] - for group in oauth_groups: - group_dict = {} - group_dict["cn"] = f"Primary user groups for {group['cn']}" - group_dict["description"] = ( - f"Primary user groups for members of '{group['cn']}'" - ) - # Replace each member user with a member group - group_dict["member"] = [ - str(member).replace("OU=users", "OU=groups") - for member in group["member"] - ] - # Groups do not have UIDs so memberUid must be empty - group_dict["memberUid"] = [] - groups_of_groups.append(group_dict) + if self.enable_mirrored_groups: + for group in oauth_groups: + group_dict = {} + group_dict["cn"] = f"Primary user groups for {group['cn']}" + group_dict["description"] = ( + f"Primary user groups for members of '{group['cn']}'" + ) + # Replace each member user with a member group + group_dict["member"] = [ + str(member).replace("OU=users", "OU=groups") + for member in group["member"] + ] + # Groups do not have UIDs so memberUid must be empty + group_dict["memberUid"] = [] + groups_of_groups.append(group_dict) # Ensure memberOf is set correctly for users for child_dict in oauth_users: diff --git a/apricot/types.py b/apricot/types.py index e93f9ea..5cc0617 100644 --- a/apricot/types.py +++ b/apricot/types.py @@ -1,5 +1,6 @@ from typing import Any JSONDict = dict[str, Any] +JSONKey = list[Any] | dict[str, Any] | Any LDAPAttributeDict = dict[str, list[str]] LDAPControlTuple = tuple[str, bool, Any] diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 8a21379..04261da 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -37,6 +37,10 @@ if [ -n "${DEBUG}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --debug" fi +if [ -n "${DISABLE_MIRRORED_GROUPS}" ]; then + EXTRA_OPTS="${EXTRA_OPTS} --disable-mirrored-groups" +fi + if [ -n "${ENTRA_TENANT_ID}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID" fi @@ -49,6 +53,14 @@ if [ -n "${REDIS_HOST}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT" fi +if [ -n "${KEYCLOAK_BASE_URL}" ]; then + if [ -z "${KEYCLOAK_REALM}" ]; then + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] KEYCLOAK_REALM environment variable is not set" + exit 1 + fi + EXTRA_OPTS="${EXTRA_OPTS} --keycloak-base-url $KEYCLOAK_BASE_URL --keycloak-realm $KEYCLOAK_REALM" +fi + # Run the server hatch run python run.py \ --backend "${BACKEND}" \ diff --git a/run.py b/run.py index 5ac4230..c228f20 100644 --- a/run.py +++ b/run.py @@ -16,10 +16,18 @@ parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") + parser.add_argument("--disable-mirrored-groups", action="store_false", + dest="enable_mirrored_groups", default=True, + help="Disable creation of mirrored groups.") parser.add_argument("--debug", action="store_true", help="Enable debug logging.") # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) + + # Options for Keycloak backend + keycloak_group = parser.add_argument_group("Keycloak") + keycloak_group.add_argument("--keycloak-base-url", type=str, help="Keycloak base URL.", required=False) + keycloak_group.add_argument("--keycloak-realm", type=str, help="Keycloak Realm.", required=False) # Options for Redis cache redis_group = parser.add_argument_group("Redis") redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.")