Skip to content

Commit 2d08678

Browse files
committed
🔧 Group arguments by category in both Docker and run.py
1 parent 9ebf429 commit 2d08678

File tree

5 files changed

+125
-78
lines changed

5 files changed

+125
-78
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ The value of this attribute should be the same as the `--domain` argument to Apr
142142
Any users with this attribute missing or set to something else will be ignored by Apricot.
143143
This allows you to attach multiple Apricot servers to the same Keycloak instance, each with their own set of users.
144144

145+
:exclamation: You can disable user domain verification with the `--disable-user-domain-verification` command line option :exclamation:
146+
145147
#### Client application
146148

147149
You will need to register an application to interact with `Keycloak`.

apricot/apricot_server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
background_refresh: bool = False,
2929
debug: bool = False,
3030
enable_mirrored_groups: bool = True,
31+
enable_user_domain_verification: bool = True,
3132
redis_host: str | None = None,
3233
redis_port: int | None = None,
3334
refresh_interval: int = 60,
@@ -45,7 +46,8 @@ def __init__(
4546
@param port: Port to expose LDAP on
4647
@param background_refresh: Whether to refresh the LDAP tree in the background
4748
@param debug: Enable debug output
48-
@param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users
49+
@param enable_mirrored_groups: Whether to create a mirrored LDAP group-of-groups for each group-of-users
50+
@param enable_user_domain_verification: Whether to verify users belong to the correct domain
4951
@param redis_host: Host for a Redis cache (if used)
5052
@param redis_port: Port for a Redis cache (if used)
5153
@param refresh_interval: Interval after which the LDAP information is stale
@@ -93,6 +95,7 @@ def __init__(
9395
domain,
9496
oauth_client,
9597
enable_mirrored_groups=enable_mirrored_groups,
98+
enable_user_domain_verification=enable_user_domain_verification,
9699
)
97100

98101
# Create an LDAPServerFactory

apricot/oauth/oauth_data_adaptor.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ def __init__(
3232
oauth_client: OAuthClient,
3333
*,
3434
enable_mirrored_groups: bool,
35+
enable_user_domain_verification: bool,
3536
) -> None:
3637
"""Initialise an OAuthDataAdaptor.
3738
3839
@param domain: The root domain of the LDAP tree
39-
@param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users
40+
@param enable_mirrored_groups: Whether to create a mirrored LDAP group-of-groups for each group-of-users
41+
@param enable_user_domain_verification: Whether to verify users belong to the correct domain
4042
@param oauth_client: An OAuth client used to construct the LDAP tree
4143
"""
4244
self.debug = oauth_client.debug
4345
self.domain = domain
4446
self.oauth_client = oauth_client
4547
self.root_dn = "DC=" + domain.replace(".", ",DC=")
4648
self.enable_mirrored_groups = enable_mirrored_groups
49+
self.enable_user_domain_verification = enable_user_domain_verification
4750

4851
def _dn_from_group_cn(self: Self, group_cn: str) -> str:
4952
return f"CN={group_cn},OU=groups,{self.root_dn}"
@@ -187,7 +190,6 @@ def _validate_groups(
187190
def _validate_users(
188191
self: Self,
189192
annotated_users: list[tuple[JSONDict, list[type[LDAPObjectClass]]]],
190-
domain: str,
191193
) -> list[LDAPAttributeAdaptor]:
192194
"""Return a list of LDAPAttributeAdaptors representing validated user data."""
193195
if self.debug:
@@ -196,18 +198,23 @@ def _validate_users(
196198
for user_dict, required_classes in annotated_users:
197199
name = user_dict.get("cn", "unknown")
198200
try:
199-
if (user_domain := user_dict.get("domain", None)) == domain:
200-
output.append(
201-
LDAPAttributeAdaptor.from_attributes(
202-
user_dict,
203-
required_classes=required_classes,
204-
),
205-
)
206-
else:
201+
# Verify user domain if enabled
202+
if (
203+
self.enable_user_domain_verification
204+
and (user_domain := user_dict.get("domain", None)) != self.domain
205+
):
207206
log.msg(f"... user '{name}' failed validation.")
208207
log.msg(
209-
f" -> 'domain': expected '{domain}' but '{user_domain}' was provided.",
208+
f" -> 'domain': expected '{self.domain}' but '{user_domain}' was provided.",
210209
)
210+
continue
211+
# Construct an LDAPAttributeAdaptor from the user attributes
212+
output.append(
213+
LDAPAttributeAdaptor.from_attributes(
214+
user_dict,
215+
required_classes=required_classes,
216+
),
217+
)
211218
except ValidationError as exc:
212219
log.msg(f"... user '{name}' failed validation.")
213220
for error in exc.errors():
@@ -222,7 +229,7 @@ def retrieve_all(
222229
"""Retrieve and return validated user and group information."""
223230
annotated_groups, annotated_users = self._retrieve_entries()
224231
validated_groups = self._validate_groups(annotated_groups)
225-
validated_users = self._validate_users(annotated_users, self.domain)
232+
validated_users = self._validate_users(annotated_users)
226233
if self.debug:
227234
log.msg(
228235
f"Validated {len(validated_groups)} groups and {len(validated_users)} users.",

docker/entrypoint.sh

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,37 @@
22
# shellcheck disable=SC2086
33
# shellcheck disable=SC2089
44

5-
# Required arguments
5+
# Optional arguments
6+
EXTRA_OPTS=""
7+
8+
9+
# Common server-level options
10+
if [ -z "${PORT}" ]; then
11+
PORT="1389"
12+
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] PORT environment variable is not set: using default of '${PORT}'"
13+
fi
14+
15+
if [ -n "${DEBUG}" ]; then
16+
EXTRA_OPTS="${EXTRA_OPTS} --debug"
17+
fi
18+
19+
20+
# LDAP tree arguments
21+
if [ -z "${DOMAIN}" ]; then
22+
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] DOMAIN environment variable is not set"
23+
exit 1
24+
fi
25+
26+
if [ -n "${DISABLE_MIRRORED_GROUPS}" ]; then
27+
EXTRA_OPTS="${EXTRA_OPTS} --disable-mirrored-groups"
28+
fi
29+
30+
if [ -n "${DISABLE_USER_DOMAIN_VERIFICATION}" ]; then
31+
EXTRA_OPTS="${EXTRA_OPTS} --disable-user-domain-verification"
32+
fi
33+
34+
35+
# OAuth client arguments
636
if [ -z "${BACKEND}" ]; then
737
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] BACKEND environment variable is not set"
838
exit 1
@@ -18,27 +48,14 @@ if [ -z "${CLIENT_SECRET}" ]; then
1848
exit 1
1949
fi
2050

21-
if [ -z "${DOMAIN}" ]; then
22-
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] DOMAIN environment variable is not set"
23-
exit 1
24-
fi
25-
26-
27-
# Arguments with defaults
28-
if [ -z "${PORT}" ]; then
29-
PORT="1389"
30-
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] PORT environment variable is not set: using default of '${PORT}'"
31-
fi
32-
3351

