diff --git a/servicex/servicex_adapter.py b/servicex/servicex_adapter.py index baa7b5b0..6c85e85a 100644 --- a/servicex/servicex_adapter.py +++ b/servicex/servicex_adapter.py @@ -85,11 +85,21 @@ def _get_bearer_token_file(): bearer_token = f.read().strip() return bearer_token + @staticmethod + def _effective_expiration(tok) -> float: + import sys + + decoded_tok = jwt.decode(tok, verify=False) + if "exp" not in decoded_tok: + # Token is missing an expiration. Return the maximum float + return sys.float_info.max + return float(decoded_tok["exp"]) + async def _get_authorization(self, force_reauth: bool = False) -> Dict[str, str]: now = time.time() if ( self.token - and jwt.decode(self.token, verify=False)["exp"] - now > 60 + and self._effective_expiration(self.token) - now > 60 and not force_reauth ): # if less than one minute validity, renew @@ -105,7 +115,7 @@ async def _get_authorization(self, force_reauth: bool = False) -> Dict[str, str] if ( not self.token or force_reauth - or float(jwt.decode(self.token, verify=False)["exp"]) - now < 60 + or self._effective_expiration(self.token) - now < 60 ): await self._get_token() return {"Authorization": f"Bearer {self.token}"} diff --git a/tests/test_servicex_adapter.py b/tests/test_servicex_adapter.py index 99d8f5c0..2c329caf 100644 --- a/tests/test_servicex_adapter.py +++ b/tests/test_servicex_adapter.py @@ -106,16 +106,15 @@ async def test_get_transforms_wlcg_bearer_token( ) token_file.close() - os.environ["BEARER_TOKEN_FILE"] = token_file.name + with patch.dict(os.environ, {"BEARER_TOKEN_FILE": token_file.name}, clear=True): - # Try with an expired token - with pytest.raises(AuthorizationError) as err: - decode.return_value = {"exp": 0.0} - await servicex.get_transforms() - assert "ServiceX access token request rejected:" in str(err.value) + # Try with an expired token + with pytest.raises(AuthorizationError) as err: + decode.return_value = {"exp": 0.0} + await servicex.get_transforms() + assert "ServiceX access token request rejected:" in str(err.value) os.remove(token_file.name) - del os.environ["BEARER_TOKEN_FILE"] @pytest.mark.asyncio