diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index ea72cef9cdf0..b8f33d239c10 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -10,6 +10,7 @@ #### Bugs Fixed * Fixed bug where replacing manual throughput using `ThroughputProperties` would not work. See [PR 41564](https://github.com/Azure/azure-sdk-for-python/pull/41564) * Fixed bug where constantly raising Service Request Error Exceptions would cause the Service Request Retry Policy to indefinitely retry the request during a query or when a request was sent without a request object. See [PR 41804](https://github.com/Azure/azure-sdk-for-python/pull/41804) +* Fixed retry logic if `raise_for_status` is set to `True` in the transport. See [PR 41940](https://github.com/Azure/azure-sdk-for-python/pull/41940). #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py index 59fca57e1c76..b5e4a1767912 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py @@ -7,7 +7,6 @@ only do cross regional retries for read operations. """ -import logging from azure.cosmos.documents import _OperationType class ServiceResponseRetryPolicy(object): @@ -23,7 +22,6 @@ def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, if self.request: self.location_endpoint = (self.global_endpoint_manager .resolve_service_endpoint_for_partition(self.request, pk_range_wrapper)) - self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy') def ShouldRetry(self): """Returns true if the request should retry based on preferred regions and retries already done. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 79e674eaa31c..83af01646171 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -26,7 +26,9 @@ import time from urllib.parse import urlparse -from azure.core.exceptions import DecodeError # type: ignore + +from aiohttp import ClientResponseError +from azure.core.exceptions import AzureError, DecodeError # type: ignore from .. import exceptions from .. import http_constants @@ -34,6 +36,18 @@ from .._synchronized_request import _request_body_from_data, _replace_url_prefix +def _check_status_code_for_retry(err: AzureError): + if isinstance(err.inner_exception, ClientResponseError) and err.inner_exception and err.inner_exception.code: + status_code = err.inner_exception.code + if status_code == 404: + raise exceptions.CosmosResourceNotFoundError(error=err) + if status_code == 409: + raise exceptions.CosmosResourceExistsError(error=err) + if status_code == 412: + raise exceptions.CosmosAccessConditionFailedError(error=err) + if status_code >= 400: + raise exceptions.CosmosHttpResponseError(error=err, status_code=status_code) + async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. @@ -97,32 +111,37 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p and not connection_policy.DisableSSLVerification ) - if connection_policy.SSLConfiguration or "connection_cert" in kwargs: - ca_certs = connection_policy.SSLConfiguration.SSLCaCerts - cert_files = (connection_policy.SSLConfiguration.SSLCertFile, connection_policy.SSLConfiguration.SSLKeyFile) - response = await _PipelineRunFunction( - pipeline_client, - request, - connection_timeout=connection_timeout, - read_timeout=read_timeout, - connection_verify=kwargs.pop("connection_verify", ca_certs), - connection_cert=kwargs.pop("connection_cert", cert_files), - request_params=request_params, - global_endpoint_manager=global_endpoint_manager, - **kwargs - ) - else: - response = await _PipelineRunFunction( - pipeline_client, - request, - connection_timeout=connection_timeout, - read_timeout=read_timeout, - # If SSL is disabled, verify = false - connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), - request_params=request_params, - global_endpoint_manager=global_endpoint_manager, - **kwargs - ) + try: + if connection_policy.SSLConfiguration or "connection_cert" in kwargs: + ca_certs = connection_policy.SSLConfiguration.SSLCaCerts + cert_files = (connection_policy.SSLConfiguration.SSLCertFile, connection_policy.SSLConfiguration.SSLKeyFile) + response = await _PipelineRunFunction( + pipeline_client, + request, + connection_timeout=connection_timeout, + read_timeout=read_timeout, + connection_verify=kwargs.pop("connection_verify", ca_certs), + connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, + **kwargs + ) + else: + response = await _PipelineRunFunction( + pipeline_client, + request, + connection_timeout=connection_timeout, + read_timeout=read_timeout, + # If SSL is disabled, verify = false + connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, + **kwargs + ) + except AzureError as err: + # If the error is an AzureError, we need to check the status code and raise the appropriate Cosmos exception. + _check_status_code_for_retry(err) + raise err response = response.http_response headers = copy.copy(response.headers) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 33b9c0785b38..7af85ff1cfcc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -291,7 +291,6 @@ async def send(self, request): timeout_error.history = retry_settings['history'] raise except ServiceRequestError as err: - retry_error = err # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if (not _has_database_account_header(request.http_request.headers) @@ -303,7 +302,6 @@ async def send(self, request): continue raise err except ServiceResponseError as err: - retry_error = err if (_has_database_account_header(request.http_request.headers) or request_params.healthy_tentative_location): raise err @@ -329,7 +327,6 @@ async def send(self, request): except CosmosHttpResponseError as err: raise err except AzureError as err: - retry_error = err if (_has_database_account_header(request.http_request.headers) or request_params.healthy_tentative_location): raise err diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 994357323b81..89cbd790ea54 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -28,6 +28,7 @@ import sys from typing import Callable, Optional, Any, Dict, List, Awaitable, MutableMapping import aiohttp +from aiohttp import ConnectionTimeoutError from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse from azure.cosmos import documents @@ -191,6 +192,10 @@ async def error_region_down() -> Exception: message="Injected region down.", ) + @staticmethod + async def connection_refused() -> Exception: + return ConnectionTimeoutError() + @staticmethod async def error_service_response() -> Exception: return ServiceResponseError( @@ -222,6 +227,23 @@ async def transform_topology_swr_mrr( return response + @staticmethod + async def change_status_code( + inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + + response = await inner() + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(response.request, "https://tomasvaron-full-fidelity-westus3.documents.azure.com:443/") and + (FaultInjectionTransportAsync.predicate_is_operation_type(response.request, "Read") and + FaultInjectionTransportAsync.predicate_is_document_operation(response.request))) + if is_request_to_read_region: + response.status_code = 500 + response.reason = "Not Ok" + return response + + + return response + @staticmethod async def transform_topology_mwr( first_region_name: str, diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index f8c2f7832bdb..f2e78dcdab8b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -13,7 +13,7 @@ from azure.cosmos.http_constants import StatusCodes from azure.cosmos.partition_key import PartitionKey from azure.cosmos import (ContainerProxy, DatabaseProxy, documents, exceptions, - http_constants, _retry_utility) + http_constants, _location_cache) from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError, ClientAuthenticationError from azure.core.pipeline.policies import AsyncRetryPolicy, RetryPolicy from devtools_testutils.azure_recorded_testcase import get_credential @@ -68,6 +68,10 @@ class TestConfig(object): TEST_CONTAINER_PARTITION_KEY = "pk" TEST_CONTAINER_PREFIX_PARTITION_KEY = ["pk1", "pk2"] TEST_CONTAINER_PREFIX_PARTITION_KEY_PATH = ['/pk1', '/pk2'] + WRITE_REGION = "West US 3" + READ_REGION = "West US" + WRITE_LOCATIONAL_ENDPOINT = _location_cache.LocationCache.GetLocationalEndpoint(host, WRITE_REGION) + READ_LOCATIONAL_ENDPOINT = _location_cache.LocationCache.GetLocationalEndpoint(host, READ_REGION) @classmethod def create_database_if_not_exist(cls, client): diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 8a47d2bc5943..4d7448fb0a26 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -12,6 +12,7 @@ from unittest import IsolatedAsyncioTestCase import pytest +from aiohttp import ClientResponseError, RequestInfo from azure.core.pipeline.transport import AioHttpTransport from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse @@ -22,8 +23,9 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy +from azure.cosmos.documents import _OperationType from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.core.exceptions import ServiceRequestError +from azure.core.exceptions import ServiceRequestError, ServiceResponseError MGMT_TIMEOUT = 5.0 logger = logging.getLogger('azure.cosmos') @@ -34,8 +36,11 @@ master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4()) +WRITE_LOCATIONAL_ENDPOINT = test_config.TestConfig.WRITE_LOCATIONAL_ENDPOINT +READ_LOCATIONAL_ENDPOINT = test_config.TestConfig.READ_LOCATIONAL_ENDPOINT +WRITE_REGION = test_config.TestConfig.WRITE_REGION +READ_REGION = test_config.TestConfig.READ_REGION -@pytest.mark.cosmosEmulator @pytest.mark.asyncio class TestFaultInjectionTransportAsync(IsolatedAsyncioTestCase): @classmethod @@ -101,6 +106,7 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") + @pytest.mark.cosmosEmulator async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsync"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, @@ -131,6 +137,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy finally: await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosEmulator async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") @@ -182,6 +189,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): finally: await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosEmulator async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") @@ -242,6 +250,77 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection finally: await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosMultiRegion + async def test_service_response_error_with_status_async(self: "TestFaultInjectionTransportAsync"): + custom_transport = FaultInjectionTransportAsync() + + is_request_to_write_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(r, WRITE_REGION) and + (FaultInjectionTransportAsync.predicate_is_operation_type(r, _OperationType.Read) and + FaultInjectionTransportAsync.predicate_is_document_operation(r))) + + inner_error = ClientResponseError( + RequestInfo(url=host, method="GET", headers={}), + (), + status=404, + ) + error = ServiceResponseError(status_code=404, message="Not Found", error=inner_error) + + custom_transport.add_fault( + is_request_to_write_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error))) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( + custom_transport, + preferred_locations=[WRITE_REGION, READ_REGION]) + container: ContainerProxy = initialized_objects["col"] + await container.create_item(body=document_definition) + + try: + # the status code should dictate the retry policy even if error is raised + with pytest.raises(CosmosHttpResponseError): + await container.read_item(document_definition["id"], document_definition["pk"]) + + finally: + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + @pytest.mark.cosmosMultiRegion + async def test_service_response_error_with_status_async(self): + custom_transport = FaultInjectionTransportAsync() + + # Inject rule to simulate regional outage in "write region" for reads + is_request_to_write_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(r, WRITE_LOCATIONAL_ENDPOINT) and + (FaultInjectionTransportAsync.predicate_is_operation_type(r, _OperationType.Read) and + FaultInjectionTransportAsync.predicate_is_document_operation(r))) + custom_transport.add_fault( + is_request_to_write_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_service_response())) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["West US 3", "West US"]) + container: ContainerProxy = initialized_objects["col"] + await container.create_item(body=document_definition) + try: + read_document = await container.read_item(document_definition["id"], document_definition["pk"]) + request: HttpRequest = read_document.get_response_headers()["_request"] + assert request.url.startswith(READ_LOCATIONAL_ENDPOINT) + finally: + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + @pytest.mark.cosmosEmulator async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") @@ -308,6 +387,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosEmulator async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -353,6 +433,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): finally: await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosEmulator async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host @@ -407,6 +488,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport finally: await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosEmulator async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") @@ -470,6 +552,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection finally: await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + @pytest.mark.cosmosEmulator async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1")