34-
# Optional arguments
35-
EXTRA_OPTS=""
36-
if [ -n "${DEBUG}" ]; then
37-
EXTRA_OPTS="${EXTRA_OPTS} --debug"
52+
# LDAP refresh arguments
53+
if [ -n "${BACKGROUND_REFRESH}" ]; then
54+
EXTRA_OPTS="${EXTRA_OPTS} --background-refresh"
3855
fi
3956

40-
if [ -n "${DISABLE_MIRRORED_GROUPS}" ]; then
41-
EXTRA_OPTS="${EXTRA_OPTS} --disable-mirrored-groups"
57+
if [ -n "${REFRESH_INTERVAL}" ]; then
58+
EXTRA_OPTS="${EXTRA_OPTS} --refresh-interval $REFRESH_INTERVAL"
4259
fi
4360

4461

@@ -61,13 +78,13 @@ if [ -n "${KEYCLOAK_DOMAIN_ATTRIBUTE}" ]; then
6178
fi
6279

6380

64-
# LDAP refresh arguments
65-
if [ -n "${BACKGROUND_REFRESH}" ]; then
66-
EXTRA_OPTS="${EXTRA_OPTS} --background-refresh"
67-
fi
68-
69-
if [ -n "${REFRESH_INTERVAL}" ]; then
70-
EXTRA_OPTS="${EXTRA_OPTS} --refresh-interval $REFRESH_INTERVAL"
81+
# Redis arguments
82+
if [ -n "${REDIS_HOST}" ]; then
83+
if [ -z "${REDIS_PORT}" ]; then
84+
REDIS_PORT="6379"
85+
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'"
86+
fi
87+
EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT"
7188
fi
7289

7390

@@ -85,16 +102,6 @@ if [ -n "${TLS_PORT}" ]; then
85102
fi
86103

87104

88-
# Redis arguments
89-
if [ -n "${REDIS_HOST}" ]; then
90-
if [ -z "${REDIS_PORT}" ]; then
91-
REDIS_PORT="6379"
92-
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'"
93-
fi
94-
EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT"
95-
fi
96-
97-
98105
# Run the server
99106
hatch run python run.py \
100107
--backend "${BACKEND}" \

