Skip to content

Commit e2c7c0d

Browse files
authored
[Core] Update on_challenge in credential policies (#41857)
We will now use authorize_request internally, which will cache the token to maintain prior behavior of subclassing policies like ARM policies. Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 69a5346 commit e2c7c0d

File tree

5 files changed

+98
-8
lines changed

5 files changed

+98
-8
lines changed

sdk/core/azure-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
- A timeout error when using the `aiohttp` transport (the default for async SDKs) will now be raised as a `azure.core.exceptions.ServiceResponseTimeoutError`, a subtype of the previously raised `ServiceResponseError`.
2020
- When using with `aiohttp` 3.10 or later, a connection timeout error will now be raised as a `azure.core.exceptions.ServiceRequestTimeoutError`, which can be retried.
21+
- The default implementation of `on_challenge` in `BearerTokenCredentialPolicy` and `AsyncBearerTokenCredentialPolicy` will now cache the retrieved token. #41857
2122

2223
## 1.34.0 (2025-05-01)
2324

sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,7 @@ def on_challenge(
204204
padding_needed = -len(encoded_claims) % 4
205205
claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
206206
if claims:
207-
token = self._get_token(*self._scopes, claims=claims)
208-
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
209-
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
207+
self.authorize_request(request, *self._scopes, claims=claims)
210208
return True
211209
except Exception: # pylint:disable=broad-except
212210
return False

sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,7 @@ async def on_challenge(
149149
padding_needed = -len(encoded_claims) % 4
150150
claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode("utf-8")
151151
if claims:
152-
token = await self._get_token(*self._scopes, claims=claims)
153-
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], token).token
154-
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
152+
await self.authorize_request(request, *self._scopes, claims=claims)
155153
return True
156154
except Exception: # pylint:disable=broad-except
157155
return False

sdk/core/azure-core/tests/async_tests/test_authentication_async.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
import asyncio
7+
import base64
78
import sys
89
import time
910
from unittest.mock import Mock, patch, AsyncMock, create_autospec
@@ -12,7 +13,7 @@
1213
from azure.core.credentials import AccessToken, AccessTokenInfo
1314
from azure.core.credentials_async import AsyncTokenCredential, AsyncSupportsTokenInfo
1415
from azure.core.exceptions import ServiceRequestError
15-
from azure.core.pipeline import AsyncPipeline, PipelineRequest, PipelineContext
16+
from azure.core.pipeline import AsyncPipeline, PipelineRequest, PipelineContext, PipelineResponse
1617
from azure.core.pipeline.policies import (
1718
AsyncBearerTokenCredentialPolicy,
1819
SansIOHTTPPolicy,
@@ -593,3 +594,49 @@ def test_async_token_credential_sync():
593594
# Ensure trio isn't in sys.modules (i.e. imported).
594595
sys.modules.pop("trio", None)
595596
AsyncBearerTokenCredentialPolicy(Mock(), "scope")
597+
598+
599+
@pytest.mark.asyncio
600+
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
601+
async def test_async_bearer_policy_on_challenge_caches_token(http_request):
602+
"""Test that async on_challenge caches the token when handling claims challenges"""
603+
# Setup credentials that return different tokens for different calls
604+
initial_token = AccessToken("initial_token", int(time.time()) + 3600)
605+
claims_token = AccessToken("claims_token", int(time.time()) + 3600)
606+
607+
call_count = 0
608+
609+
async def mock_get_token_info(*scopes, options=None):
610+
nonlocal call_count
611+
call_count += 1
612+
if options and "claims" in options:
613+
return claims_token
614+
return initial_token
615+
616+
fake_credential = Mock(spec_set=["get_token_info"], get_token_info=mock_get_token_info)
617+
policy = AsyncBearerTokenCredentialPolicy(fake_credential, "scope")
618+
619+
# Create request and initial response
620+
http_req = http_request("GET", "https://example.com")
621+
request = PipelineRequest(http_req, PipelineContext(None))
622+
623+
# Create a 401 response with insufficient_claims challenge
624+
test_claims = '{"access_token":{"foo":"bar"}}'
625+
encoded_claims = base64.urlsafe_b64encode(test_claims.encode()).decode().rstrip("=")
626+
challenge_header = f'Bearer error="insufficient_claims", claims="{encoded_claims}"'
627+
628+
response_mock = Mock(status_code=401, headers={"WWW-Authenticate": challenge_header})
629+
response = PipelineResponse(request, response_mock, PipelineContext(None))
630+
631+
# Call on_challenge
632+
result = await policy.on_challenge(request, response)
633+
634+
# Verify the challenge was handled successfully
635+
assert result is True
636+
637+
# Verify the token was cached
638+
assert policy._token is claims_token
639+
assert policy._token.token == "claims_token"
640+
641+
# Verify the Authorization header was set correctly
642+
assert request.http_request.headers["Authorization"] == "Bearer claims_token"

sdk/core/azure-core/tests/test_authentication.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66
from collections import namedtuple
7+
import base64
78
import time
89
from itertools import product
910
from requests import Response
@@ -16,7 +17,7 @@
1617
AccessTokenInfo,
1718
)
1819
from azure.core.exceptions import ServiceRequestError
19-
from azure.core.pipeline import Pipeline, PipelineRequest, PipelineContext
20+
from azure.core.pipeline import Pipeline, PipelineRequest, PipelineContext, PipelineResponse
2021
from azure.core.pipeline.transport import HttpTransport, HttpRequest
2122
from azure.core.pipeline.policies import (
2223
BearerTokenCredentialPolicy,
@@ -791,3 +792,48 @@ def test_access_token_subscriptable():
791792
assert len(token) == 2
792793
assert token[0] == "token"
793794
assert token[1] == 42
795+
796+
797+
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
798+
def test_bearer_policy_on_challenge_caches_token_with_claims(http_request):
799+
"""Test that on_challenge caches the token when handling claims challenges"""
800+
# Setup credentials that return different tokens for different calls
801+
initial_token = AccessToken("initial_token", int(time.time()) + 3600)
802+
claims_token = AccessToken("claims_token", int(time.time()) + 3600)
803+
804+
call_count = 0
805+
806+
def mock_get_token_info(*scopes, options):
807+
nonlocal call_count
808+
call_count += 1
809+
if options and "claims" in options:
810+
return claims_token
811+
return initial_token
812+
813+
fake_credential = Mock(spec_set=["get_token_info"], get_token_info=mock_get_token_info)
814+
policy = BearerTokenCredentialPolicy(fake_credential, "scope")
815+
816+
# Create request and initial response
817+
http_req = http_request("GET", "https://example.com")
818+
request = PipelineRequest(
819+
http_req, PipelineContext(None)
820+
) # Create a 401 response with insufficient_claims challenge
821+
test_claims = '{"access_token":{"foo":"bar"}}'
822+
encoded_claims = base64.urlsafe_b64encode(test_claims.encode()).decode().rstrip("=")
823+
challenge_header = f'Bearer error="insufficient_claims", claims="{encoded_claims}"'
824+
825+
response_mock = Mock(status_code=401, headers={"WWW-Authenticate": challenge_header})
826+
response = PipelineResponse(request, response_mock, PipelineContext(None))
827+
828+
# Call on_challenge
829+
result = policy.on_challenge(request, response)
830+
831+
# Verify the challenge was handled successfully
832+
assert result is True
833+
834+
# Verify the token was cached
835+
assert policy._token is claims_token
836+
assert policy._token.token == "claims_token"
837+
838+
# Verify the Authorization header was set correctly
839+
assert request.http_request.headers["Authorization"] == "Bearer claims_token"

0 commit comments

Comments
 (0)