Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/container.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ jobs:
with:
image_namespace: wipacrepo
image_name: iceprod
mode: BUILD
101 changes: 27 additions & 74 deletions iceprod/website/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
from cachetools.func import ttl_cache
from prometheus_client import Info, start_http_server
import tornado.web
import tornado.httpserver
import tornado.gen
import jwt
import tornado.concurrent
import requests.exceptions
from rest_tools.client import RestClient, ClientCredentialsAuth
from rest_tools.server import catch_error, RestServer, RestHandlerSetup, RestHandler, OpenIDCookieHandlerMixin, OpenIDLoginHandler
from rest_tools.server.session import SessionMixin, Session
Expand All @@ -37,7 +35,6 @@
from iceprod.roles_groups import GROUPS
from iceprod.core.config import CONFIG_SCHEMA as DATASET_SCHEMA
from iceprod.server.config import CONFIG_SCHEMA as SERVER_SCHEMA
import iceprod.core.functions
from iceprod.server import documentation
import iceprod.server.states
from iceprod.server.util import datetime2str, nowstr
Expand Down Expand Up @@ -219,120 +216,76 @@ def clear_tokens(self):
self._session_mgr.delete_session(username)


class TokenStorageMixin(OpenIDCookieHandlerMixin, RestHandler):
class TokenStorageMixin(RestHandler):
"""
Store/load current user's `OpenIDLoginHandler` tokens in iceprod credentials API.
Store/load current user's tokens in iceprod credentials API.
"""
def initialize(self, cred_rest_client: RestClient, full_url: str, **kwargs): # type: ignore
def initialize(self, *args, cred_rest_client, **kwargs):
super().initialize(**kwargs)
self.cred_rest_client = cred_rest_client
self.full_url = full_url

def get_current_user(self):
return None

async def get_current_user_async(self):
"""Get the current user, and set auth-related attributes."""
@authenticated
async def get_cred_tokens(self, url):
"""Get selected tokens from the credential service."""
try:
assert self.auth
username = self.get_secure_cookie('iceprod_username')
if not username:
return None
if isinstance(username, bytes):
username = username.decode('utf-8')
creds = await self.cred_rest_client.request('GET', f'/users/{username}/credentials', {'url': self.full_url})
cred = creds[self.full_url]
access_token = cred['access_token']
try:
data = self.auth.validate(access_token)
except jwt.ExpiredSignatureError:
logger.debug('user access_token expired')
return None
self.auth_data = data

# lookup groups
auth_groups = set()
try:
for name in GROUPS:
for expression in GROUPS[name]:
ret = eval_expression(data, expression)
auth_groups.update(match.expand(name) for match in ret)
except Exception:
logger.info('cannot determine groups', exc_info=True)
self.auth_groups = sorted(auth_groups)

self.auth_access_token = access_token
self.auth_refresh_token = cred.get('refresh_token', '')
return username

except Exception:
logger.debug('failed auth', exc_info=True)
username = self.current_user
creds = await self.cred_rest_client.request('GET', f'/users/{username}/credentials', {'url': url})
return creds[url]
except requests.exceptions.RequestException:
logger.warning('failed to get credentials', exc_info=True)
return None

def store_tokens(
@authenticated
async def put_cred_tokens(
self,
url,
access_token,
access_token_exp,
refresh_token=None,
refresh_token_exp=None,
user_info=None,
user_info_exp=None,
):
"""
Store jwt tokens and user info from OpenID-compliant auth source.
Store jwt tokens from OpenID-compliant auth source.

Args:
url (str): site url
access_token (str): jwt access token
access_token_exp (int): access token expiration in seconds
refresh_token (str): jwt refresh token
refresh_token_exp (int): refresh token expiration in seconds
user_info (dict): user info (from id token or user info lookup)
user_info_exp (int): user info expiration in seconds
"""
assert self.auth
if not user_info:
user_info = self.auth.validate(access_token)
username = user_info.get('preferred_username')
if not username:
username = user_info.get('upn')
if not username:
raise tornado.web.HTTPError(400, reason='no username in token')
username = self.current_user
args = {
'url': self.full_url,
'url': url,
'type': 'oauth',
'access_token': access_token,
}
if refresh_token:
args['refresh_token'] = refresh_token

self.cred_rest_client.request_seq('POST', f'/users/{username}/credentials', args)

self.set_secure_cookie('iceprod_username', username, expires_days=30)
await self.cred_rest_client.request('POST', f'/users/{username}/credentials', args)

def clear_tokens(self):
async def clear_cred_tokens(self):
"""
Clear token data, usually on logout.
Clear all token data.
"""
self.clear_cookie('iceprod_username')
username = self.current_user
await self.cred_rest_client.request('DELETE', f'/users/{username}/credentials', {})


class Login(LoginMixin, PromRequestMixin, OpenIDLoginHandler): # type: ignore
pass


class PublicHandler(LoginMixin, PromRequestMixin, RestHandler):
class PublicHandler(LoginMixin, TokenStorageMixin, PromRequestMixin, RestHandler):
"""Default Handler"""
def initialize(self, rest_api, cred_rest_client, system_rest_client, **kwargs): # type: ignore
def initialize(self, rest_api, system_rest_client, **kwargs): # type: ignore
"""
Get some params from the website module

:param rest_api: the rest api url
:param cred_rest_client: the rest api url for the cred service
:param system_rest_client: the rest client for the system role
"""
super().initialize(**kwargs)
self.rest_api = rest_api
self.cred_rest_client = cred_rest_client
self.system_rest_client = system_rest_client
self.rest_client: RestClient | None = None

Expand Down Expand Up @@ -869,7 +822,6 @@ def __init__(self):
raise RuntimeError('ICEPROD_CRED_CLIENT_ID or ICEPROD_CRED_CLIENT_SECRET not specified, and CI_TESTING not enabled!')

handler_args = RestHandlerSetup(rest_config)
handler_args['cred_rest_client'] = cred_client
if config.CI_TESTING:
self.session = Session()
else:
Expand Down Expand Up @@ -906,6 +858,7 @@ def __init__(self):

handler_args.update({
'rest_api': rest_address,
'cred_rest_client': cred_client,
'system_rest_client': rest_client,
})
if config.COOKIE_SECRET:
Expand Down
Loading