Skip to content

Commit 79a1623

Browse files
authored
[minor] use custom claims overrides so we can directly specify special wlcg/scitoken claims (#7)
1 parent 41fbb8c commit 79a1623

File tree

3 files changed

+34
-25
lines changed

3 files changed

+34
-25
lines changed

src/scitoken_issuer/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses as dc
2+
import json
23
import logging
34
import secrets
5+
from typing import Any
46

57
from wipac_dev_tools import from_environment_as_dataclass
68

@@ -28,7 +30,7 @@ class EnvConfig:
2830
IDP_USERNAME_CLAIM: str = 'preferred_username'
2931

3032
ISSUER_ADDRESS: str = ''
31-
AUDIENCE: str = ''
33+
CUSTOM_CLAIMS: dict[str, Any] | str = ''
3234
KEY_TYPE: str = 'RS256'
3335

3436
ACCESS_TOKEN_EXPIRATION: int = 300 # seconds
@@ -68,10 +70,10 @@ def __post_init__(self) -> None:
6870
raise ConfigError('Must specify IDP_CLIENT_ID in production')
6971
if not self.IDP_CLIENT_SECRET:
7072
raise ConfigError('Must specify IDP_CLIENT_SECRET in production')
71-
if not self.AUDIENCE:
72-
raise ConfigError('Must specify AUDIENCE in production')
7373
if not self.MONGODB_URL:
7474
raise ConfigError('Must specify MONGODB_URL in production')
75+
if self.CUSTOM_CLAIMS and isinstance(self.CUSTOM_CLAIMS, str):
76+
object.__setattr__(self, 'CUSTOM_CLAIMS', json.loads(self.CUSTOM_CLAIMS))
7577
if self.KEY_TYPE not in DEFAULT_KEY_ALGORITHMS:
7678
raise ConfigError(f'KEY_TYPE must be one of {DEFAULT_KEY_ALGORITHMS}')
7779
if self.MONGODB_WRITE_CONCERN < 1:

src/scitoken_issuer/server.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -450,27 +450,32 @@ async def post(self):
450450
current_key = await self.state.get_current_key()
451451
auth = Auth(
452452
secret=current_key['private_key'],
453-
audience=config.ENV.AUDIENCE,
454453
issuer=config.ENV.ISSUER_ADDRESS,
455454
algorithm=config.ENV.KEY_TYPE,
456455
integer_times=True, # scitokens-cpp can't handle floats
457456
)
457+
access_claims = {
458+
'jti': uuid.uuid4().hex,
459+
config.ENV.IDP_USERNAME_CLAIM: username,
460+
'scope': access_scope,
461+
}
462+
if config.ENV.CUSTOM_CLAIMS:
463+
logger.info('custom claims: %r', config.ENV.CUSTOM_CLAIMS)
464+
if not set(access_claims).isdisjoint(config.ENV.CUSTOM_CLAIMS):
465+
logger.error('CUSTOM_CLAIMS should not override existing claims')
466+
raise OAuthError(500, error="invalid_claims")
467+
access_claims.update(config.ENV.CUSTOM_CLAIMS)
458468
access_token = auth.create_token(
459469
subject=username,
460470
expiration=config.ENV.ACCESS_TOKEN_EXPIRATION,
461-
payload={
462-
'aud': [config.ENV.AUDIENCE],
463-
config.ENV.IDP_USERNAME_CLAIM: username,
464-
'scope': access_scope,
465-
'jti': uuid.uuid4().hex,
466-
'wlcg.ver': '1.0',
467-
},
471+
payload=access_claims,
468472
headers={'kid': current_key['kid']},
469473
)
470474
refresh_token = auth.create_token(
471475
subject=username,
472476
expiration=config.ENV.REFRESH_TOKEN_EXPIRATION,
473477
payload={
478+
'jti': uuid.uuid4().hex,
474479
'aud': config.ENV.ISSUER_ADDRESS,
475480
config.ENV.IDP_USERNAME_CLAIM: username,
476481
'idp_username': username,

tests/test_server.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ async def fn():
126126
CI_TESTING=True,
127127
PORT=port,
128128
ISSUER_ADDRESS='',
129-
AUDIENCE='storage',
130129
POSIX_PATH=str(tmp_path),
131130
DEVICE_CODE_POLLING_INTERVAL=.1):
132131

@@ -307,7 +306,6 @@ async def common(scopes):
307306
assert login.call_count == 1
308307

309308
data = jwt.decode(token, options={"verify_signature": False})
310-
assert data['aud'] == [scitoken_issuer.config.ENV.AUDIENCE]
311309
assert data['sub'] == user
312310
return data
313311

@@ -455,15 +453,19 @@ async def test_scitokens(server, storage, monkeypatch):
455453
"""
456454
users, _, _ = storage
457455
user = 'test1'
458-
async with server() as address:
459-
login = Mock(side_effect=do_device_login(user))
460-
monkeypatch.setattr('rest_tools.client.device_client._print_qrcode', login)
461-
scopes = ['storage.read:/data/ana/project1/sub1']
462-
async with make_client(address, scopes) as rc:
463-
token = rc._openid_token()
464-
assert token
465-
access_data = jwt.decode(token, options={"verify_signature": False})
466-
logging.info('access token: %r', access_data)
467-
assert isinstance(access_data['exp'], int)
468-
469-
assert access_data['wlcg.ver'] == '1.0'
456+
457+
with env(CUSTOM_CLAIMS='{"aud": ["https://wlcg.cern.ch/jwt/v1/any"], "wlcg.ver":"1.0"}'):
458+
async with server() as address:
459+
login = Mock(side_effect=do_device_login(user))
460+
monkeypatch.setattr('rest_tools.client.device_client._print_qrcode', login)
461+
scopes = ['storage.read:/data/ana/project1/sub1']
462+
async with make_client(address, scopes) as rc:
463+
token = rc._openid_token()
464+
assert token
465+
access_data = jwt.decode(token, options={"verify_signature": False})
466+
logging.info('access token: %r', access_data)
467+
assert isinstance(access_data['exp'], int)
468+
469+
assert access_data['wlcg.ver'] == '1.0'
470+
471+
assert 'https://wlcg.cern.ch/jwt/v1/any' in access_data['aud']

0 commit comments

Comments
 (0)