Skip to content

Better logging #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions apricot/apricot_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import inspect
import sys
import logging
from typing import TYPE_CHECKING, Any, Self, cast

from twisted.internet import reactor, task
from twisted.internet.endpoints import quoteStringArgument, serverFromString
from twisted.logger import Logger
from twisted.python import log

from apricot.cache import LocalCache, RedisCache, UidCache
Expand Down Expand Up @@ -59,44 +60,58 @@ def __init__(
@param tls_certificate: TLS certificate for LDAPS
@param tls_private_key: TLS private key for LDAPS
"""
self.debug = debug
# Set up Python root logger
logging.basicConfig(
level=logging.INFO,
datefmt=r"%Y-%m-%d %H:%M:%S",
format=r"%(asctime)s [%(levelname)-8s] %(message)s",
)
if debug:
logging.getLogger("apricot").setLevel(logging.DEBUG)

# Log to stdout
log.startLogging(sys.stdout)
# Configure Twisted loggers to write to Python logging
observer = log.PythonLoggingObserver("apricot")
observer.start()
self.logger = Logger()

# Load the Twisted reactor
self.reactor = cast("IReactorCore", reactor)

# Initialise the UID cache
uid_cache: UidCache
if redis_host and redis_port:
log.msg(
f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'.",
self.logger.info(
"Using a Redis user-id cache at host '{host}' on port '{port}'.",
host=redis_host,
port=redis_port,
)
uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port)
else:
log.msg("Using a local user-id cache.")
self.logger.info("Using a local user-id cache.")
uid_cache = LocalCache()

# Initialise the appropriate OAuth client
try:
if self.debug:
log.msg(f"Creating an OAuthClient for {backend}.")
self.logger.debug(
"Creating an OAuthClient for the {backend} backend.",
backend=backend.value,
)
oauth_backend = OAuthClientMap[backend]
oauth_backend_args = inspect.getfullargspec(
oauth_backend.__init__, # type: ignore[misc]
).args
oauth_client = oauth_backend(
client_id=client_id,
client_secret=client_secret,
debug=debug,
uid_cache=uid_cache,
**{k: v for k, v in kwargs.items() if k in oauth_backend_args},
)
except Exception as exc:
msg = f"Could not construct an OAuth client for the '{backend}' backend.\n{exc!s}"
msg = f"Could not construct an OAuth client for the {backend.value} backend.\n{exc!s}"
raise ValueError(msg) from exc

# Initialise the OAuth data adaptor
if self.debug:
log.msg("Creating an OAuthDataAdaptor.")
self.logger.debug("Creating an OAuthDataAdaptor.")
oauth_adaptor = OAuthDataAdaptor(
domain,
oauth_client,
Expand All @@ -105,9 +120,8 @@ def __init__(
enable_user_domain_verification=enable_user_domain_verification,
)

# Create an LDAPServerFactory
if self.debug:
log.msg("Creating an LDAPServerFactory.")
# Create an OAuthLDAPServerFactory
self.logger.debug("Creating an OAuthLDAPServerFactory.")
factory = OAuthLDAPServerFactory(
oauth_adaptor,
oauth_client,
Expand All @@ -116,17 +130,16 @@ def __init__(
)

if background_refresh:
if self.debug:
log.msg(
f"Starting background refresh (interval={refresh_interval})",
)
self.logger.info(
"Starting background refresh (interval={interval})",
interval=refresh_interval,
)
loop = task.LoopingCall(factory.adaptor.refresh)
loop.start(refresh_interval)

# Attach a listening endpoint
if self.debug:
log.msg("Attaching a listening endpoint (plain).")
endpoint: IStreamServerEndpoint = serverFromString(reactor, f"tcp:{port}")
self.logger.info("Listening for LDAP requests on port {port}.", port=port)
endpoint: IStreamServerEndpoint = serverFromString(self.reactor, f"tcp:{port}")
endpoint.listen(factory)

# Attach a listening endpoint
Expand All @@ -137,19 +150,16 @@ def __init__(
if not tls_private_key:
msg = "No TLS private key provided. Please provide one with --tls-private-key or disable TLS."
raise ValueError(msg)
if self.debug:
log.msg("Attaching a listening endpoint (TLS).")
self.logger.info(
"Listening for LDAPS requests on port {port}.",
port=tls_port,
)
ssl_endpoint: IStreamServerEndpoint = serverFromString(
reactor,
self.reactor,
f"ssl:{tls_port}:privateKey={quoteStringArgument(tls_private_key)}:certKey={quoteStringArgument(tls_certificate)}",
)
ssl_endpoint.listen(factory)

# Load the Twisted reactor
self.reactor = cast("IReactorCore", reactor)

def run(self: Self) -> None:
"""Start the Twisted reactor."""
if self.debug:
log.msg("Starting the Twisted reactor.")
self.reactor.run()
9 changes: 7 additions & 2 deletions apricot/ldap/oauth_ldap_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
LDAPInvalidCredentials,
)
from twisted.internet import defer
from twisted.python import log
from twisted.logger import Logger

from apricot.oauth import LDAPAttributeDict, OAuthClient

Expand All @@ -35,6 +35,7 @@ def __init__(
@param attributes: Attributes of the object.
@param oauth_client: An OAuth client used for binding
"""
self.logger = Logger()
self.oauth_client_ = oauth_client
if not isinstance(dn, DistinguishedName):
dn = DistinguishedName(stringValue=dn)
Expand Down Expand Up @@ -73,7 +74,10 @@ def add_child(
try:
output = self.addChild(rdn, attributes)
except LDAPEntryAlreadyExists:
log.msg(f"Refusing to add child '{rdn.getText()}' as it already exists.")
self.logger.warn(
"Refusing to add child '{child}' as it already exists.",
child=rdn.getText(),
)
output = self._children[rdn.getText()]
return cast("OAuthLDAPEntry", output)

Expand All @@ -84,6 +88,7 @@ def _bind(password: bytes) -> OAuthLDAPEntry:
if self.oauth_client.verify(username=oauth_username, password=s_password):
return self
msg = f"Invalid password for user '{oauth_username}'."
self.logger.error(msg)
raise LDAPInvalidCredentials(msg)

return defer.maybeDeferred(_bind, password)
Expand Down
2 changes: 1 addition & 1 deletion apricot/ldap/oauth_ldap_server_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def buildProtocol(self: Self, addr: IAddress) -> Protocol: # noqa: N802
@param addr: an object implementing L{IAddress}
"""
id(addr) # ignore unused arguments
proto = ReadOnlyLDAPServer(debug=self.adaptor.debug)
proto = ReadOnlyLDAPServer()
proto.factory = self.adaptor
return proto
84 changes: 57 additions & 27 deletions apricot/ldap/oauth_ldap_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from ldaptor.interfaces import IConnectedLDAPEntry, ILDAPEntry
from ldaptor.protocols.ldap.distinguishedname import DistinguishedName
from twisted.python import log
from twisted.logger import Logger
from zope.interface import implementer

from apricot.ldap.oauth_ldap_entry import OAuthLDAPEntry

if TYPE_CHECKING:
from twisted.internet import defer
from twisted.python.failure import Failure

from apricot.oauth import OAuthClient, OAuthDataAdaptor

Expand All @@ -36,8 +37,8 @@ def __init__(
@param refresh_interval: Interval in seconds after which the tree must be refreshed
"""
self.background_refresh = background_refresh
self.debug = oauth_client.debug
self.last_update = time.monotonic()
self.logger = Logger()
self.oauth_adaptor = oauth_adaptor
self.oauth_client = oauth_client
self.refresh_interval = refresh_interval
Expand Down Expand Up @@ -66,15 +67,37 @@ def __repr__(self: Self) -> str:
return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}"

def lookup(self: Self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]:
"""Lookup the referred to by dn.
"""Lookup a DistinguishedName in the LDAP tree.

@return: A Deferred returning an ILDAPEntry.
"""

def result_callback(ldap_entry: OAuthLDAPEntry | None) -> OAuthLDAPEntry | None:
if ldap_entry:
self.logger.debug(
"LDAP lookup succeeded: found {dn}",
dn=ldap_entry.dn.getText(),
)
return ldap_entry

def failure_callback(failure: Failure) -> Failure:
self.logger.debug(
"LDAP lookup failed: {error}",
error=failure.getErrorMessage(),
)
return failure

# Construct a complete DN
if not isinstance(dn, DistinguishedName):
dn = DistinguishedName(stringValue=dn)
if self.debug:
log.msg(f"Starting an LDAP lookup for '{dn.getText()}'.")
return self.root.lookup(dn)
self.logger.info("Starting an LDAP lookup for '{dn}'.", dn=dn.getText())

# Attach debug callbacks to the lookup and return
return (
self.root.lookup(dn)
.addErrback(failure_callback)
.addCallback(result_callback)
)

def refresh(self: Self) -> None:
"""Refresh the LDAP tree."""
Expand All @@ -83,11 +106,11 @@ def refresh(self: Self) -> None:
or (time.monotonic() - self.last_update) > self.refresh_interval
):
# Update users and groups from the OAuth server
log.msg("Retrieving OAuth data.")
self.logger.info("Retrieving OAuth data.")
oauth_groups, oauth_users = self.oauth_adaptor.retrieve_all()

# Create a root node for the tree
log.msg("Rebuilding LDAP tree.")
self.logger.info("Rebuilding LDAP tree.")
self.root_ = OAuthLDAPEntry(
dn=self.oauth_adaptor.root_dn,
attributes={"objectClass": ["dcObject"]},
Expand All @@ -105,31 +128,38 @@ def refresh(self: Self) -> None:
)

# Add groups to the groups OU
if self.debug:
log.msg(
f"Attempting to add {len(oauth_groups)} groups to the LDAP tree.",
)
self.logger.debug(
"Attempting to add {n_groups} groups to the LDAP tree.",
n_groups=len(oauth_groups),
)
for group_attrs in oauth_groups:
groups_ou.add_child(f"CN={group_attrs.cn}", group_attrs.to_dict())
if self.debug:
children = groups_ou.list_children()
for child in children:
log.msg(f"... {child.dn.getText()}")
log.msg(f"There are {len(children)} groups in the LDAP tree.")
ldap_groups = groups_ou.list_children()
self.logger.info(
"There are {n_groups} groups in the LDAP tree.",
n_groups=len(ldap_groups),
)
for ldap_group in ldap_groups:
self.logger.debug(
"... {ldap_group}",
ldap_group=ldap_group.dn.getText(),
)

# Add users to the users OU
if self.debug:
log.msg(
f"Attempting to add {len(oauth_users)} users to the LDAP tree.",
)
self.logger.debug(
"Attempting to add {n_users} users to the LDAP tree.",
n_users=len(oauth_users),
)
for user_attrs in oauth_users:
users_ou.add_child(f"CN={user_attrs.cn}", user_attrs.to_dict())
if self.debug:
children = users_ou.list_children()
for child in children:
log.msg(f"... {child.dn.getText()}")
log.msg(f"There are {len(children)} users in the LDAP tree.")
ldap_users = users_ou.list_children()
self.logger.info(
"There are {n_users} users in the LDAP tree.",
n_users=len(ldap_users),
)
for ldap_user in ldap_users:
self.logger.debug("... {ldap_user}", ldap_user=ldap_user.dn.getText())

# Set last updated time
log.msg("Finished building LDAP tree.")
self.logger.info("Finished building LDAP tree.")
self.last_update = time.monotonic()
Loading
Loading