run.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,68 +10,93 @@
1010
prog="Apricot",
1111
description="Apricot is a proxy for delegating LDAP requests to an OpenID Connect backend.",
1212
)
13-
# Common options needed for all backends
13+
# Common server-level options
1414
parser.add_argument(
15-
"-b",
16-
"--backend",
17-
type=OAuthBackend,
18-
help="Which OAuth backend to use.",
15+
"-p",
16+
"--port",
17+
type=int,
18+
default=1389,
19+
help="Port to run on.",
1920
)
2021
parser.add_argument(
22+
"--debug",
23+
action="store_true",
24+
help="Enable debug logging.",
25+
)
26+
27+
# LDAP tree settings
28+
ldap_group = parser.add_argument_group("LDAP tree settings")
29+
ldap_group.add_argument(
2130
"-d",
2231
"--domain",
2332
type=str,
2433
help="Which domain users belong to.",
34+
required=True,
2535
)
26-
parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.")
27-
parser.add_argument(
28-
"-p",
29-
"--port",
30-
type=int,
31-
default=1389,
32-
help="Port to run on.",
36+
ldap_group.add_argument(
37+
"--disable-mirrored-groups",
38+
action="store_false",
39+
default=True,
40+
dest="enable_mirrored_groups",
41+
help="Disable creation of mirrored groups.",
3342
)
34-
parser.add_argument(
43+
ldap_group.add_argument(
44+
"--disable-user-domain-verification",
45+
action="store_false",
46+
default=True,
47+
dest="enable_user_domain_verification",
48+
help="Disable check that users belong to the correct domain.",
49+
)
50+
51+
# OAuth client settings
52+
oauth_group = parser.add_argument_group("OAuth settings")
53+
oauth_group.add_argument(
54+
"-b",
55+
"--backend",
56+
type=OAuthBackend,
57+
help="Which OAuth backend to use.",
58+
required=True,
59+
)
60+
oauth_group.add_argument(
61+
"-i",
62+
"--client-id",
63+
type=str,
64+
help="OAuth client ID.",
65+
required=True,
66+
)
67+
oauth_group.add_argument(
3568
"-s",
3669
"--client-secret",
3770
type=str,
3871
help="OAuth client secret.",
72+
required=True,
3973
)
40-
parser.add_argument(
74+
75+
# Options for refreshing the tree
76+
refresh_group = parser.add_argument_group("Refresh settings")
77+
refresh_group.add_argument(
4178
"--background-refresh",
4279
action="store_true",
4380
default=False,
4481
help="Refresh in the background instead of as needed per request",
4582
)
46-
parser.add_argument(
47-
"--debug",
48-
action="store_true",
49-
help="Enable debug logging.",
50-
)
51-
parser.add_argument(
52-
"--disable-mirrored-groups",
53-
action="store_false",
54-
default=True,
55-
dest="enable_mirrored_groups",
56-
help="Disable creation of mirrored groups.",
57-
)
58-
parser.add_argument(
83+
refresh_group.add_argument(
5984
"--refresh-interval",
6085
type=int,
6186
default=60,
6287
help="How often to refresh the database in seconds",
6388
)
6489

6590
# Options for Microsoft Entra backend
66-
entra_group = parser.add_argument_group("Microsoft Entra")
91+
entra_group = parser.add_argument_group("Microsoft Entra backend")
6792
entra_group.add_argument(
6893
"--entra-tenant-id",
6994
type=str,
7095
help="Microsoft Entra tenant ID.",
7196
)
7297

7398
# Options for Keycloak backend
74-
keycloak_group = parser.add_argument_group("Keycloak")
99+
keycloak_group = parser.add_argument_group("Keycloak backend")
75100
keycloak_group.add_argument(
76101
"--keycloak-base-url",
77102
type=str,
@@ -88,6 +113,7 @@
88113
default="domain",
89114
help="The attribute in Keycloak that contains the users' domain.",
90115
)
116+
91117
# Options for Redis cache
92118
redis_group = parser.add_argument_group("Redis")
93119
redis_group.add_argument(
@@ -100,6 +126,7 @@
100126
type=int,
101127
help="Port for Redis server.",
102128
)
129+
103130
# Options for TLS
104131
tls_group = parser.add_argument_group("TLS")
105132
tls_group.add_argument(
@@ -118,6 +145,7 @@
118145
type=str,
119146
help="Location of TLS private key (pem).",
120147
)
148+
121149
# Parse arguments
122150
args = parser.parse_args()
123151

0 commit comments

Comments
 (0)