Skip to content

raise_for_status fix retry handling #41940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#### Bugs Fixed
* Fixed bug where replacing manual throughput using `ThroughputProperties` would not work. See [PR 41564](https://github.com/Azure/azure-sdk-for-python/pull/41564)
* Fixed bug where constantly raising Service Request Error Exceptions would cause the Service Request Retry Policy to indefinitely retry the request during a query or when a request was sent without a request object. See [PR 41804](https://github.com/Azure/azure-sdk-for-python/pull/41804)
* Fixed retry logic if `raise_for_status` is set to `True` in the transport. See [PR 41940](https://github.com/Azure/azure-sdk-for-python/pull/41940).

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
only do cross regional retries for read operations.
"""

import logging
from azure.cosmos.documents import _OperationType

class ServiceResponseRetryPolicy(object):
Expand All @@ -23,7 +22,6 @@ def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper,
if self.request:
self.location_endpoint = (self.global_endpoint_manager
.resolve_service_endpoint_for_partition(self.request, pk_range_wrapper))
self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy')

def ShouldRetry(self):
"""Returns true if the request should retry based on preferred regions and retries already done.
Expand Down
73 changes: 46 additions & 27 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,28 @@
import time

from urllib.parse import urlparse
from azure.core.exceptions import DecodeError # type: ignore

from aiohttp import ClientResponseError
from azure.core.exceptions import AzureError, DecodeError # type: ignore

from .. import exceptions
from .. import http_constants
from . import _retry_utility_async
from .._synchronized_request import _request_body_from_data, _replace_url_prefix


def _check_status_code_for_retry(err: AzureError):
if isinstance(err.inner_exception, ClientResponseError) and err.inner_exception and err.inner_exception.code:
status_code = err.inner_exception.code
if status_code == 404:
raise exceptions.CosmosResourceNotFoundError(error=err)
if status_code == 409:
raise exceptions.CosmosResourceExistsError(error=err)
if status_code == 412:
raise exceptions.CosmosAccessConditionFailedError(error=err)
if status_code >= 400:
raise exceptions.CosmosHttpResponseError(error=err, status_code=status_code)

async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements
"""Makes one http request using the requests module.
Expand Down Expand Up @@ -97,32 +111,37 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p
and not connection_policy.DisableSSLVerification
)

if connection_policy.SSLConfiguration or "connection_cert" in kwargs:
ca_certs = connection_policy.SSLConfiguration.SSLCaCerts
cert_files = (connection_policy.SSLConfiguration.SSLCertFile, connection_policy.SSLConfiguration.SSLKeyFile)
response = await _PipelineRunFunction(
pipeline_client,
request,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
connection_verify=kwargs.pop("connection_verify", ca_certs),
connection_cert=kwargs.pop("connection_cert", cert_files),
request_params=request_params,
global_endpoint_manager=global_endpoint_manager,
**kwargs
)
else:
response = await _PipelineRunFunction(
pipeline_client,
request,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
# If SSL is disabled, verify = false
connection_verify=kwargs.pop("connection_verify", is_ssl_enabled),
request_params=request_params,
global_endpoint_manager=global_endpoint_manager,
**kwargs
)
try:
if connection_policy.SSLConfiguration or "connection_cert" in kwargs:
ca_certs = connection_policy.SSLConfiguration.SSLCaCerts
cert_files = (connection_policy.SSLConfiguration.SSLCertFile, connection_policy.SSLConfiguration.SSLKeyFile)
response = await _PipelineRunFunction(
pipeline_client,
request,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
connection_verify=kwargs.pop("connection_verify", ca_certs),
connection_cert=kwargs.pop("connection_cert", cert_files),
request_params=request_params,
global_endpoint_manager=global_endpoint_manager,
**kwargs
)
else:
response = await _PipelineRunFunction(
pipeline_client,
request,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
# If SSL is disabled, verify = false
connection_verify=kwargs.pop("connection_verify", is_ssl_enabled),
request_params=request_params,
global_endpoint_manager=global_endpoint_manager,
**kwargs
)
except AzureError as err:
# If the error is an AzureError, we need to check the status code and raise the appropriate Cosmos exception.
_check_status_code_for_retry(err)
raise err

response = response.http_response
headers = copy.copy(response.headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ async def send(self, request):
timeout_error.history = retry_settings['history']
raise
except ServiceRequestError as err:
retry_error = err
# the request ran into a socket timeout or failed to establish a new connection
# since request wasn't sent, raise exception immediately to be dealt with in client retry policies
if (not _has_database_account_header(request.http_request.headers)
Expand All @@ -303,7 +302,6 @@ async def send(self, request):
continue
raise err
except ServiceResponseError as err:
retry_error = err
if (_has_database_account_header(request.http_request.headers) or
request_params.healthy_tentative_location):
raise err
Expand All @@ -329,7 +327,6 @@ async def send(self, request):
except CosmosHttpResponseError as err:
raise err
except AzureError as err:
retry_error = err
if (_has_database_account_header(request.http_request.headers) or
request_params.healthy_tentative_location):
raise err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import sys
from typing import Callable, Optional, Any, Dict, List, Awaitable, MutableMapping
import aiohttp
from aiohttp import ConnectionTimeoutError
from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse
from azure.core.rest import HttpRequest, AsyncHttpResponse
from azure.cosmos import documents
Expand Down Expand Up @@ -191,6 +192,10 @@ async def error_region_down() -> Exception:
message="Injected region down.",
)

@staticmethod
async def connection_refused() -> Exception:
return ConnectionTimeoutError()

@staticmethod
async def error_service_response() -> Exception:
return ServiceResponseError(
Expand Down Expand Up @@ -222,6 +227,23 @@ async def transform_topology_swr_mrr(

return response

@staticmethod
async def change_status_code(
inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse:

response = await inner()
is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \
r: (FaultInjectionTransportAsync.predicate_targets_region(response.request, "https://tomasvaron-full-fidelity-westus3.documents.azure.com:443/") and
(FaultInjectionTransportAsync.predicate_is_operation_type(response.request, "Read") and
FaultInjectionTransportAsync.predicate_is_document_operation(response.request)))
if is_request_to_read_region:
response.status_code = 500
response.reason = "Not Ok"
return response


return response

@staticmethod
async def transform_topology_mwr(
first_region_name: str,
Expand Down
6 changes: 5 additions & 1 deletion sdk/cosmos/azure-cosmos/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.cosmos.http_constants import StatusCodes
from azure.cosmos.partition_key import PartitionKey
from azure.cosmos import (ContainerProxy, DatabaseProxy, documents, exceptions,
http_constants, _retry_utility)
http_constants, _location_cache)
from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError, ClientAuthenticationError
from azure.core.pipeline.policies import AsyncRetryPolicy, RetryPolicy
from devtools_testutils.azure_recorded_testcase import get_credential
Expand Down Expand Up @@ -68,6 +68,10 @@ class TestConfig(object):
TEST_CONTAINER_PARTITION_KEY = "pk"
TEST_CONTAINER_PREFIX_PARTITION_KEY = ["pk1", "pk2"]
TEST_CONTAINER_PREFIX_PARTITION_KEY_PATH = ['/pk1', '/pk2']
WRITE_REGION = "West US 3"
READ_REGION = "West US"
WRITE_LOCATIONAL_ENDPOINT = _location_cache.LocationCache.GetLocationalEndpoint(host, WRITE_REGION)
READ_LOCATIONAL_ENDPOINT = _location_cache.LocationCache.GetLocationalEndpoint(host, READ_REGION)

@classmethod
def create_database_if_not_exist(cls, client):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unittest import IsolatedAsyncioTestCase

import pytest
from aiohttp import ClientResponseError, RequestInfo
from azure.core.pipeline.transport import AioHttpTransport
from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse
from azure.core.rest import HttpRequest, AsyncHttpResponse
Expand All @@ -22,8 +23,9 @@
from azure.cosmos.aio import CosmosClient
from azure.cosmos.aio._container import ContainerProxy
from azure.cosmos.aio._database import DatabaseProxy
from azure.cosmos.documents import _OperationType
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.core.exceptions import ServiceRequestError
from azure.core.exceptions import ServiceRequestError, ServiceResponseError

MGMT_TIMEOUT = 5.0
logger = logging.getLogger('azure.cosmos')
Expand All @@ -34,8 +36,11 @@
master_key = test_config.TestConfig.masterKey
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4())
WRITE_LOCATIONAL_ENDPOINT = test_config.TestConfig.WRITE_LOCATIONAL_ENDPOINT
READ_LOCATIONAL_ENDPOINT = test_config.TestConfig.READ_LOCATIONAL_ENDPOINT
WRITE_REGION = test_config.TestConfig.WRITE_REGION
READ_REGION = test_config.TestConfig.READ_REGION

@pytest.mark.cosmosEmulator
@pytest.mark.asyncio
class TestFaultInjectionTransportAsync(IsolatedAsyncioTestCase):
@classmethod
Expand Down Expand Up @@ -101,6 +106,7 @@ async def cleanup_method(initialized_objects: Dict[str, Any]):
except Exception as close_error:
logger.warning(f"Exception trying to close method client. {close_error}")

@pytest.mark.cosmosEmulator
async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsync"):
id_value: str = str(uuid.uuid4())
document_definition = {'id': id_value,
Expand Down Expand Up @@ -131,6 +137,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosEmulator
async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"):
expected_read_region_uri: str = test_config.TestConfig.local_host
expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1")
Expand Down Expand Up @@ -182,6 +189,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"):
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosEmulator
async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjectionTransportAsync"):
expected_read_region_uri: str = test_config.TestConfig.local_host
expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1")
Expand Down Expand Up @@ -242,6 +250,77 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosMultiRegion
async def test_service_response_error_with_status_async(self: "TestFaultInjectionTransportAsync"):
custom_transport = FaultInjectionTransportAsync()

is_request_to_write_region: Callable[[HttpRequest], bool] = lambda \
r: (FaultInjectionTransportAsync.predicate_targets_region(r, WRITE_REGION) and
(FaultInjectionTransportAsync.predicate_is_operation_type(r, _OperationType.Read) and
FaultInjectionTransportAsync.predicate_is_document_operation(r)))

inner_error = ClientResponseError(
RequestInfo(url=host, method="GET", headers={}),
(),
status=404,
)
error = ServiceResponseError(status_code=404, message="Not Found", error=inner_error)

custom_transport.add_fault(
is_request_to_write_region,
lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay(0, error)))

id_value: str = str(uuid.uuid4())
document_definition = {'id': id_value,
'pk': id_value,
'name': 'sample document',
'key': 'value'}
initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport(
custom_transport,
preferred_locations=[WRITE_REGION, READ_REGION])
container: ContainerProxy = initialized_objects["col"]
await container.create_item(body=document_definition)

try:
# the status code should dictate the retry policy even if error is raised
with pytest.raises(CosmosHttpResponseError):
await container.read_item(document_definition["id"], document_definition["pk"])

finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosMultiRegion
async def test_service_response_error_with_status_async(self):
custom_transport = FaultInjectionTransportAsync()

# Inject rule to simulate regional outage in "write region" for reads
is_request_to_write_region: Callable[[HttpRequest], bool] = lambda \
r: (FaultInjectionTransportAsync.predicate_targets_region(r, WRITE_LOCATIONAL_ENDPOINT) and
(FaultInjectionTransportAsync.predicate_is_operation_type(r, _OperationType.Read) and
FaultInjectionTransportAsync.predicate_is_document_operation(r)))
custom_transport.add_fault(
is_request_to_write_region,
lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_service_response()))

id_value: str = str(uuid.uuid4())
document_definition = {'id': id_value,
'pk': id_value,
'name': 'sample document',
'key': 'value'}

initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport(
custom_transport,
preferred_locations=["West US 3", "West US"])
container: ContainerProxy = initialized_objects["col"]
await container.create_item(body=document_definition)
try:
read_document = await container.read_item(document_definition["id"], document_definition["pk"])
request: HttpRequest = read_document.get_response_headers()["_request"]
assert request.url.startswith(READ_LOCATIONAL_ENDPOINT)
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosEmulator
async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInjectionTransportAsync"):
expected_read_region_uri: str = test_config.TestConfig.local_host
expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1")
Expand Down Expand Up @@ -308,6 +387,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)


@pytest.mark.cosmosEmulator
async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"):
first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1")
custom_transport = FaultInjectionTransportAsync()
Expand Down Expand Up @@ -353,6 +433,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"):
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosEmulator
async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"):
first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1")
second_region_uri: str = test_config.TestConfig.local_host
Expand Down Expand Up @@ -407,6 +488,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosEmulator
async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjectionTransportAsync"):
expected_read_region_uri: str = test_config.TestConfig.local_host
expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1")
Expand Down Expand Up @@ -470,6 +552,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection
finally:
await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects)

@pytest.mark.cosmosEmulator
async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsync"):

first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1")
Expand Down