From 2550197c7aeb64040e98d37790ac41c38c6c70d2 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:23:17 -0400 Subject: [PATCH 01/52] session container fixes --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 101 +++++++++---- .../azure/cosmos/_cosmos_client_connection.py | 62 ++++---- .../azure-cosmos/azure/cosmos/_session.py | 134 +++++++++++++++--- .../aio/_cosmos_client_connection_async.py | 90 +++++++----- sdk/cosmos/azure-cosmos/tests/test_session.py | 2 +- .../tests/test_session_container.py | 4 +- 6 files changed, 278 insertions(+), 115 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 654b23c5d71f..ab2b658009d1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -167,37 +167,9 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches if options.get("indexingDirective"): headers[http_constants.HttpHeaders.IndexingDirective] = options["indexingDirective"] - consistency_level = None - - # get default client consistency level - default_client_consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) - - # set consistency level. check if set via options, this will override the default + # set request consistency level - if session consistency, the client should be setting this on its own if options.get("consistencyLevel"): - consistency_level = options["consistencyLevel"] - # TODO: move this line outside of if-else cause to remove the code duplication - headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level - elif default_client_consistency_level is not None: - consistency_level = default_client_consistency_level - headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level - - # figure out if consistency level for this request is session - is_session_consistency = consistency_level == documents.ConsistencyLevel.Session - - # set session token if required - if is_session_consistency is True and not IsMasterResource(resource_type): - # if there is a token set via option, then use it to override default - if options.get("sessionToken"): - headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] - else: - # check if the client's default consistency is session (and request consistency level is same), - # then update from session container - if default_client_consistency_level == documents.ConsistencyLevel.Session and \ - cosmos_client_connection.session: - # populate session token from the client's session container - headers[http_constants.HttpHeaders.SessionToken] = cosmos_client_connection.session.get_session_token( - path - ) + headers[http_constants.HttpHeaders.ConsistencyLevel] = options["consistencyLevel"] if options.get("enableScanInQuery"): headers[http_constants.HttpHeaders.EnableScanInQuery] = options["enableScanInQuery"] @@ -337,6 +309,75 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches return headers +def _is_session_token_request( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + resource_type: str, + operation_type: str) -> None: + consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) + # Figure out if consistency level for this request is session + is_session_consistency = consistency_level == documents.ConsistencyLevel.Session + + # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account + # configured to use multiple write regions + return (is_session_consistency is True and not IsMasterResource(resource_type) + and (documents._OperationType.IsReadOnlyOperation(operation_type) or operation_type == "Batch" + or cosmos_client_connection._global_endpoint_manager.get_use_multiple_write_locations())) + + +def set_session_token_header( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + path: str, + resource_type: str, + operation_type: str, + options: Mapping[str, Any], + partition_key_range_id: Optional[str] = None) -> None: + # set session token if required + if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type): + # if there is a token set via option, then use it to override default + if options.get("sessionToken"): + headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] + else: + # check if the client's default consistency is session (and request consistency level is same), + # then update from session container + if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ + cosmos_client_connection.session: + # populate session token from the client's session container + session_token = cosmos_client_connection.session.get_session_token(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id) + if session_token != "": + headers[http_constants.HttpHeaders.SessionToken] = session_token + +async def set_session_token_header_async( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + path: str, + resource_type: str, + operation_type: str, + options: Mapping[str, Any], + partition_key_range_id: Optional[str] = None) -> None: + # set session token if required + if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type): + # if there is a token set via option, then use it to override default + if options.get("sessionToken"): + headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] + else: + # check if the client's default consistency is session (and request consistency level is same), + # then update from session container + if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ + cosmos_client_connection.session: + # populate session token from the client's session container + session_token = await cosmos_client_connection.session.get_session_token_async(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id) + if session_token != "": + headers[http_constants.HttpHeaders.SessionToken] = session_token def GetResourceIdOrFullNameFromLink(resource_link: str) -> Optional[str]: """Gets resource id or full name from resource link. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3934c23bcf99..38ae3f4a0bf2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2042,6 +2042,7 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch) request_data = {} @@ -2131,6 +2132,7 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) + base.set_session_token_header(self, headers, path, "docs", documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], @@ -2191,6 +2193,8 @@ def DeleteAllItemsByPartitionKey( collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) + base.set_session_token_header(self, headers, path, "partitionkey", documents._OperationType.Delete, + options) request_params = RequestObject("partitionkey", documents._OperationType.Delete) _, last_response_headers = self.__Post( path=path, @@ -2615,7 +2619,7 @@ def Create( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2625,7 +2629,7 @@ def Create( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2642,11 +2646,12 @@ def Create( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Create, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create) + request_params = RequestObject(resource_type, documents._OperationType.Create) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2660,7 +2665,7 @@ def Upsert( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2670,7 +2675,7 @@ def Upsert( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2687,12 +2692,13 @@ def Upsert( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Upsert, options) # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params = RequestObject(resource_type, documents._OperationType.Upsert) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2705,7 +2711,7 @@ def Replace( self, resource: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2715,7 +2721,7 @@ def Replace( :param dict resource: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2732,10 +2738,11 @@ def Replace( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, - options) + headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, + documents._OperationType.Replace, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace) + request_params = RequestObject(resource_type, documents._OperationType.Replace) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2748,7 +2755,7 @@ def Replace( def Read( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2757,7 +2764,7 @@ def Read( """Reads an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2774,9 +2781,11 @@ def Read( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) + headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, + documents._OperationType.Read, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read) + request_params = RequestObject(resource_type, documents._OperationType.Read) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2786,7 +2795,7 @@ def Read( def DeleteResource( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2812,10 +2821,11 @@ def DeleteResource( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, - options) + headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, + documents._OperationType.Delete, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete) + request_params = RequestObject(resource_type, documents._OperationType.Delete) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3063,6 +3073,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) + base.set_session_token_header(self, headers, path, resource_type, request_params.operation_type, options, + partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -3101,6 +3113,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) + base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery, + options) # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) @@ -3355,7 +3369,7 @@ def _UpdateSessionIfRequired( if is_session_consistency and self.session: # update session - self.session.update_session(response_result, response_headers) + self.session.update_session(self, response_result, response_headers) def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]: partition_key_definition: Optional[Dict[str, Any]] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 84dd5914e208..811d778fca90 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -30,6 +30,8 @@ from . import http_constants from ._vector_session_token import VectorSessionToken from .exceptions import CosmosHttpResponseError +from .partition_key import PartitionKey +from typing import Any, Dict, Optional class SessionContainer(object): @@ -38,12 +40,24 @@ def __init__(self): self.rid_to_session_token = {} self.session_lock = threading.RLock() - def get_session_token(self, resource_path): - """Get Session Token for collection_link. + def get_session_token( + self, + resource_path: str, + pk_value: str, + container_properties_cache: Dict[str, Dict[str, Any]], + routing_map_provider: Any, + partition_key_range_id: Optional[int]) -> str: + """Get Session Token for collection_link and operation_type. :param str resource_path: Self link / path to the resource - :return: Session Token dictionary for the collection_id - :rtype: dict + :param str operation_type: Operation type (e.g. 'Create', 'Read', 'Upsert', 'Replace') + :param str pk_value: The partition key value being used for the operation + :param container_properties_cache: Container properties cache used to fetch partition key definitions + :type container_properties_cache: Dict[str, Dict[str, Any]] + :param int partition_key_range_id: The partition key range ID used for the operation + :return: Session Token dictionary for the collection_id, will be empty string if not found or if the operation + does not require a session token (single master write operations). + :rtype: str """ with self.session_lock: @@ -59,23 +73,91 @@ def get_session_token(self, resource_path): else: collection_rid = _base.GetItemContainerLink(resource_path) - if collection_rid in self.rid_to_session_token: + if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] - session_token_list = [] - for key in token_dict.keys(): - session_token_list.append("{0}:{1}".format(key, token_dict[key].convert_to_string())) - session_token = ",".join(session_token_list) + if partition_key_range_id is not None: + session_token = token_dict.get(partition_key_range_id) + else: + collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") + partition_key = PartitionKey(path=collection_pk_definition['paths'], + kind=collection_pk_definition['kind'], + version=collection_pk_definition['version']) + epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) + pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) + vector_session_token = token_dict.get(pk_range[0]['id']) + session_token = "{0}:{1}".format(pk_range[0]['id'], vector_session_token.session_token) return session_token + return "" + except Exception: # pylint: disable=broad-except + return "" + + async def get_session_token_async( + self, + resource_path: str, + pk_value: str, + container_properties_cache: Dict[str, Dict[str, Any]], + routing_map_provider: Any, + partition_key_range_id: Optional[str]) -> str: + """Get Session Token for collection_link and operation_type. + + :param str resource_path: Self link / path to the resource + :param str operation_type: Operation type (e.g. 'Create', 'Read', 'Upsert', 'Replace') + :param str pk_value: The partition key value being used for the operation + :param container_properties_cache: Container properties cache used to fetch partition key definitions + :type container_properties_cache: Dict[str, Dict[str, Any]] + :param Any routing_map_provider: The routing map provider containing the partition key range cache logic + :param str partition_key_range_id: The partition key range ID used for the operation + :return: Session Token dictionary for the collection_id, will be empty string if not found or if the operation + does not require a session token (single master write operations). + :rtype: str + """ + + with self.session_lock: + is_name_based = _base.IsNameBased(resource_path) + collection_rid = "" + session_token = "" - # return empty token if not found + try: + if is_name_based: + # get the collection name + collection_name = _base.GetItemContainerLink(resource_path) + collection_rid = self.collection_name_to_rid[collection_name] + else: + collection_rid = _base.GetItemContainerLink(resource_path) + + if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: + token_dict = self.rid_to_session_token[collection_rid] + if partition_key_range_id is not None: + vector_session_token = token_dict.get(partition_key_range_id) + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token.session_token) + else: + collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") + partition_key = PartitionKey(path=collection_pk_definition['paths'], + kind=collection_pk_definition['kind'], + version=collection_pk_definition['version']) + epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) + pk_range = await routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) + session_token_list = [] + parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) + for parent in parents: + vector_session_token = token_dict.get(parent) + session_token = "{0}:{1}".format(parent, vector_session_token.session_token) + session_token_list.append(session_token) + # if vector_session_token is not None: + # session_token = "{0}:{1}".format(parent, vector_session_token.session_token) + # session_token_list.append(session_token) + session_token = ",".join(session_token_list) + return session_token return "" except Exception: # pylint: disable=broad-except return "" - def set_session_token(self, response_result, response_headers): + def set_session_token(self, client_connection, response_result, response_headers): """Session token must only be updated from response of requests that successfully mutate resource on the server side (write, replace, delete etc). + :param client_connection: Client connection used to refresh the partition key range cache if needed :param dict response_result: :param dict response_headers: :return: None @@ -86,8 +168,6 @@ def set_session_token(self, response_result, response_headers): # x-ms-alt-content-path which is the string representation of the resource with self.session_lock: - collection_rid = "" - collection_name = "" try: self_link = response_result["_self"] @@ -105,10 +185,15 @@ def set_session_token(self, response_result, response_headers): response_result_id = response_result[response_result_id_key] else: return - collection_rid, collection_name = _base.GetItemContainerInfo( - self_link, alt_content_path, response_result_id - ) - + collection_rid, collection_name = _base.GetItemContainerInfo(self_link, alt_content_path, + response_result_id) + # if the response came in with a new partition key range id after a split, refresh the pk range cache + partition_key_range_id = response_headers.get(http_constants.HttpHeaders.PartitionKeyRangeID) + collection_ranges = None + if client_connection: + collection_ranges = client_connection._routing_map_provider._collection_routing_map_by_item.get(collection_name) + if collection_ranges and not collection_ranges._rangeById.get(partition_key_range_id): + client_connection.refresh_routing_map_provider() except ValueError: return except Exception: # pylint: disable=broad-except @@ -194,7 +279,7 @@ def parse_session_token(response_headers): class Session(object): - """State of a Azure Cosmos session. + """State of an Azure Cosmos session. This session object can be shared across clients within the same process. @@ -209,8 +294,13 @@ def __init__(self, url_connection): def clear_session_token(self, response_headers): self.session_container.clear_session_token(response_headers) - def update_session(self, response_result, response_headers): - self.session_container.set_session_token(response_result, response_headers) + def update_session(self, client_connection, response_result, response_headers): + self.session_container.set_session_token(client_connection, response_result, response_headers) + + def get_session_token(self, resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id): + return self.session_container.get_session_token(resource_path, pk_value, container_properties_cache, + routing_map_provider, partition_key_range_id) - def get_session_token(self, resource_path): - return self.session_container.get_session_token(resource_path) + async def get_session_token_async(self, resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id): + return await self.session_container.get_session_token_async(resource_path, pk_value, container_properties_cache, + routing_map_provider, partition_key_range_id) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 49219533a7e6..0fb466697109 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -737,7 +737,7 @@ async def Create( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -747,7 +747,7 @@ async def Create( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -763,11 +763,13 @@ async def Create( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Create, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Create) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -874,7 +876,7 @@ async def Upsert( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -884,7 +886,7 @@ async def Upsert( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -900,21 +902,22 @@ async def Upsert( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Upsert, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Upsert) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request self._UpdateSessionIfRequired(headers, result, self.last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Post( self, @@ -1178,7 +1181,7 @@ async def ReadConflict( async def Read( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1204,10 +1207,12 @@ async def Read( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, - options) + headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, + documents._OperationType.Read, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Read) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1456,16 +1461,18 @@ async def PatchItem( response_hook = kwargs.pop("response_hook", None) path = base.GetPathFromLink(document_link) document_id = base.GetResourceIdOrFullNameFromLink(document_link) - typ = "docs" + resource_type = "docs" if options is None: options = {} initial_headers = self.default_headers - headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, + headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Patch) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1540,7 +1547,7 @@ async def Replace( self, resource: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1566,10 +1573,12 @@ async def Replace( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, - options) + headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, + documents._OperationType.Replace, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Replace) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1864,7 +1873,7 @@ async def DeleteConflict( async def DeleteResource( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1889,10 +1898,12 @@ async def DeleteResource( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, - options) + headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, + documents._OperationType.Delete, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Delete) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2005,6 +2016,8 @@ async def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) + await base.set_session_token_header_async(self, headers, path, "docs", + documents._OperationType.Read, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2808,7 +2821,7 @@ async def QueryFeed( async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements,too-many-locals self, path: str, - typ: str, + resource_type: str, id_: Optional[str], result_fn: Callable[[Dict[str, Any]], List[Dict[str, Any]]], create_fn: Optional[Callable[['CosmosClientConnection', Dict[str, Any]], Dict[str, Any]]], @@ -2822,7 +2835,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, """Query for more than one Azure Cosmos resources. :param str path: - :param str typ: + :param str resource_type: :param str id_: :param function result_fn: :param function create_fn: @@ -2858,12 +2871,13 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if query is None: # Query operations will use ReadEndpoint even though it uses GET(for feed requests) request_params = _request_object.RequestObject( - typ, + resource_type, documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed ) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) - + headers = base.GetHeaders(self, initial_headers, "get", path, id_, resource_type, + request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, headers, path, resource_type, + request_params.operation_type, options, partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: await change_feed_state.populate_request_headers_async(self._routing_map_provider, headers) @@ -2889,9 +2903,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.SqlQuery) + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, + request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, req_headers, path, resource_type, + request_params.operation_type, options, partition_key_range_id) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) @@ -3030,7 +3046,7 @@ def _UpdateSessionIfRequired( if is_session_consistency and self.session: # update session - self.session.update_session(response_result, response_headers) + self.session.update_session(self, response_result, response_headers) PartitionResolverErrorMessage = ( "Couldn't find any partition resolvers for the database link provided. " @@ -3258,6 +3274,8 @@ async def DeleteAllItemsByPartitionKey( initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) + await base.set_session_token_header_async(self, headers, path, "partitionkey", + documents._OperationType.Delete, options) request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index ab9307db3443..fde1e57be711 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -80,7 +80,7 @@ def test_clear_session_token(self): self.created_collection.read_item(item=created_document['id'], partition_key='mypk') except exceptions.CosmosHttpResponseError as e: self.assertEqual(self.client.client_connection.session.get_session_token( - 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id), "") + 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id, "Read"), "") self.assertEqual(e.status_code, StatusCodes.NOT_FOUND) self.assertEqual(e.sub_status, SubStatusCodes.READ_SESSION_NOTAVAILABLE) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_container.py b/sdk/cosmos/azure-cosmos/tests/test_session_container.py index 2ee352571204..12d513b66455 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_container.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_container.py @@ -35,7 +35,7 @@ def test_create_collection(self): u'id': u'sample collection'} create_collection_response_header = {'x-ms-session-token': '0:0#409#24=-1#12=-1', 'x-ms-alt-content-path': 'dbs/sample%20database'} - self.session.update_session(create_collection_response_result, create_collection_response_header) + self.session.update_session(None, create_collection_response_result, create_collection_response_header) token = self.session.get_session_token(u'/dbs/sample%20database/colls/sample%20collection') assert token == '0:0#409#24=-1#12=-1' @@ -53,7 +53,7 @@ def test_document_requests(self): 'x-ms-alt-content-path': 'dbs/sample%20database/colls/sample%20collection', 'x-ms-content-path': 'DdAkAPS2rAA='} - self.session.update_session(create_document_response_result, create_document_response_header) + self.session.update_session(None, create_document_response_result, create_document_response_header) token = self.session.get_session_token(u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/docs/DdAkAPS2rAACAAAAAAAAAA==/') assert token == '0:0#406#24=-1#12=-1' From 9cf54e7982343f4e759a21e9033c388373428c1c Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 8 May 2025 19:39:36 -0400 Subject: [PATCH 02/52] async changes --- .../azure/cosmos/_cosmos_client_connection.py | 6 +- .../aio/document_producer.py | 4 +- .../aio/execution_dispatcher.py | 43 +++++---- .../aio/hybrid_search_aggregator.py | 9 +- .../aio/multi_execution_aggregator.py | 7 +- .../aio/non_streaming_order_by_aggregator.py | 7 +- .../azure-cosmos/azure/cosmos/_session.py | 52 ++++++----- .../aio/_cosmos_client_connection_async.py | 90 +++++++++++-------- .../azure/cosmos/aio/_query_iterable_async.py | 4 +- 9 files changed, 137 insertions(+), 85 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3d265a2c061c..d7b41f85f13f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2797,6 +2797,9 @@ def Read( request_params = RequestObject(resource_type, documents._OperationType.Read) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) + + self._UpdateSessionIfRequired(headers, result, last_response_headers) + self.last_response_headers = last_response_headers if response_hook: response_hook(last_response_headers, result) @@ -3381,7 +3384,8 @@ def _UpdateSessionIfRequired( if documents.ConsistencyLevel.Session == request_headers[http_constants.HttpHeaders.ConsistencyLevel]: is_session_consistency = True - if is_session_consistency and self.session: + if (is_session_consistency and self.session and + not base.IsMasterResource(request_headers[http_constants.HttpHeaders.ThinClientProxyResourceType])): # update session self.session.update_session(self, response_result, response_headers) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py index 382e0ee0f2b4..7584ef142cbd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py @@ -41,7 +41,7 @@ class _DocumentProducer(object): """ def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options, - response_hook): + response_hook, raw_response_hook): """ Constructor """ @@ -62,7 +62,7 @@ def __init__(self, partition_key_target_range, client, collection_link, query, d async def fetch_fn(options): return await self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"], - response_hook=response_hook) + response_hook=response_hook, raw_response_hook=raw_response_hook) self._ex_context = _DefaultQueryExecutionContext(client, self._options, fetch_fn) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index a85c6f2c9955..d7809ff39da5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -47,7 +47,8 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl to _MultiExecutionContextAggregator """ - def __init__(self, client, resource_link, query, options, fetch_function, response_hook): + def __init__(self, client, resource_link, query, options, fetch_function, + response_hook, raw_response_hook): """ Constructor """ @@ -58,6 +59,13 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon self._query = query self._fetch_function = fetch_function self._response_hook = response_hook + self._raw_response_hook = raw_response_hook + + async def _create_execution_context_with_query_plan(self): + query_to_use = self._query if self._query is not None else "Select * from root r" + query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway + (query_to_use, self._resource_link)) + self._execution_context = await self._create_pipelined_execution_context(query_execution_info) async def __anext__(self): """Returns the next query result. @@ -89,16 +97,14 @@ async def fetch_next_block(self): :return: List of results. :rtype: list """ - try: - return await self._execution_context.fetch_next_block() - except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - query_to_use = self._query if self._query is not None else "Select * from root r" - query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway - (query_to_use, self._resource_link)) - self._execution_context = await self._create_pipelined_execution_context(query_execution_info) - else: - raise e + if "enableCrossPartitionQuery" not in self._options: + try: + return await self._execution_context.fetch_next_block() + except CosmosHttpResponseError as e: + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + await self._create_execution_context_with_query_plan() + else: + await self._create_execution_context_with_query_plan() return await self._execution_context.fetch_next_block() @@ -129,7 +135,8 @@ async def _create_pipelined_execution_context(self, query_execution_info): self._query, self._options, query_execution_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) await execution_context_aggregator._configure_partition_ranges() elif query_execution_info.has_hybrid_search_query_info(): hybrid_search_query_info = query_execution_info._query_execution_info['hybridSearchQueryInfo'] @@ -140,15 +147,13 @@ async def _create_pipelined_execution_context(self, query_execution_info): self._options, query_execution_info, hybrid_search_query_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) await execution_context_aggregator._run_hybrid_search() else: - execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator(self._client, - self._resource_link, - self._query, - self._options, - query_execution_info, - self._response_hook) + execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator( + self._client, self._resource_link, self._query, self._options, query_execution_info, + self._response_hook, self._raw_response_hook) await execution_context_aggregator._configure_partition_ranges() return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator, query_execution_info) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py index 8b6c04108400..3aaf28b54ec0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py @@ -45,7 +45,7 @@ class _HybridSearchContextAggregator(_QueryExecutionContextBase): """ def __init__(self, client, resource_link, options, partitioned_query_execution_info, - hybrid_search_query_info, response_hook): + hybrid_search_query_info, response_hook, raw_response_hook): super(_HybridSearchContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -58,6 +58,7 @@ def __init__(self, client, resource_link, options, partitioned_query_execution_i self._aggregated_global_statistics = None self._document_producer_comparator = None self._response_hook = response_hook + self._raw_response_hook = raw_response_hook async def _run_hybrid_search(self): # Check if we need to run global statistics queries, and if so do for every partition in the container @@ -119,7 +120,8 @@ async def _run_hybrid_search(self): rewritten_query['rewrittenQuery'], self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) # verify all document producers have items/ no splits @@ -225,7 +227,8 @@ async def _repair_document_producer(self, query, target_all_ranges=False): query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py index 77b38baf61f2..e6fa8f1ef75e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py @@ -62,7 +62,8 @@ def peek(self): def size(self): return len(self._heap) - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_MultiExecutionContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -73,6 +74,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._partitioned_query_ex_info = partitioned_query_ex_info self._sort_orders = partitioned_query_ex_info.get_order_by() self._response_hook = response_hook + self._raw_response_hook = raw_response_hook if self._sort_orders: self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders) @@ -155,7 +157,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) async def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py index 1a6ed820d80c..4876ba477b3c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py @@ -22,7 +22,8 @@ class _NonStreamingOrderByContextAggregator(_QueryExecutionContextBase): by the user. """ - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_NonStreamingOrderByContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -36,6 +37,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._doc_producers = [] self._document_producer_comparator = document_producer._NonStreamingOrderByComparator(self._sort_orders) self._response_hook = response_hook + self._raw_response_hook = raw_response_hook async def __anext__(self): @@ -100,7 +102,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) async def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 811d778fca90..6dcf21ccb318 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -128,8 +128,10 @@ async def get_session_token_async( if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] if partition_key_range_id is not None: - vector_session_token = token_dict.get(partition_key_range_id) - session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token.session_token) + container_routing_map = routing_map_provider._collection_routing_map_by_item.get(collection_name) + current_range = container_routing_map._rangeById.get(partition_key_range_id) + if current_range is not None: + session_token = self._format_session_token(current_range, token_dict) else: collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") partition_key = PartitionKey(path=collection_pk_definition['paths'], @@ -137,17 +139,7 @@ async def get_session_token_async( version=collection_pk_definition['version']) epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) pk_range = await routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) - session_token_list = [] - parents = pk_range[0].get('parents').copy() - parents.append(pk_range[0]['id']) - for parent in parents: - vector_session_token = token_dict.get(parent) - session_token = "{0}:{1}".format(parent, vector_session_token.session_token) - session_token_list.append(session_token) - # if vector_session_token is not None: - # session_token = "{0}:{1}".format(parent, vector_session_token.session_token) - # session_token_list.append(session_token) - session_token = ",".join(session_token_list) + session_token = self._format_session_token(pk_range, token_dict) return session_token return "" except Exception: # pylint: disable=broad-except @@ -168,9 +160,12 @@ def set_session_token(self, client_connection, response_result, response_headers # x-ms-alt-content-path which is the string representation of the resource with self.session_lock: - try: - self_link = response_result["_self"] + self_link = response_result.get("_self") + # query results don't directly have a self_link - need to fetch it directly from one of the items + if self_link is None: + if 'Documents' in response_result and len(response_result['Documents']) > 0: + self_link = response_result['Documents'][0].get('_self') # extract alternate content path from the response_headers # (only document level resource updates will have this), @@ -182,11 +177,17 @@ def set_session_token(self, client_connection, response_result, response_headers response_result_id = None if alt_content_path_key in response_headers: alt_content_path = response_headers[http_constants.HttpHeaders.AlternateContentPath] - response_result_id = response_result[response_result_id_key] + if response_result_id_key in response_result: + response_result_id = response_result[response_result_id_key] else: return - collection_rid, collection_name = _base.GetItemContainerInfo(self_link, alt_content_path, - response_result_id) + if self_link is not None: + collection_rid, collection_name = _base.GetItemContainerInfo(self_link, alt_content_path, + response_result_id) + else: + # if for whatever reason we don't have a _self link at this point, we use the container name + collection_name = alt_content_path + collection_rid = self.collection_name_to_rid.get(collection_name) # if the response came in with a new partition key range id after a split, refresh the pk range cache partition_key_range_id = response_headers.get(http_constants.HttpHeaders.PartitionKeyRangeID) collection_ranges = None @@ -228,10 +229,9 @@ def set_session_token(self, client_connection, response_result, response_headers self.rid_to_session_token[collection_rid][id_] = parsed_tokens[id_] else: self.rid_to_session_token[collection_rid][id_] = parsed_tokens[id_].merge(old_session_token) - self.collection_name_to_rid[collection_name] = collection_rid else: self.rid_to_session_token[collection_rid] = parsed_tokens - self.collection_name_to_rid[collection_name] = collection_rid + self.collection_name_to_rid[collection_name] = collection_rid def clear_session_token(self, response_headers): with self.session_lock: @@ -277,6 +277,18 @@ def parse_session_token(response_headers): id_to_sessionlsn[id_] = sessionToken return id_to_sessionlsn + def _format_session_token(self, pk_range, token_dict): + session_token_list = [] + parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) + for parent in parents: + vector_session_token = token_dict.get(parent) + if vector_session_token is not None: + session_token = "{0}:{1}".format(parent, vector_session_token.session_token) + session_token_list.append(session_token) + session_token = ",".join(session_token_list) + return session_token + class Session(object): """State of an Azure Cosmos session. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 0364dab69bbc..b1e8b39a733b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -1220,11 +1220,12 @@ async def Read( request_params = _request_object.RequestObject(resource_type, documents._OperationType.Read) request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) + # update session for request mutates data on server side + self._UpdateSessionIfRequired(headers, result, last_response_headers) self.last_response_headers = last_response_headers if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Get( self, @@ -1594,8 +1595,7 @@ async def Replace( self._UpdateSessionIfRequired(headers, result, self.last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Put( self, @@ -2077,13 +2077,14 @@ def _QueryPartitionKeyRanges( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "pkranges") + resource_type = http_constants.ResourceType.PartitionKeyRange + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "pkranges", collection_id, lambda r: r["PartitionKeyRanges"], + path, resource_type, collection_id, lambda r: r["PartitionKeyRanges"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, @@ -2132,10 +2133,11 @@ def QueryDatabases( if options is None: options = {} + resource_type = http_constants.ResourceType.Database async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - "/dbs", "dbs", "", lambda r: r["Databases"], + "/dbs", resource_type, "", lambda r: r["Databases"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, @@ -2189,13 +2191,14 @@ def QueryContainers( if options is None: options = {} - path = base.GetPathFromLink(database_link, "colls") + resource_type = http_constants.ResourceType.Collection + path = base.GetPathFromLink(database_link, resource_type) database_id = base.GetResourceIdOrFullNameFromLink(database_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "colls", database_id, lambda r: r["DocumentCollections"], + path, resource_type, database_id, lambda r: r["DocumentCollections"], lambda _, body: body, query, options, **kwargs ), self.last_response_headers, @@ -2256,6 +2259,7 @@ def QueryItems( if options is None: options = {} + resource_type = http_constants.ResourceType.Document if base.IsDatabaseLink(database_or_container_link): return AsyncItemPaged( self, @@ -2266,14 +2270,14 @@ def QueryItems( page_iterator_class=query_iterable.QueryIterable ) - path = base.GetPathFromLink(database_or_container_link, "docs") + path = base.GetPathFromLink(database_or_container_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( path, - "docs", + resource_type, collection_id, lambda r: r["Documents"], lambda _, b: b, @@ -2292,7 +2296,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca fetch_function=fetch_fn, collection_link=database_or_container_link, page_iterator_class=query_iterable.QueryIterable, - response_hook=response_hook + response_hook=response_hook, + raw_response_hook=kwargs.get('raw_response_hook'), ) def QueryItemsChangeFeed( @@ -2410,10 +2415,11 @@ def QueryOffers( if options is None: options = {} + resource_type = http_constants.ResourceType.Offer async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - "/offers", "offers", "", lambda r: r["Offers"], lambda _, b: b, query, options, **kwargs + "/offers", resource_type, "", lambda r: r["Offers"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, ) @@ -2472,13 +2478,14 @@ def QueryUsers( if options is None: options = {} - path = base.GetPathFromLink(database_link, "users") + resource_type = http_constants.ResourceType.User + path = base.GetPathFromLink(database_link, resource_type) database_id = base.GetResourceIdOrFullNameFromLink(database_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "users", database_id, lambda r: r["Users"], + path, resource_type, database_id, lambda r: r["Users"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, @@ -2534,13 +2541,14 @@ def QueryPermissions( if options is None: options = {} - path = base.GetPathFromLink(user_link, "permissions") + resource_type = http_constants.ResourceType.Permission + path = base.GetPathFromLink(user_link, resource_type) user_id = base.GetResourceIdOrFullNameFromLink(user_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "permissions", user_id, lambda r: r["Permissions"], lambda _, b: b, query, options, **kwargs + path, resource_type, user_id, lambda r: r["Permissions"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, ) @@ -2595,13 +2603,14 @@ def QueryStoredProcedures( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "sprocs") + resource_type = http_constants.ResourceType.StoredProcedure + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "sprocs", collection_id, lambda r: r["StoredProcedures"], + path, resource_type, collection_id, lambda r: r["StoredProcedures"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, @@ -2657,13 +2666,14 @@ def QueryTriggers( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "triggers") + resource_type = http_constants.ResourceType.Trigger + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "triggers", collection_id, lambda r: r["Triggers"], lambda _, b: b, query, options, **kwargs + path, resource_type, collection_id, lambda r: r["Triggers"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, ) @@ -2718,13 +2728,14 @@ def QueryUserDefinedFunctions( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "udfs") + resource_type = http_constants.ResourceType.UserDefinedFunction + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "udfs", collection_id, lambda r: r["UserDefinedFunctions"], + path, resource_type, collection_id, lambda r: r["UserDefinedFunctions"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, @@ -2779,13 +2790,14 @@ def QueryConflicts( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "conflicts") + resource_type = http_constants.ResourceType.Conflict + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "conflicts", collection_id, lambda r: r["Conflicts"], + path, resource_type, collection_id, lambda r: r["Conflicts"], lambda _, b: b, query, options, **kwargs ), self.last_response_headers, @@ -2894,20 +2906,20 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if change_feed_state is not None: await change_feed_state.populate_request_headers_async(self._routing_map_provider, headers) - result, self.last_response_headers = await self.__Get(path, request_params, headers, **kwargs) + result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) + self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(self.last_response_headers, result) return __GetBodiesFromQueryResult(result) query = self.__CheckAndUnifyQueryFormat(query) - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" if not is_query_plan: initial_headers[http_constants.HttpHeaders.IsQuery] = "true" - if ( - self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, - CosmosClientConnection._QueryCompatibilityMode.Query)): + if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, + CosmosClientConnection._QueryCompatibilityMode.Query)): initial_headers[http_constants.HttpHeaders.ContentType] = runtime_constants.MediaTypes.QueryJson elif self._query_compatibility_mode == CosmosClientConnection._QueryCompatibilityMode.SqlQuery: initial_headers[http_constants.HttpHeaders.ContentType] = runtime_constants.MediaTypes.SQL @@ -2919,8 +2931,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: request_params.set_excluded_location_from_options(options) req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, request_params.operation_type, options, partition_key_range_id) - await base.set_session_token_header_async(self, req_headers, path, resource_type, - request_params.operation_type, options, partition_key_range_id) + if not is_query_plan: + await base.set_session_token_header_async(self, req_headers, path, resource_type, + request_params.operation_type, options, partition_key_range_id) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) @@ -2971,13 +2984,15 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" - partial_result, self.last_response_headers = await self.__Post( + partial_result, last_response_headers = await self.__Post( path, request_params, query, req_headers, **kwargs ) + self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) if results: # add up all the query results from all over lapping ranges results["Documents"].extend(partial_result["Documents"]) @@ -2989,7 +3004,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if results: return __GetBodiesFromQueryResult(results) - result, self.last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) + result, last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = last_response_headers + # update session for request mutates data on server side + self._UpdateSessionIfRequired(req_headers, result, last_response_headers) + # TODO: this part might become an issue since HTTP/2 can return read-only headers if self.last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = self.last_response_headers[INDEX_METRICS_HEADER] @@ -3057,7 +3076,8 @@ def _UpdateSessionIfRequired( if documents.ConsistencyLevel.Session == request_headers[http_constants.HttpHeaders.ConsistencyLevel]: is_session_consistency = True - if is_session_consistency and self.session: + if (is_session_consistency and self.session and + not base.IsMasterResource(request_headers[http_constants.HttpHeaders.ThinClientProxyResourceType])): # update session self.session.update_session(self, response_result, response_headers) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py index 4a67671606dd..0fb3ad1c7fc4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py @@ -45,6 +45,7 @@ def __init__( partition_key=None, continuation_token=None, response_hook=None, + raw_response_hook=None, ): """Instantiates a QueryIterable for non-client side partitioning queries. @@ -75,7 +76,8 @@ def __init__( self._database_link = database_link self._partition_key = partition_key self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( - self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook) + self._client, self._collection_link, self._query, self._options, self._fetch_function, + response_hook, raw_response_hook) super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) async def _unpack(self, block): From 2320353bfb41f49a13b15a6f88d5728d7a0adaae Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 8 May 2025 20:17:54 -0400 Subject: [PATCH 03/52] sync changes --- .../azure/cosmos/_cosmos_client_connection.py | 71 +++++++++++-------- .../_execution_context/document_producer.py | 4 +- .../execution_dispatcher.py | 40 ++++++----- .../hybrid_search_aggregator.py | 12 ++-- .../multi_execution_aggregator.py | 7 +- .../non_streaming_order_by_aggregator.py | 7 +- .../azure/cosmos/_query_iterable.py | 4 +- .../azure-cosmos/azure/cosmos/_session.py | 8 ++- .../aio/_cosmos_client_connection_async.py | 2 +- 9 files changed, 94 insertions(+), 61 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index d7b41f85f13f..346d55560754 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -446,9 +446,10 @@ def QueryDatabases( if options is None: options = {} + resource_type = http_constants.ResourceType.Database def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - "/dbs", "dbs", "", lambda r: r["Databases"], + "/dbs", resource_type, "", lambda r: r["Databases"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -501,12 +502,13 @@ def QueryContainers( if options is None: options = {} + resource_type = http_constants.ResourceType.Collection path = base.GetPathFromLink(database_link, "colls") database_id = base.GetResourceIdOrFullNameFromLink(database_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "colls", database_id, lambda r: r["DocumentCollections"], + path, resource_type, database_id, lambda r: r["DocumentCollections"], lambda _, body: body, query, options, **kwargs) return ItemPaged( @@ -730,12 +732,13 @@ def QueryUsers( if options is None: options = {} - path = base.GetPathFromLink(database_link, "users") + resource_type = http_constants.ResourceType.User + path = base.GetPathFromLink(database_link, resource_type) database_id = base.GetResourceIdOrFullNameFromLink(database_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "users", database_id, lambda r: r["Users"], + path, resource_type, database_id, lambda r: r["Users"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -908,12 +911,13 @@ def QueryPermissions( if options is None: options = {} - path = base.GetPathFromLink(user_link, "permissions") + resource_type = http_constants.ResourceType.Permission + path = base.GetPathFromLink(user_link, resource_type) user_id = base.GetResourceIdOrFullNameFromLink(user_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "permissions", user_id, lambda r: r["Permissions"], + path, resource_type, user_id, lambda r: r["Permissions"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -1081,6 +1085,7 @@ def QueryItems( if options is None: options = {} + resource_type = http_constants.ResourceType.Document if base.IsDatabaseLink(database_or_container_link): return ItemPaged( self, @@ -1091,13 +1096,13 @@ def QueryItems( page_iterator_class=query_iterable.QueryIterable ) - path = base.GetPathFromLink(database_or_container_link, "docs") + path = base.GetPathFromLink(database_or_container_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(database_or_container_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( path, - "docs", + resource_type, collection_id, lambda r: r["Documents"], lambda _, b: b, @@ -1113,7 +1118,8 @@ def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInse fetch_function=fetch_fn, collection_link=database_or_container_link, page_iterator_class=query_iterable.QueryIterable, - response_hook=response_hook + response_hook=response_hook, + raw_response_hook=kwargs.get('raw_response_hook'), ) def QueryItemsChangeFeed( @@ -1253,12 +1259,13 @@ def _QueryPartitionKeyRanges( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "pkranges") + resource_type = http_constants.ResourceType.PartitionKeyRange + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "pkranges", collection_id, lambda r: r["PartitionKeyRanges"], + path, resource_type, collection_id, lambda r: r["PartitionKeyRanges"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -1451,12 +1458,13 @@ def QueryTriggers( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "triggers") + resource_type = http_constants.ResourceType.Trigger + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "triggers", collection_id, lambda r: r["Triggers"], + path, resource_type, collection_id, lambda r: r["Triggers"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -1604,12 +1612,13 @@ def QueryUserDefinedFunctions( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "udfs") + resource_type = http_constants.ResourceType.UserDefinedFunction + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "udfs", collection_id, lambda r: r["UserDefinedFunctions"], + path, resource_type, collection_id, lambda r: r["UserDefinedFunctions"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -1760,12 +1769,13 @@ def QueryStoredProcedures( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "sprocs") + resource_type = http_constants.ResourceType.StoredProcedure + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "sprocs", collection_id, lambda r: r["StoredProcedures"], + path, resource_type, collection_id, lambda r: r["StoredProcedures"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -1914,12 +1924,13 @@ def QueryConflicts( if options is None: options = {} - path = base.GetPathFromLink(collection_link, "conflicts") + resource_type = http_constants.ResourceType.Conflict + path = base.GetPathFromLink(collection_link, resource_type) collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - path, "conflicts", collection_id, lambda r: r["Conflicts"], + path, resource_type, collection_id, lambda r: r["Conflicts"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -2542,9 +2553,10 @@ def QueryOffers( if options is None: options = {} + resource_type = http_constants.ResourceType.Offer def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( - "/offers", "offers", "", lambda r: r["Offers"], + "/offers", resource_type, "", lambda r: r["Offers"], lambda _, b: b, query, options, **kwargs) return ItemPaged( @@ -2756,7 +2768,7 @@ def Replace( self.last_response_headers = last_response_headers # update session for request mutates data on server side - self._UpdateSessionIfRequired(headers, result, self.last_response_headers) + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, result) return CosmosDict(result, response_headers=last_response_headers) @@ -2797,7 +2809,7 @@ def Read( request_params = RequestObject(resource_type, documents._OperationType.Read) request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) - + # update session for request mutates data on server side self._UpdateSessionIfRequired(headers, result, last_response_headers) self.last_response_headers = last_response_headers @@ -2844,7 +2856,7 @@ def DeleteResource( self.last_response_headers = last_response_headers # update session for request mutates data on server side - self._UpdateSessionIfRequired(headers, result, self.last_response_headers) + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, None) @@ -3103,9 +3115,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: query = self.__CheckAndUnifyQueryFormat(query) - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" if not is_query_plan: - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" # TODO: check why we have this weird logic + initial_headers[http_constants.HttpHeaders.IsQuery] = "true" if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, CosmosClientConnection._QueryCompatibilityMode.Query)): @@ -3129,8 +3140,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) - base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery, - options) + if not is_query_plan: + base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery, + options) # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) @@ -3179,6 +3191,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, request_params, query, req_headers, **kwargs ) self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) if results: # add up all the query results from all over lapping ranges results["Documents"].extend(partial_result["Documents"]) @@ -3191,12 +3204,12 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: return __GetBodiesFromQueryResult(results), last_response_headers result, last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(req_headers, result, last_response_headers) if last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = last_response_headers[INDEX_METRICS_HEADER] last_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) - self.last_response_headers = last_response_headers - if response_hook: response_hook(last_response_headers, result) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py index dc01334f1905..f77504d3e9a1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py @@ -40,7 +40,7 @@ class _DocumentProducer(object): """ def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options, - response_hook): + response_hook, raw_response_hook): """ Constructor """ @@ -61,7 +61,7 @@ def __init__(self, partition_key_target_range, client, collection_link, query, d def fetch_fn(options): return self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"], - response_hook=response_hook) + response_hook=response_hook, raw_response_hook=raw_response_hook) self._ex_context = _DefaultQueryExecutionContext(client, self._options, fetch_fn) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index a93377e3dcf0..11802baf0e0b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -77,7 +77,7 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl to _MultiExecutionContextAggregator """ - def __init__(self, client, resource_link, query, options, fetch_function, response_hook): + def __init__(self, client, resource_link, query, options, fetch_function, response_hook, raw_response_hook): """ Constructor """ @@ -88,6 +88,13 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon self._query = query self._fetch_function = fetch_function self._response_hook = response_hook + self._raw_response_hook = raw_response_hook + + def _create_execution_context_with_query_plan(self): + query_to_use = self._query if self._query is not None else "Select * from root r" + query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway + (query_to_use, self._resource_link)) + self._execution_context = self._create_pipelined_execution_context(query_execution_info) def __next__(self): """Returns the next query result. @@ -119,19 +126,15 @@ def fetch_next_block(self): :return: List of results. :rtype: list """ - # TODO: NEED to change this - make every query retrieve a query plan - # also, we can't have this logic being returned to so often - there should be no need for this - # need to split up query plan logic and actual query iterating logic - try: - return self._execution_context.fetch_next_block() - except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - query_to_use = self._query if self._query is not None else "Select * from root r" - query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway - (query_to_use, self._resource_link)) - self._execution_context = self._create_pipelined_execution_context(query_execution_info) - else: - raise e + + if "enableCrossPartitionQuery" not in self._options: + try: + return self._execution_context.fetch_next_block() + except CosmosHttpResponseError as e: + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + self._create_execution_context_with_query_plan() + else: + self._create_execution_context_with_query_plan() return self._execution_context.fetch_next_block() @@ -162,7 +165,8 @@ def _create_pipelined_execution_context(self, query_execution_info): self._query, self._options, query_execution_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) elif query_execution_info.has_hybrid_search_query_info(): hybrid_search_query_info = query_execution_info._query_execution_info['hybridSearchQueryInfo'] _verify_valid_hybrid_search_query(hybrid_search_query_info) @@ -172,7 +176,8 @@ def _create_pipelined_execution_context(self, query_execution_info): self._options, query_execution_info, hybrid_search_query_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) execution_context_aggregator._run_hybrid_search() else: execution_context_aggregator = \ @@ -181,7 +186,8 @@ def _create_pipelined_execution_context(self, query_execution_info): self._query, self._options, query_execution_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator, query_execution_info) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py index ce0676bec460..5dc5fd8c2db9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py @@ -150,7 +150,7 @@ class _HybridSearchContextAggregator(_QueryExecutionContextBase): """ def __init__(self, client, resource_link, options, - partitioned_query_execution_info, hybrid_search_query_info, response_hook): + partitioned_query_execution_info, hybrid_search_query_info, response_hook, raw_response_hook): super(_HybridSearchContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -163,6 +163,7 @@ def __init__(self, client, resource_link, options, self._aggregated_global_statistics = None self._document_producer_comparator = None self._response_hook = response_hook + self._raw_response_hook = raw_response_hook def _run_hybrid_search(self): # Check if we need to run global statistics queries, and if so do for every partition in the container @@ -181,7 +182,8 @@ def _run_hybrid_search(self): global_statistics_query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) @@ -223,7 +225,8 @@ def _run_hybrid_search(self): rewritten_query['rewrittenQuery'], self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) # verify all document producers have items/ no splits @@ -350,7 +353,8 @@ def _repair_document_producer(self, query, target_all_ranges=False): query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py index e1747c6d50ee..01596db61956 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py @@ -63,7 +63,8 @@ def peek(self): def size(self): return len(self._heap) - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_MultiExecutionContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -74,6 +75,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._partitioned_query_ex_info = partitioned_query_ex_info self._sort_orders = partitioned_query_ex_info.get_order_by() self._response_hook = response_hook + self._raw_response_hook = raw_response_hook if self._sort_orders: self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders) @@ -187,7 +189,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py index 0bfc514e00d8..626dcd86a7d1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py @@ -22,7 +22,8 @@ class _NonStreamingOrderByContextAggregator(_QueryExecutionContextBase): by the user. """ - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_NonStreamingOrderByContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -34,6 +35,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._sort_orders = partitioned_query_ex_info.get_order_by() self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() self._response_hook = response_hook + self._raw_response_hook = raw_response_hook # will be a list of (partition_min, partition_max) tuples targetPartitionRanges = self._get_target_partition_key_range() @@ -143,7 +145,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py index 6663628dad5f..881ca9d9329e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py @@ -44,6 +44,7 @@ def __init__( partition_key=None, continuation_token=None, response_hook=None, + raw_response_hook=None, ): """Instantiates a QueryIterable for non-client side partitioning queries. @@ -74,7 +75,8 @@ def __init__( self._database_link = database_link self._partition_key = partition_key self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( - self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook) + self._client, self._collection_link, self._query, self._options, self._fetch_function, + response_hook, raw_response_hook) super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) def _unpack(self, block): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 6dcf21ccb318..c92d1f16ba9f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -76,7 +76,10 @@ def get_session_token( if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] if partition_key_range_id is not None: - session_token = token_dict.get(partition_key_range_id) + container_routing_map = routing_map_provider._collection_routing_map_by_item.get(collection_name) + current_range = container_routing_map._rangeById.get(partition_key_range_id) + if current_range is not None: + session_token = self._format_session_token(current_range, token_dict) else: collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") partition_key = PartitionKey(path=collection_pk_definition['paths'], @@ -84,8 +87,7 @@ def get_session_token( version=collection_pk_definition['version']) epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) - vector_session_token = token_dict.get(pk_range[0]['id']) - session_token = "{0}:{1}".format(pk_range[0]['id'], vector_session_token.session_token) + session_token = self._format_session_token(pk_range, token_dict) return session_token return "" except Exception: # pylint: disable=broad-except diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index b1e8b39a733b..f0a6c0e9c36e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -919,7 +919,7 @@ async def Upsert( result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request - self._UpdateSessionIfRequired(headers, result, self.last_response_headers) + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, result) return CosmosDict(result, response_headers=last_response_headers) From dcc11ddd532f2869862d9aff146a6aeeebac22dd Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 9 May 2025 13:46:28 -0400 Subject: [PATCH 04/52] mypy/pylint --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 25 ++++++----- .../azure/cosmos/_cosmos_client_connection.py | 6 +-- .../aio/hybrid_search_aggregator.py | 3 +- .../aio/non_streaming_order_by_aggregator.py | 7 ++-- .../non_streaming_order_by_aggregator.py | 6 +-- .../azure-cosmos/azure/cosmos/_session.py | 41 +++++++++++-------- .../aio/_cosmos_client_connection_async.py | 9 ++-- 7 files changed, 55 insertions(+), 42 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index c6a74bb6777f..4dfd7efef5a5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -45,6 +45,7 @@ from ._cosmos_client_connection import CosmosClientConnection from .aio._cosmos_client_connection_async import CosmosClientConnection as AsyncClientConnection +# pylint: disable=protected-access _COMMON_OPTIONS = { 'initial_headers': 'initialHeaders', @@ -319,7 +320,7 @@ def _is_session_token_request( cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], headers: dict, resource_type: str, - operation_type: str) -> None: + operation_type: str) -> bool: consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) # Figure out if consistency level for this request is session is_session_consistency = consistency_level == documents.ConsistencyLevel.Session @@ -350,11 +351,12 @@ def set_session_token_header( if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ cosmos_client_connection.session: # populate session token from the client's session container - session_token = cosmos_client_connection.session.get_session_token(path, - options.get('partitionKey'), - cosmos_client_connection._container_properties_cache, - cosmos_client_connection._routing_map_provider, - partition_key_range_id) + session_token = ( + cosmos_client_connection.session.get_session_token(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id)) if session_token != "": headers[http_constants.HttpHeaders.SessionToken] = session_token @@ -377,11 +379,12 @@ async def set_session_token_header_async( if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ cosmos_client_connection.session: # populate session token from the client's session container - session_token = await cosmos_client_connection.session.get_session_token_async(path, - options.get('partitionKey'), - cosmos_client_connection._container_properties_cache, - cosmos_client_connection._routing_map_provider, - partition_key_range_id) + session_token = \ + await cosmos_client_connection.session.get_session_token_async(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id) if session_token != "": headers[http_constants.HttpHeaders.SessionToken] = session_token diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 346d55560754..a344e0139a7f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2829,7 +2829,7 @@ def DeleteResource( """Deletes an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -3115,9 +3115,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: query = self.__CheckAndUnifyQueryFormat(query) - if not is_query_plan: - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" - if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, CosmosClientConnection._QueryCompatibilityMode.Query)): initial_headers[http_constants.HttpHeaders.ContentType] = runtime_constants.MediaTypes.QueryJson @@ -3141,6 +3138,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key_range_id ) if not is_query_plan: + req_headers[http_constants.HttpHeaders.IsQuery] = "true" base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery, options) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py index 3aaf28b54ec0..196364aa55fa 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py @@ -77,7 +77,8 @@ async def _run_hybrid_search(self): global_statistics_query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py index 4876ba477b3c..20e0471ed68c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py @@ -32,10 +32,10 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._resource_link = resource_link self._query = query self._partitioned_query_ex_info = partitioned_query_ex_info - self._sort_orders = partitioned_query_ex_info.get_order_by() self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() self._doc_producers = [] - self._document_producer_comparator = document_producer._NonStreamingOrderByComparator(self._sort_orders) + self._document_producer_comparator = ( + document_producer._NonStreamingOrderByComparator(partitioned_query_ex_info.get_order_by())) self._response_hook = response_hook self._raw_response_hook = raw_response_hook @@ -141,11 +141,12 @@ async def _configure_partition_ranges(self): pq_size = self._partitioned_query_ex_info.get_top() or\ self._partitioned_query_ex_info.get_limit() + self._partitioned_query_ex_info.get_offset() + sort_orders = self._partitioned_query_ex_info.get_order_by() for doc_producer in self._doc_producers: while True: try: result = await doc_producer.peek() - item_result = document_producer._NonStreamingItemResultProducer(result, self._sort_orders) + item_result = document_producer._NonStreamingItemResultProducer(result, sort_orders) await self._orderByPQ.push_async(item_result, self._document_producer_comparator) await doc_producer.__anext__() except StopAsyncIteration: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py index 626dcd86a7d1..c07864d7d767 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py @@ -32,7 +32,6 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._resource_link = resource_link self._query = query self._partitioned_query_ex_info = partitioned_query_ex_info - self._sort_orders = partitioned_query_ex_info.get_order_by() self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() self._response_hook = response_hook self._raw_response_hook = raw_response_hook @@ -40,7 +39,8 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i # will be a list of (partition_min, partition_max) tuples targetPartitionRanges = self._get_target_partition_key_range() - self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders) + sort_orders = partitioned_query_ex_info.get_order_by() + self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(sort_orders) targetPartitionQueryExecutionContextList = [] for partitionTargetRange in targetPartitionRanges: @@ -70,7 +70,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i while True: try: result = doc_producer.peek() - item_result = document_producer._NonStreamingItemResultProducer(result, self._sort_orders) + item_result = document_producer._NonStreamingItemResultProducer(result, sort_orders) self._orderByPQ.push(item_result) next(doc_producer) except StopIteration: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index c92d1f16ba9f..3e4271b83519 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -25,14 +25,17 @@ import sys import traceback import threading +from typing import Any, Dict, Optional from . import _base from . import http_constants +from ._routing.routing_map_provider import SmartRoutingMapProvider +from ._routing.aio.routing_map_provider import SmartRoutingMapProvider as SmartRoutingMapProviderAsync from ._vector_session_token import VectorSessionToken from .exceptions import CosmosHttpResponseError from .partition_key import PartitionKey -from typing import Any, Dict, Optional +# pylint: disable=protected-access class SessionContainer(object): def __init__(self): @@ -45,12 +48,13 @@ def get_session_token( resource_path: str, pk_value: str, container_properties_cache: Dict[str, Dict[str, Any]], - routing_map_provider: Any, + routing_map_provider: SmartRoutingMapProvider, partition_key_range_id: Optional[int]) -> str: - """Get Session Token for collection_link and operation_type. + """Get Session Token for the given collection and partition key information. :param str resource_path: Self link / path to the resource - :param str operation_type: Operation type (e.g. 'Create', 'Read', 'Upsert', 'Replace') + :param ~azure.cosmos.SmartRoutingMapProvider routing_map_provider: routing map containing relevant session + information, such as partition key ranges for a given collection :param str pk_value: The partition key value being used for the operation :param container_properties_cache: Container properties cache used to fetch partition key definitions :type container_properties_cache: Dict[str, Dict[str, Any]] @@ -60,7 +64,7 @@ def get_session_token( :rtype: str """ - with self.session_lock: + with (self.session_lock): is_name_based = _base.IsNameBased(resource_path) collection_rid = "" session_token = "" @@ -76,12 +80,12 @@ def get_session_token( if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] if partition_key_range_id is not None: - container_routing_map = routing_map_provider._collection_routing_map_by_item.get(collection_name) + container_routing_map = routing_map_provider._collection_routing_map_by_item[collection_name] current_range = container_routing_map._rangeById.get(partition_key_range_id) if current_range is not None: session_token = self._format_session_token(current_range, token_dict) else: - collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") + collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], kind=collection_pk_definition['kind'], version=collection_pk_definition['version']) @@ -98,12 +102,13 @@ async def get_session_token_async( resource_path: str, pk_value: str, container_properties_cache: Dict[str, Dict[str, Any]], - routing_map_provider: Any, + routing_map_provider: SmartRoutingMapProviderAsync, partition_key_range_id: Optional[str]) -> str: - """Get Session Token for collection_link and operation_type. + """Get Session Token for the given collection and partition key information. :param str resource_path: Self link / path to the resource - :param str operation_type: Operation type (e.g. 'Create', 'Read', 'Upsert', 'Replace') + :param ~azure.cosmos.SmartRoutingMapProviderAsync routing_map_provider: routing map containing relevant session + information, such as partition key ranges for a given collection :param str pk_value: The partition key value being used for the operation :param container_properties_cache: Container properties cache used to fetch partition key definitions :type container_properties_cache: Dict[str, Dict[str, Any]] @@ -130,12 +135,12 @@ async def get_session_token_async( if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] if partition_key_range_id is not None: - container_routing_map = routing_map_provider._collection_routing_map_by_item.get(collection_name) + container_routing_map = routing_map_provider._collection_routing_map_by_item[collection_name] current_range = container_routing_map._rangeById.get(partition_key_range_id) if current_range is not None: session_token = self._format_session_token(current_range, token_dict) else: - collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") + collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], kind=collection_pk_definition['kind'], version=collection_pk_definition['version']) @@ -152,6 +157,7 @@ def set_session_token(self, client_connection, response_result, response_headers successfully mutate resource on the server side (write, replace, delete etc). :param client_connection: Client connection used to refresh the partition key range cache if needed + :type client_connection: Union[azure.cosmos.CosmosClientConnection, azure.cosmos.aio.CosmosClientConnection] :param dict response_result: :param dict response_headers: :return: None @@ -161,7 +167,7 @@ def set_session_token(self, client_connection, response_result, response_headers # self link which has the rid representation of the resource, and # x-ms-alt-content-path which is the string representation of the resource - with self.session_lock: + with (self.session_lock): try: self_link = response_result.get("_self") # query results don't directly have a self_link - need to fetch it directly from one of the items @@ -194,7 +200,8 @@ def set_session_token(self, client_connection, response_result, response_headers partition_key_range_id = response_headers.get(http_constants.HttpHeaders.PartitionKeyRangeID) collection_ranges = None if client_connection: - collection_ranges = client_connection._routing_map_provider._collection_routing_map_by_item.get(collection_name) + collection_ranges = \ + client_connection._routing_map_provider._collection_routing_map_by_item.get(collection_name) if collection_ranges and not collection_ranges._rangeById.get(partition_key_range_id): client_connection.refresh_routing_map_provider() except ValueError: @@ -311,10 +318,12 @@ def clear_session_token(self, response_headers): def update_session(self, client_connection, response_result, response_headers): self.session_container.set_session_token(client_connection, response_result, response_headers) - def get_session_token(self, resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id): + def get_session_token(self, resource_path, pk_value, container_properties_cache, routing_map_provider, + partition_key_range_id): return self.session_container.get_session_token(resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id) - async def get_session_token_async(self, resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id): + async def get_session_token_async(self, resource_path, pk_value, container_properties_cache, routing_map_provider, + partition_key_range_id): return await self.session_container.get_session_token_async(resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index f0a6c0e9c36e..ce1edf40d186 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -1195,7 +1195,7 @@ async def Read( """Reads an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -1565,7 +1565,7 @@ async def Replace( :param dict resource: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -1890,7 +1890,7 @@ async def DeleteResource( """Deletes an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2673,7 +2673,8 @@ def QueryTriggers( async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, resource_type, collection_id, lambda r: r["Triggers"], lambda _, b: b, query, options, **kwargs + path, resource_type, collection_id, lambda r: r["Triggers"], lambda _, b: b, query, options, + **kwargs ), self.last_response_headers, ) From 10c2c9a183c4b289c2590c173d005e6a3a4bb17b Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 9 May 2025 15:03:00 -0400 Subject: [PATCH 05/52] Update _session.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 3e4271b83519..19704ca95909 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -64,7 +64,7 @@ def get_session_token( :rtype: str """ - with (self.session_lock): + with self.session_lock: is_name_based = _base.IsNameBased(resource_path) collection_rid = "" session_token = "" @@ -167,7 +167,7 @@ def set_session_token(self, client_connection, response_result, response_headers # self link which has the rid representation of the resource, and # x-ms-alt-content-path which is the string representation of the resource - with (self.session_lock): + with self.session_lock: try: self_link = response_result.get("_self") # query results don't directly have a self_link - need to fetch it directly from one of the items From 2c7a2db7a787e393be9925e3dd178181ab65efcf Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 9 May 2025 17:17:06 -0400 Subject: [PATCH 06/52] mark query plan as fetched for query --- .../cosmos/_execution_context/aio/execution_dispatcher.py | 4 +++- .../azure/cosmos/_execution_context/execution_dispatcher.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index d7809ff39da5..3ff7657bc1ee 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -60,8 +60,10 @@ def __init__(self, client, resource_link, query, options, fetch_function, self._fetch_function = fetch_function self._response_hook = response_hook self._raw_response_hook = raw_response_hook + self._fetched_query_plan = False async def _create_execution_context_with_query_plan(self): + self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway (query_to_use, self._resource_link)) @@ -97,7 +99,7 @@ async def fetch_next_block(self): :return: List of results. :rtype: list """ - if "enableCrossPartitionQuery" not in self._options: + if self._fetched_query_plan or "enableCrossPartitionQuery" not in self._options: try: return await self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 11802baf0e0b..cf4946a8bfb1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -89,8 +89,10 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon self._fetch_function = fetch_function self._response_hook = response_hook self._raw_response_hook = raw_response_hook + self._fetched_query_plan = False def _create_execution_context_with_query_plan(self): + self._fetched_query_plan = True query_to_use = self._query if self._query is not None else "Select * from root r" query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway (query_to_use, self._resource_link)) @@ -127,7 +129,7 @@ def fetch_next_block(self): :rtype: list """ - if "enableCrossPartitionQuery" not in self._options: + if self._fetched_query_plan or "enableCrossPartitionQuery" not in self._options: try: return self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: From 31a927e162a462e46e170073c50ac75a66712c7f Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 12 May 2025 11:49:18 -0400 Subject: [PATCH 07/52] adjust logic after merging --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 21 ++++++------ .../azure/cosmos/_cosmos_client_connection.py | 22 +++++------- .../aio/_cosmos_client_connection_async.py | 34 +++++++------------ 3 files changed, 31 insertions(+), 46 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 4dfd7efef5a5..699876e3132c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from ._cosmos_client_connection import CosmosClientConnection from .aio._cosmos_client_connection_async import CosmosClientConnection as AsyncClientConnection + from ._request_object import RequestObject # pylint: disable=protected-access @@ -319,29 +320,28 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches def _is_session_token_request( cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], headers: dict, - resource_type: str, - operation_type: str) -> bool: + request_object) -> bool: consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) # Figure out if consistency level for this request is session is_session_consistency = consistency_level == documents.ConsistencyLevel.Session # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account # configured to use multiple write regions - return (is_session_consistency is True and not IsMasterResource(resource_type) - and (documents._OperationType.IsReadOnlyOperation(operation_type) or operation_type == "Batch" - or cosmos_client_connection._global_endpoint_manager.get_use_multiple_write_locations())) + return (is_session_consistency is True and not IsMasterResource(request_object.resource_type) + and (documents._OperationType.IsReadOnlyOperation(request_object.operation_type) + or request_object.operation_type == "Batch" + or cosmos_client_connection._global_endpoint_manager.get_use_multiple_write_locations())) def set_session_token_header( cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], headers: dict, path: str, - resource_type: str, - operation_type: str, + request_object: "RequestObject", options: Mapping[str, Any], partition_key_range_id: Optional[str] = None) -> None: # set session token if required - if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type): + if _is_session_token_request(cosmos_client_connection, headers, request_object): # if there is a token set via option, then use it to override default if options.get("sessionToken"): headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] @@ -364,12 +364,11 @@ async def set_session_token_header_async( cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], headers: dict, path: str, - resource_type: str, - operation_type: str, + request_object, options: Mapping[str, Any], partition_key_range_id: Optional[str] = None) -> None: # set session token if required - if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type): + if _is_session_token_request(cosmos_client_connection, headers, request_object): # if there is a token set via option, then use it to override default if options.get("sessionToken"): headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index a344e0139a7f..670dc2418819 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2057,10 +2057,10 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) - base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2148,9 +2148,9 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - base.set_session_token_header(self, headers, path, "docs", documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2666,11 +2666,10 @@ def Create( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Create, options) - base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(resource_type, documents._OperationType.Create) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2714,11 +2713,10 @@ def Upsert( headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True - base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Upsert, options) - # Upsert will use WriteEndpoint since it uses POST operation request_params = RequestObject(resource_type, documents._OperationType.Upsert) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2760,10 +2758,10 @@ def Replace( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, documents._OperationType.Replace, options) - base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Replace) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2804,10 +2802,10 @@ def Read( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, documents._OperationType.Read, options) - base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation request_params = RequestObject(resource_type, documents._OperationType.Read) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) # update session for request mutates data on server side self._UpdateSessionIfRequired(headers, result, last_response_headers) @@ -2848,10 +2846,10 @@ def DeleteResource( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, documents._OperationType.Delete, options) - base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation request_params = RequestObject(resource_type, documents._OperationType.Delete) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3100,8 +3098,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) - base.set_session_token_header(self, headers, path, resource_type, request_params.operation_type, options, - partition_key_range_id) + base.set_session_token_header(self, headers, path, request_params, options, partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -3139,8 +3136,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: ) if not is_query_plan: req_headers[http_constants.HttpHeaders.IsQuery] = "true" - base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery, - options) + base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id) # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index ce1edf40d186..715f53a6a1fd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -768,12 +768,10 @@ async def Create( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Create, options) - await base.set_session_token_header_async(self, headers, path, resource_type, - documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(resource_type, documents._OperationType.Create) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -908,14 +906,12 @@ async def Upsert( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Upsert, options) - await base.set_session_token_header_async(self, headers, path, resource_type, - documents._OperationType.Upsert, options) - headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(resource_type, documents._OperationType.Upsert) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1214,11 +1210,10 @@ async def Read( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, documents._OperationType.Read, options) - await base.set_session_token_header_async(self, headers, path, resource_type, - documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation request_params = _request_object.RequestObject(resource_type, documents._OperationType.Read) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) # update session for request mutates data on server side self._UpdateSessionIfRequired(headers, result, last_response_headers) @@ -1476,11 +1471,10 @@ async def PatchItem( initial_headers = self.default_headers headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) - await base.set_session_token_header_async(self, headers, path, resource_type, - documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(resource_type, documents._OperationType.Patch) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1492,8 +1486,7 @@ async def PatchItem( self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def ReplaceOffer( self, @@ -1583,11 +1576,10 @@ async def Replace( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, documents._OperationType.Replace, options) - await base.set_session_token_header_async(self, headers, path, resource_type, - documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(resource_type, documents._OperationType.Replace) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1908,11 +1900,10 @@ async def DeleteResource( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, documents._OperationType.Delete, options) - await base.set_session_token_header_async(self, headers, path, resource_type, - documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation request_params = _request_object.RequestObject(resource_type, documents._OperationType.Delete) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2025,10 +2016,9 @@ async def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - await base.set_session_token_header_async(self, headers, path, "docs", - documents._OperationType.Read, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2901,8 +2891,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: request_params.set_excluded_location_from_options(options) headers = base.GetHeaders(self, initial_headers, "get", path, id_, resource_type, request_params.operation_type, options, partition_key_range_id) - await base.set_session_token_header_async(self, headers, path, resource_type, - request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, headers, path, request_params, options, + partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: await change_feed_state.populate_request_headers_async(self._routing_map_provider, headers) @@ -2933,8 +2923,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, request_params.operation_type, options, partition_key_range_id) if not is_query_plan: - await base.set_session_token_header_async(self, req_headers, path, resource_type, - request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, req_headers, path, request_params, options, + partition_key_range_id) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) From 73fa94ca45a8a29ea21d96835e2e9ed95d4cf15a Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 12 May 2025 11:49:45 -0400 Subject: [PATCH 08/52] Update _base.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 699876e3132c..6e7be69f02f8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -364,7 +364,7 @@ async def set_session_token_header_async( cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], headers: dict, path: str, - request_object, + request_object: "RequestObject", options: Mapping[str, Any], partition_key_range_id: Optional[str] = None) -> None: # set session token if required From 450070a33f3999ff595319f24d7b40f7dd5aa486 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 12 May 2025 14:23:51 -0400 Subject: [PATCH 09/52] Update _base.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 6e7be69f02f8..e68f22a3e9c7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -330,7 +330,7 @@ def _is_session_token_request( return (is_session_consistency is True and not IsMasterResource(request_object.resource_type) and (documents._OperationType.IsReadOnlyOperation(request_object.operation_type) or request_object.operation_type == "Batch" - or cosmos_client_connection._global_endpoint_manager.get_use_multiple_write_locations())) + or cosmos_client_connection._global_endpoint_manager.can_use_multiple_write_locations(request_object))) def set_session_token_header( From 7411502c564f017c17fa9efe8393e2946afc6648 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 12 May 2025 19:07:01 -0400 Subject: [PATCH 10/52] tests - missing sync mwr --- sdk/cosmos/azure-cosmos/tests/test_config.py | 10 +- sdk/cosmos/azure-cosmos/tests/test_session.py | 67 ++++-- .../azure-cosmos/tests/test_session_async.py | 202 ++++++++++++++++++ 3 files changed, 256 insertions(+), 23 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_session_async.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index cbab2f07f710..71b47d6ef10a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -10,7 +10,7 @@ from azure.cosmos._retry_utility import _has_database_account_header, _has_read_retryable_headers from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.cosmos.http_constants import StatusCodes +from azure.cosmos.http_constants import StatusCodes, HttpHeaders from azure.cosmos.partition_key import PartitionKey from azure.cosmos import (ContainerProxy, DatabaseProxy, documents, exceptions, http_constants, _retry_utility) @@ -297,6 +297,14 @@ def __init__(self, headers=None, status_code=200, message="test-message"): def body(self): return None +def no_token_response_hook(raw_response): + request_headers = raw_response.http_request.headers + assert request_headers.get(HttpHeaders.SessionToken) is None + +def token_response_hook(raw_response): + request_headers = raw_response.http_request.headers + assert request_headers.get(HttpHeaders.SessionToken) is not None + class MockConnectionRetryPolicy(RetryPolicy): def __init__(self, resource_type, error=None, **kwargs): diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index fde1e57be711..00128107e5e9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -7,11 +7,10 @@ import pytest -import azure.cosmos._synchronized_request as synchronized_request import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.exceptions as exceptions import test_config -from azure.cosmos import DatabaseProxy +from azure.cosmos import DatabaseProxy, PartitionKey from azure.cosmos import _retry_utility from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders @@ -44,25 +43,45 @@ def setUpClass(cls): cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.created_collection = cls.created_db.get_container_client(cls.TEST_COLLECTION_ID) - def _MockRequest(self, global_endpoint_manager, request_params, connection_policy, pipeline_client, request): - if HttpHeaders.SessionToken in request.headers: - self.last_session_token_sent = request.headers[HttpHeaders.SessionToken] - else: - self.last_session_token_sent = None - return self._OriginalRequest(global_endpoint_manager, request_params, connection_policy, pipeline_client, - request) - - def test_session_token_not_sent_for_master_resource_ops(self): - self._OriginalRequest = synchronized_request._Request - synchronized_request._Request = self._MockRequest - created_document = self.created_collection.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}) - self.created_collection.read_item(item=created_document['id'], partition_key='mypk') - self.assertNotEqual(self.last_session_token_sent, None) - self.created_db.get_container_client(container=self.created_collection).read() - self.assertEqual(self.last_session_token_sent, None) - self.created_collection.read_item(item=created_document['id'], partition_key='mypk') - self.assertNotEqual(self.last_session_token_sent, None) - synchronized_request._Request = self._OriginalRequest + def test_session_token_sm_for_ops(self): + + # Session token should not be sent for control plane operations + test_container = self.created_db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), raw_response_hook=test_config.no_token_response_hook) + self.created_db.get_container_client(container=self.created_collection).read(raw_response_hook=test_config.no_token_response_hook) + self.created_db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for document read/batch requests only - verify it is not sent for write requests + up_item = self.created_collection.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + replaced_item = self.created_collection.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + created_document = self.created_collection.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = self.created_collection.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = self.created_collection.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = self.created_collection.execute_item_batch(batch_operations, 'mypk', raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Verify no session tokens are sent for delete requests either - but verify session token is updated + self.created_collection.delete_item(created_document['id'], created_document['pk'], raw_response_hook=test_config.no_token_response_hook) + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.READ_SESSION_NOTAVAILABLE}) @@ -80,7 +99,11 @@ def test_clear_session_token(self): self.created_collection.read_item(item=created_document['id'], partition_key='mypk') except exceptions.CosmosHttpResponseError as e: self.assertEqual(self.client.client_connection.session.get_session_token( - 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id, "Read"), "") + 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id, + None, + None, + None, + None), "") self.assertEqual(e.status_code, StatusCodes.NOT_FOUND) self.assertEqual(e.sub_status, SubStatusCodes.READ_SESSION_NOTAVAILABLE) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_async.py b/sdk/cosmos/azure-cosmos/tests/test_session_async.py new file mode 100644 index 000000000000..7e80380bc5b3 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_session_async.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid + +import pytest + +from _fault_injection_transport_async import FaultInjectionTransportAsync +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos.aio import CosmosClient, _retry_utility_async +from azure.cosmos import DatabaseProxy, PartitionKey +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders +from azure.core.pipeline.transport import AioHttpTransport +from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse +from azure.core.rest import HttpRequest, AsyncHttpResponse +from typing import Awaitable, Callable + + + +@pytest.mark.cosmosEmulator +class TestSessionAsync(unittest.IsolatedAsyncioTestCase): + """Test to ensure escaping of non-ascii characters from partition key""" + + created_db: DatabaseProxy = None + client: CosmosClient = None + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + configs = test_config.TestConfig + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + TEST_COLLECTION_ID = configs.TEST_MULTI_PARTITION_CONTAINER_ID + + @classmethod + def setUpClass(cls): + if cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]': + raise Exception("You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + async def asyncSetUp(self): + self.client = CosmosClient(self.host, self.masterKey) + await self.client.__aenter__() + self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) + self.created_container = self.created_db.get_container_client(self.TEST_COLLECTION_ID) + + async def asyncTearDown(self): + await self.client.close() + + async def test_session_token_swr_for_ops_async(self): + # Session token should not be sent for control plane operations + test_container = await self.created_db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), raw_response_hook=test_config.no_token_response_hook) + await self.created_db.get_container_client(container=self.created_container).read(raw_response_hook=test_config.no_token_response_hook) + await self.created_db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for document read/batch requests only - verify it is not sent for write requests + up_item = await self.created_container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + replaced_item = await self.created_container.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + created_document = await self.created_container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = await self.created_container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = await self.created_container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = await self.created_container.execute_item_batch(batch_operations, 'mypk', raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Verify no session tokens are sent for delete requests either - but verify session token is updated + await self.created_container.delete_item(created_document['id'], created_document['pk'], raw_response_hook=test_config.no_token_response_hook) + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + + async def test_session_token_mwr_for_ops_async(self): + # For multiple write regions, all document requests should send out session tokens + # We will use fault injection to simulate the regions the emulator needs + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransportAsync() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], Awaitable[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + client = CosmosClient(self.host, self.masterKey, consistency_level="Session", + transport=custom_transport, multiple_write_locations=True) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_COLLECTION_ID) + await client.__aenter__() + + # Session token should not be sent for control plane operations + test_container = await db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), + raw_response_hook=test_config.no_token_response_hook) + await db.get_container_client(container=self.created_container).read( + raw_response_hook=test_config.no_token_response_hook) + await db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for all document requests since we have mwr configuration + up_item = await container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + replaced_item = await container.replace_item(item=up_item['id'], + body={'id': up_item['id'], 'song': 'song', + 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + created_document = await container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = await container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = await container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = await container.execute_item_batch(batch_operations, 'mypk', + raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Should get sent now that we have mwr configuration + await container.delete_item(created_document['id'], created_document['pk'], + raw_response_hook=test_config.token_response_hook) + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + + # Clean up + await client.delete_database(db.id) + + + def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): + response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.READ_SESSION_NOTAVAILABLE}) + raise exceptions.CosmosHttpResponseError( + status_code=StatusCodes.NOT_FOUND, + message="Read Session not available", + response=response) + + async def test_clear_session_token_async(self): + created_document = await self.created_container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}) + + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunctionSessionReadFailureOnce + try: + await self.created_container.read_item(item=created_document['id'], partition_key='mypk') + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(self.client.client_connection.session.get_session_token( + 'dbs/' + self.created_db.id + '/colls/' + self.created_container.id, + None, + None, + None, + None), "") + self.assertEqual(e.status_code, StatusCodes.NOT_FOUND) + self.assertEqual(e.sub_status, SubStatusCodes.READ_SESSION_NOTAVAILABLE) + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction + + async def _MockExecuteFunctionInvalidSessionTokenAsync(self, function, *args, **kwargs): + response = {'_self': 'dbs/90U1AA==/colls/90U1AJ4o6iA=/docs/90U1AJ4o6iABCT0AAAAABA==/', 'id': '1'} + headers = {HttpHeaders.SessionToken: '0:2', + HttpHeaders.AlternateContentPath: 'dbs/testDatabase/colls/testCollection'} + return (response, headers) + + async def test_internal_server_error_raised_for_invalid_session_token_received_from_server_async(self): + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunctionInvalidSessionTokenAsync + try: + await self.created_container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}) + self.fail('Test did not fail as expected') + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.http_error_message, "Could not parse the received session token: 2") + self.assertEqual(e.status_code, StatusCodes.INTERNAL_SERVER_ERROR) + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction From b2fb9431ad57a60d77bbdd9dc65e04fc5712482f Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Tue, 13 May 2025 12:16:50 -0400 Subject: [PATCH 11/52] sync mwr tests, test fixes --- sdk/cosmos/azure-cosmos/tests/test_config.py | 6 +- sdk/cosmos/azure-cosmos/tests/test_session.py | 74 ++++++++++++++++++- .../azure-cosmos/tests/test_session_async.py | 13 ++-- .../tests/test_session_container.py | 67 ----------------- 4 files changed, 82 insertions(+), 78 deletions(-) delete mode 100644 sdk/cosmos/azure-cosmos/tests/test_session_container.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 71b47d6ef10a..725f7af60a9a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -57,10 +57,10 @@ class TestConfig(object): THROUGHPUT_FOR_2_PARTITIONS = 12000 THROUGHPUT_FOR_1_PARTITION = 400 - TEST_DATABASE_ID = os.getenv('COSMOS_TEST_DATABASE_ID', "Python SDK Test Database " + str(uuid.uuid4())) + TEST_DATABASE_ID = os.getenv('COSMOS_TEST_DATABASE_ID', "PythonSDKTestDatabase-" + str(uuid.uuid4())) - TEST_SINGLE_PARTITION_CONTAINER_ID = "Single Partition Test Container " + str(uuid.uuid4()) - TEST_MULTI_PARTITION_CONTAINER_ID = "Multi Partition Test Container " + str(uuid.uuid4()) + TEST_SINGLE_PARTITION_CONTAINER_ID = "SinglePartitionTestContainer-" + str(uuid.uuid4()) + TEST_MULTI_PARTITION_CONTAINER_ID = "MultiPartitionTestContainer-" + str(uuid.uuid4()) TEST_CONTAINER_PARTITION_KEY = "pk" diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index 00128107e5e9..15cac3f2e2b5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -10,9 +10,12 @@ import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.exceptions as exceptions import test_config +from _fault_injection_transport import FaultInjectionTransport +from azure.core.rest import HttpRequest from azure.cosmos import DatabaseProxy, PartitionKey from azure.cosmos import _retry_utility from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders +from typing import Callable @pytest.mark.cosmosEmulator @@ -79,10 +82,79 @@ def test_session_token_sm_for_ops(self): assert batch_response_token != response_session_token # Verify no session tokens are sent for delete requests either - but verify session token is updated - self.created_collection.delete_item(created_document['id'], created_document['pk'], raw_response_hook=test_config.no_token_response_hook) + self.created_collection.delete_item(replaced_item['id'], replaced_item['pk'], raw_response_hook=test_config.no_token_response_hook) assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + def test_session_token_mwr_for_ops(self): + # For multiple write regions, all document requests should send out session tokens + # We will use fault injection to simulate the regions the emulator needs + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + client = cosmos_client.CosmosClient(self.host, self.masterKey, consistency_level="Session", + transport=custom_transport, multiple_write_locations=True) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_COLLECTION_ID) + + # Session token should not be sent for control plane operations + test_container = db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), + raw_response_hook=test_config.no_token_response_hook) + db.get_container_client(container=self.created_collection).read( + raw_response_hook=test_config.no_token_response_hook) + db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for all document requests since we have mwr configuration + # First write request won't have since tokens need to be populated on the client first + container.upsert_item(body={'id': '0' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + up_item = container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + replaced_item = container.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', + 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + created_document = container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = container.execute_item_batch(batch_operations, 'mypk', + raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Should get sent now that we have mwr configuration + container.delete_item(replaced_item['id'], replaced_item['pk'], + raw_response_hook=test_config.token_response_hook) + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + + def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.READ_SESSION_NOTAVAILABLE}) raise exceptions.CosmosHttpResponseError( diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_async.py b/sdk/cosmos/azure-cosmos/tests/test_session_async.py index 7e80380bc5b3..ee06ea913a56 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_async.py @@ -13,7 +13,6 @@ from azure.cosmos.aio import CosmosClient, _retry_utility_async from azure.cosmos import DatabaseProxy, PartitionKey from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders -from azure.core.pipeline.transport import AioHttpTransport from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse from typing import Awaitable, Callable @@ -84,7 +83,7 @@ async def test_session_token_swr_for_ops_async(self): assert batch_response_token != response_session_token # Verify no session tokens are sent for delete requests either - but verify session token is updated - await self.created_container.delete_item(created_document['id'], created_document['pk'], raw_response_hook=test_config.no_token_response_hook) + await self.created_container.delete_item(replaced_item['id'], replaced_item['pk'], raw_response_hook=test_config.no_token_response_hook) assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token @@ -119,8 +118,11 @@ async def test_session_token_mwr_for_ops_async(self): await db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) # Session token should be sent for all document requests since we have mwr configuration + # First write request won't have since tokens need to be populated on the client first + await container.upsert_item(body={'id': '0' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) up_item = await container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, - raw_response_hook=test_config.token_response_hook) + raw_response_hook=test_config.token_response_hook) replaced_item = await container.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', 'pk': 'mypk'}, @@ -150,14 +152,11 @@ async def test_session_token_mwr_for_ops_async(self): assert batch_response_token != response_session_token # Should get sent now that we have mwr configuration - await container.delete_item(created_document['id'], created_document['pk'], + await container.delete_item(replaced_item['id'], replaced_item['pk'], raw_response_hook=test_config.token_response_hook) assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token - # Clean up - await client.delete_database(db.id) - def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.READ_SESSION_NOTAVAILABLE}) diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_container.py b/sdk/cosmos/azure-cosmos/tests/test_session_container.py deleted file mode 100644 index 12d513b66455..000000000000 --- a/sdk/cosmos/azure-cosmos/tests/test_session_container.py +++ /dev/null @@ -1,67 +0,0 @@ -# The MIT License (MIT) -# Copyright (c) Microsoft Corporation. All rights reserved. - -import unittest - -import pytest - -import azure.cosmos.cosmos_client as cosmos_client -import test_config - - -# from types import * - -@pytest.mark.cosmosEmulator -class TestSessionContainer(unittest.TestCase): - # this test doesn't need real credentials, or connection to server - host = test_config.TestConfig.host - master_key = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy - - def setUp(self): - self.client = cosmos_client.CosmosClient(self.host, self.master_key, consistency_level="Session", - connection_policy=self.connectionPolicy) - self.session = self.client.client_connection.Session - - def tearDown(self): - pass - - def test_create_collection(self): - # validate session token population after create collection request - session_token = self.session.get_session_token('') - assert session_token == '' - - create_collection_response_result = {u'_self': u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/', u'_rid': u'DdAkAPS2rAA=', - u'id': u'sample collection'} - create_collection_response_header = {'x-ms-session-token': '0:0#409#24=-1#12=-1', - 'x-ms-alt-content-path': 'dbs/sample%20database'} - self.session.update_session(None, create_collection_response_result, create_collection_response_header) - - token = self.session.get_session_token(u'/dbs/sample%20database/colls/sample%20collection') - assert token == '0:0#409#24=-1#12=-1' - - token = self.session.get_session_token(u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/') - assert token == '0:0#409#24=-1#12=-1' - return True - - def test_document_requests(self): - # validate session token for rid based requests - create_document_response_result = {u'_self': u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/docs/DdAkAPS2rAACAAAAAAAAAA==/', - u'_rid': u'DdAkAPS2rAACAAAAAAAAAA==', - u'id': u'eb391181-5c49-415a-ab27-848ce21d5d11'} - create_document_response_header = {'x-ms-session-token': '0:0#406#24=-1#12=-1', - 'x-ms-alt-content-path': 'dbs/sample%20database/colls/sample%20collection', - 'x-ms-content-path': 'DdAkAPS2rAA='} - - self.session.update_session(None, create_document_response_result, create_document_response_header) - - token = self.session.get_session_token(u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/docs/DdAkAPS2rAACAAAAAAAAAA==/') - assert token == '0:0#406#24=-1#12=-1' - - token = self.session.get_session_token( - u'dbs/sample%20database/colls/sample%20collection/docs/eb391181-5c49-415a-ab27-848ce21d5d11') - assert token == '0:0#406#24=-1#12=-1' - - -if __name__ == '__main__': - unittest.main() From 6899cdb92ff67bd8644ba92eaa4ab7bf3d616f9d Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Tue, 13 May 2025 17:00:33 -0400 Subject: [PATCH 12/52] Update test_session_async.py --- sdk/cosmos/azure-cosmos/tests/test_session_async.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_async.py b/sdk/cosmos/azure-cosmos/tests/test_session_async.py index ee06ea913a56..b30102435dd0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_async.py @@ -156,6 +156,7 @@ async def test_session_token_mwr_for_ops_async(self): raw_response_hook=test_config.token_response_hook) assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + await client.close() def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): From 209d3d70d8d7efac3ce9404a5dbc1d43ef791e2b Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 21 May 2025 18:02:48 -0400 Subject: [PATCH 13/52] Update test_backwards_compatibility.py --- .../tests/test_backwards_compatibility.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py index a2efa52abf3d..4db6e42c6c6b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py @@ -67,10 +67,11 @@ def test_session_token_compatibility(self): database_list = list(self.client.list_databases(session_token=str(uuid.uuid4()))) database_list2 = list(self.client.query_databases(query="select * from c", session_token=str(uuid.uuid4()))) assert len(database_list) > 0 - # assert database_list == database_list2 + assert len(database_list2) > 0 database_read = database.read(session_token=str(uuid.uuid4())) assert database_read is not None self.client.delete_database(database2.id, session_token=str(uuid.uuid4())) + self.client.delete_database(database.id, session_token=str(uuid.uuid4())) try: database2.read() pytest.fail("Database read should have failed") @@ -78,14 +79,14 @@ def test_session_token_compatibility(self): assert e.status_code == 404 # Container - container = database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) + container = self.databaseForTest.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) assert container is not None - container2 = database.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) + container2 = self.databaseForTest.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) assert container2 is not None - container_list = list(database.list_containers(session_token=str(uuid.uuid4()))) - container_list2 = list(database.query_containers(query="select * from c", session_token=str(uuid.uuid4()))) + container_list = list(self.databaseForTest.list_containers(session_token=str(uuid.uuid4()))) + container_list2 = list(self.databaseForTest.query_containers(query="select * from c", session_token=str(uuid.uuid4()))) assert len(container_list) > 0 - assert container_list == container_list2 + assert len(container_list2) > 0 container2_read = container2.read(session_token=str(uuid.uuid4())) assert container2_read is not None replace_container = database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, session_token=str(uuid.uuid4())) @@ -93,15 +94,14 @@ def test_session_token_compatibility(self): assert replace_container is not None assert replace_container_read != container2_read assert 'defaultTtl' in replace_container_read # Check for default_ttl as a new additional property - database.delete_container(replace_container.id, session_token=str(uuid.uuid4())) + self.databaseForTest.delete_container(replace_container.id, session_token=str(uuid.uuid4())) + self.databaseForTest.delete_container(container.id, session_token=str(uuid.uuid4())) try: container2.read() pytest.fail("Container read should have failed") except CosmosHttpResponseError as e: assert e.status_code == 404 - self.client.delete_database(database.id) - def test_etag_match_condition_compatibility(self): # Verifying that behavior is unaffected across the board for using `etag`/`match_condition` on irrelevant methods # Database From e80d2c7526273aa1762cfbf341f79568b149164e Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 22 May 2025 17:18:47 -0400 Subject: [PATCH 14/52] Update test_backwards_compatibility.py --- sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py index 4db6e42c6c6b..c5529be9e7c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py @@ -89,7 +89,7 @@ def test_session_token_compatibility(self): assert len(container_list2) > 0 container2_read = container2.read(session_token=str(uuid.uuid4())) assert container2_read is not None - replace_container = database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, session_token=str(uuid.uuid4())) + replace_container = self.databaseForTest.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, session_token=str(uuid.uuid4())) replace_container_read = replace_container.read() assert replace_container is not None assert replace_container_read != container2_read From 24c36b8c0ee80de81ff09965d617fb3a1610927b Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 22 May 2025 23:15:33 -0400 Subject: [PATCH 15/52] Update test_backwards_compatibility_async.py --- .../test_backwards_compatibility_async.py | 41 ++++++++----------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py index 86c5e6e1e4eb..9f2515df29b9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py @@ -49,9 +49,10 @@ async def test_session_token_compatibility_async(self): database_list = [db async for db in self.client.list_databases(session_token=str(uuid.uuid4()))] database_list2 = [db async for db in self.client.query_databases(query="select * from c", session_token=str(uuid.uuid4()))] assert len(database_list) > 0 - # assert database_list == database_list2 + assert len(database_list2) > 0 database_read = await database.read(session_token=str(uuid.uuid4())) assert database_read is not None + await self.client.delete_database(database.id, session_token=str(uuid.uuid4())) await self.client.delete_database(database2.id, session_token=str(uuid.uuid4())) try: await database2.read() @@ -60,31 +61,30 @@ async def test_session_token_compatibility_async(self): assert e.status_code == 404 # Container - container = await database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) + container = await self.created_database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) assert container is not None - container2 = await database.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) + container2 = await self.created_database.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) assert container2 is not None - container_list = [cont async for cont in database.list_containers(session_token=str(uuid.uuid4()))] - container_list2 = [cont async for cont in database.query_containers(query="select * from c", session_token=str(uuid.uuid4()))] + container_list = [cont async for cont in self.created_database.list_containers(session_token=str(uuid.uuid4()))] + container_list2 = [cont async for cont in self.created_database.query_containers(query="select * from c", session_token=str(uuid.uuid4()))] assert len(container_list) > 0 - assert container_list == container_list2 + assert len(container_list2) > 0 container2_read = await container2.read(session_token=str(uuid.uuid4())) assert container2_read is not None - replace_container = await database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, session_token=str(uuid.uuid4())) + replace_container = await self.created_database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, session_token=str(uuid.uuid4())) replace_container_read = await replace_container.read() assert replace_container is not None assert replace_container_read != container2_read assert 'defaultTtl' in replace_container_read # Check for default_ttl as a new additional property assert replace_container_read['defaultTtl'] == 30 - await database.delete_container(replace_container.id, session_token=str(uuid.uuid4())) + await self.created_database.delete_container(container.id, session_token=str(uuid.uuid4())) + await self.created_database.delete_container(container2.id, session_token=str(uuid.uuid4())) try: await container2.read() pytest.fail("Container read should have failed") except CosmosHttpResponseError as e: assert e.status_code == 404 - await self.client.delete_database(database.id) - async def test_etag_match_condition_compatibility_async(self): # Verifying that behavior is unaffected across the board for using `etag`/`match_condition` on irrelevant methods # Database @@ -93,33 +93,24 @@ async def test_etag_match_condition_compatibility_async(self): database2 = await self.client.create_database_if_not_exists(str(uuid.uuid4()), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) assert database2 is not None await self.client.delete_database(database2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) - try: - await database2.read() - pytest.fail("Database read should have failed") - except CosmosHttpResponseError as e: - assert e.status_code == 404 + await self.client.delete_database(database.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) # Container - container = await database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), + container = await self.created_database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) assert container is not None - container2 = await database.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), + container2 = await self.created_database.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) assert container2 is not None container2_read = await container2.read() assert container2_read is not None - replace_container = await database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, + replace_container = await self.created_database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) replace_container_read = await replace_container.read() assert replace_container is not None assert replace_container_read != container2_read assert 'defaultTtl' in replace_container_read # Check for default_ttl as a new additional property - await database.delete_container(replace_container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) - try: - await container2.read() - pytest.fail("Container read should have failed") - except CosmosHttpResponseError as e: - assert e.status_code == 404 + await self.created_database.delete_container(container2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) # Item item = await container.create_item({"id": str(uuid.uuid4()), "pk": 0}, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) @@ -143,7 +134,7 @@ async def test_etag_match_condition_compatibility_async(self): for result in batch_results: assert result['statusCode'] in (200, 201) - await self.client.delete_database(database.id) + await self.created_database.delete_container(container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) if __name__ == '__main__': From d0373046aec325885022f9b1756bf132a7185bc9 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 22 May 2025 23:19:05 -0400 Subject: [PATCH 16/52] Update test_backwards_compatibility_async.py --- .../tests/test_backwards_compatibility_async.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py index 9f2515df29b9..eeeeb7f6e448 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py @@ -93,7 +93,7 @@ async def test_etag_match_condition_compatibility_async(self): database2 = await self.client.create_database_if_not_exists(str(uuid.uuid4()), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) assert database2 is not None await self.client.delete_database(database2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) - await self.client.delete_database(database.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) + await self.client.delete_database(database.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) # Container container = await self.created_database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), @@ -110,7 +110,7 @@ async def test_etag_match_condition_compatibility_async(self): assert replace_container is not None assert replace_container_read != container2_read assert 'defaultTtl' in replace_container_read # Check for default_ttl as a new additional property - await self.created_database.delete_container(container2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) + await self.created_database.delete_container(container2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) # Item item = await container.create_item({"id": str(uuid.uuid4()), "pk": 0}, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) @@ -134,7 +134,7 @@ async def test_etag_match_condition_compatibility_async(self): for result in batch_results: assert result['statusCode'] in (200, 201) - await self.created_database.delete_container(container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) + await self.created_database.delete_container(container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) if __name__ == '__main__': From a78d07d6aa1a5c29fb6ec44377a0a8c2f178b440 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 22 May 2025 23:28:32 -0400 Subject: [PATCH 17/52] Update test_backwards_compatibility.py --- .../tests/test_backwards_compatibility.py | 31 ++++--------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py index c5529be9e7c9..11d38792fe3e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py @@ -72,11 +72,6 @@ def test_session_token_compatibility(self): assert database_read is not None self.client.delete_database(database2.id, session_token=str(uuid.uuid4())) self.client.delete_database(database.id, session_token=str(uuid.uuid4())) - try: - database2.read() - pytest.fail("Database read should have failed") - except CosmosHttpResponseError as e: - assert e.status_code == 404 # Container container = self.databaseForTest.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) @@ -96,11 +91,6 @@ def test_session_token_compatibility(self): assert 'defaultTtl' in replace_container_read # Check for default_ttl as a new additional property self.databaseForTest.delete_container(replace_container.id, session_token=str(uuid.uuid4())) self.databaseForTest.delete_container(container.id, session_token=str(uuid.uuid4())) - try: - container2.read() - pytest.fail("Container read should have failed") - except CosmosHttpResponseError as e: - assert e.status_code == 404 def test_etag_match_condition_compatibility(self): # Verifying that behavior is unaffected across the board for using `etag`/`match_condition` on irrelevant methods @@ -110,33 +100,24 @@ def test_etag_match_condition_compatibility(self): database2 = self.client.create_database_if_not_exists(str(uuid.uuid4()), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) assert database2 is not None self.client.delete_database(database2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) - try: - database2.read() - pytest.fail("Database read should have failed") - except CosmosHttpResponseError as e: - assert e.status_code == 404 + self.client.delete_database(database.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) # Container - container = database.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), + container = self.databaseForTest.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) assert container is not None - container2 = database.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), + container2 = self.databaseForTest.create_container_if_not_exists(str(uuid.uuid4()), PartitionKey(path="/pk"), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) assert container2 is not None container2_read = container2.read() assert container2_read is not None - replace_container = database.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, + replace_container = self.databaseForTest.replace_container(container2, PartitionKey(path="/pk"), default_ttl=30, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) replace_container_read = replace_container.read() assert replace_container is not None assert replace_container_read != container2_read assert 'defaultTtl' in replace_container_read # Check for default_ttl as a new additional property - database.delete_container(replace_container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) - try: - container2.read() - pytest.fail("Container read should have failed") - except CosmosHttpResponseError as e: - assert e.status_code == 404 + self.databaseForTest.delete_container(replace_container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) # Item item = container.create_item({"id": str(uuid.uuid4()), "pk": 0}, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) @@ -159,7 +140,7 @@ def test_etag_match_condition_compatibility(self): for result in batch_results: assert result['statusCode'] in (200, 201) - self.client.delete_database(database.id) + self.databaseForTest.delete_container(container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) if __name__ == "__main__": unittest.main() From 9ca6054a416ff7b8f3634ce77de626daa7ee46a4 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Tue, 27 May 2025 18:33:06 -0400 Subject: [PATCH 18/52] Update execution_dispatcher.py --- .../azure/cosmos/_execution_context/aio/execution_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 226db53d6897..b9d2d0e8ba2d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -110,7 +110,7 @@ async def fetch_next_block(self): else: await self._create_execution_context_with_query_plan() - return await self._execution_context.fetch_next_block() return await self._execution_context.__anext__() + return await self._execution_context.fetch_next_block() async def _create_pipelined_execution_context(self, query_execution_info): From b806e20e8bf8420a8f884610a339dcbaecc3e30a Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Tue, 27 May 2025 19:02:11 -0400 Subject: [PATCH 19/52] merge leftovers --- .../azure-cosmos/azure/cosmos/_cosmos_client_connection.py | 3 --- .../cosmos/_execution_context/aio/execution_dispatcher.py | 4 ---- .../azure/cosmos/_execution_context/execution_dispatcher.py | 4 ---- .../azure/cosmos/aio/_cosmos_client_connection_async.py | 1 - 4 files changed, 12 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index d696b3f24352..81c68cc30f7d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -446,7 +446,6 @@ def QueryDatabases( if options is None: options = {} - resource_type = http_constants.ResourceType.Database def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( "/dbs", http_constants.ResourceType.Database, "", lambda r: r["Databases"], @@ -1086,7 +1085,6 @@ def QueryItems( if options is None: options = {} - resource_type = http_constants.ResourceType.Document if base.IsDatabaseLink(database_or_container_link): return ItemPaged( self, @@ -2552,7 +2550,6 @@ def QueryOffers( if options is None: options = {} - resource_type = http_constants.ResourceType.Offer def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return self.__QueryFeed( "/offers", http_constants.ResourceType.Offer, "", lambda r: r["Offers"], diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index b9d2d0e8ba2d..a942c1a6b48d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -88,8 +88,6 @@ async def __anext__(self): else: await self._create_execution_context_with_query_plan() - return await self._execution_context.__anext__() - async def fetch_next_block(self): """Returns a block of results. @@ -110,8 +108,6 @@ async def fetch_next_block(self): else: await self._create_execution_context_with_query_plan() - return await self._execution_context.fetch_next_block() - async def _create_pipelined_execution_context(self, query_execution_info): assert self._resource_link, "code bug, resource_link is required." diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 732606165c5e..d14cfd3f529b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -117,8 +117,6 @@ def __next__(self): else: self._create_execution_context_with_query_plan() - return next(self._execution_context) - def fetch_next_block(self): """Returns a block of results. @@ -139,8 +137,6 @@ def fetch_next_block(self): else: self._create_execution_context_with_query_plan() - return self._execution_context.fetch_next_block() - def _create_pipelined_execution_context(self, query_execution_info): assert self._resource_link, "code bug, resource_link is required." if query_execution_info.has_aggregates() and not query_execution_info.has_select_value(): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index d1c9e6c963e5..f1a463f2577b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2265,7 +2265,6 @@ def QueryItems( if options is None: options = {} - resource_type = http_constants.ResourceType.Document if base.IsDatabaseLink(database_or_container_link): return AsyncItemPaged( self, From 34d63a8e1ec22cb95d7a08d697b897a627d3892a Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Tue, 27 May 2025 19:38:47 -0400 Subject: [PATCH 20/52] slip --- .../cosmos/_execution_context/aio/execution_dispatcher.py | 4 ++++ .../azure/cosmos/_execution_context/execution_dispatcher.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index a942c1a6b48d..b9d2d0e8ba2d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -88,6 +88,8 @@ async def __anext__(self): else: await self._create_execution_context_with_query_plan() + return await self._execution_context.__anext__() + async def fetch_next_block(self): """Returns a block of results. @@ -108,6 +110,8 @@ async def fetch_next_block(self): else: await self._create_execution_context_with_query_plan() + return await self._execution_context.fetch_next_block() + async def _create_pipelined_execution_context(self, query_execution_info): assert self._resource_link, "code bug, resource_link is required." diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index d14cfd3f529b..732606165c5e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -117,6 +117,8 @@ def __next__(self): else: self._create_execution_context_with_query_plan() + return next(self._execution_context) + def fetch_next_block(self): """Returns a block of results. @@ -137,6 +139,8 @@ def fetch_next_block(self): else: self._create_execution_context_with_query_plan() + return self._execution_context.fetch_next_block() + def _create_pipelined_execution_context(self, query_execution_info): assert self._resource_link, "code bug, resource_link is required." if query_execution_info.has_aggregates() and not query_execution_info.has_select_value(): From 82ff60c7ff200921f5c9ab794e86052f0b7c17b2 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 10:36:52 -0400 Subject: [PATCH 21/52] Update test_backwards_compatibility.py --- sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py index 3d754d254e71..4a94a2353984 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py @@ -96,8 +96,6 @@ def test_session_token_compatibility(self): except CosmosHttpResponseError as e: assert e.status_code == 404 - self.client.delete_database(database.id) - def test_etag_match_condition_compatibility(self): # Verifying that behavior is unaffected across the board for using `etag`/`match_condition` on irrelevant methods # Database From b06ad7d0cdb5192d6404f23fa8ea27f48778d809 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 10:37:23 -0400 Subject: [PATCH 22/52] Update test_backwards_compatibility.py --- sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py index 4a94a2353984..2b2f51771135 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py @@ -71,7 +71,6 @@ def test_session_token_compatibility(self): database_read = database.read(session_token=str(uuid.uuid4())) assert database_read is not None self.client.delete_database(database2.id, session_token=str(uuid.uuid4())) - self.client.delete_database(database.id, session_token=str(uuid.uuid4())) # Container container = self.databaseForTest.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) @@ -96,6 +95,8 @@ def test_session_token_compatibility(self): except CosmosHttpResponseError as e: assert e.status_code == 404 + self.client.delete_database(database.id) + def test_etag_match_condition_compatibility(self): # Verifying that behavior is unaffected across the board for using `etag`/`match_condition` on irrelevant methods # Database From eb885dde07bd0e1873f1bb665d3a7858037bc96d Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 13:20:51 -0400 Subject: [PATCH 23/52] Update test_query_hybrid_search_async.py --- sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 86f33f658204..d779ac2472e0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -295,7 +295,7 @@ async def test_hybrid_search_weighted_reciprocal_rank_fusion_async(self): query = "SELECT c.index, c.title FROM c " \ "ORDER BY RANK RRF(FullTextScore(c.text, 'United States'), VectorDistance(c.vector, {}), [1,1]) " \ "OFFSET 0 LIMIT 10".format(item_vector) - results = self.test_container.query_items(query, enable_cross_partition_query=True) + results = self.test_container.query_items(query) result_list = [res async for res in results] assert len(result_list) == 10 result_list = [res['index'] for res in result_list] From 008ac695202c0115f1480b26de9fcd5fd082bc4c Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 14:47:04 -0400 Subject: [PATCH 24/52] further changes, changelog --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 5 +- .../aio/execution_dispatcher.py | 6 +- .../execution_dispatcher.py | 7 ++- .../multi_execution_aggregator.py | 57 ++++++++++--------- 4 files changed, 42 insertions(+), 33 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index c6f7ad8f7b42..d5b028281603 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,14 +1,17 @@ ## Release History ### 4.12.0b2 (Unreleased) -* Added ability to use request level `excluded_locations` on metadata calls, such as getting container properties. See [PR 40905](https://github.com/Azure/azure-sdk-for-python/pull/40905) #### Features Added +* Added ability to use request level `excluded_locations` on metadata calls, such as getting container properties. See [PR 40905](https://github.com/Azure/azure-sdk-for-python/pull/40905) #### Bugs Fixed * Fixed issue where Query Change Feed did not return items if the container uses legacy Hash V1 Partition Keys. This also fixes issues with not being able to change feed query for Specific Partition Key Values for HPK. See [PR 41270](https://github.com/Azure/azure-sdk-for-python/pull/41270/) +* Fixed session container compound session token logic. The SDK will now only send the relevant tokens for each read request, as opposed to the entire compound session token for the container. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). +* Write requests will no longer send session tokens when using session consistency. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). #### Other Changes +* Cross-partition queries will now always send a query plan before attempting to execute. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). ### 4.12.0b1 (2025-05-19) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index b9d2d0e8ba2d..b5762b3e4a2a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -77,7 +77,7 @@ async def __anext__(self): :raises StopIteration: If no more result is left. """ - if self._fetched_query_plan or "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options: try: return await self._execution_context.__anext__() except CosmosHttpResponseError as e: @@ -99,7 +99,7 @@ async def fetch_next_block(self): :return: List of results. :rtype: list """ - if self._fetched_query_plan or "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options: try: return await self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: @@ -120,6 +120,8 @@ async def _create_pipelined_execution_context(self, query_execution_info): and self._options["enableCrossPartitionQuery"]): raise CosmosHttpResponseError(StatusCodes.BAD_REQUEST, "Cross partition query only supports 'VALUE ' for aggregates") + if self._fetched_query_plan: + self._options.pop("enableCrossPartitionQuery", None) # throw exception here for vector search query without limit filter or limit > max_limit if query_execution_info.get_non_streaming_order_by(): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 732606165c5e..51299494e351 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -106,7 +106,7 @@ def __next__(self): :raises StopIteration: If no more result is left. """ - if self._fetched_query_plan or "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options: try: return next(self._execution_context) except CosmosHttpResponseError as e: @@ -128,7 +128,7 @@ def fetch_next_block(self): :return: List of results. :rtype: list """ - if self._fetched_query_plan or "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options: try: return self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: @@ -149,6 +149,8 @@ def _create_pipelined_execution_context(self, query_execution_info): raise CosmosHttpResponseError( StatusCodes.BAD_REQUEST, "Cross partition query only supports 'VALUE ' for aggregates") + if self._fetched_query_plan: + self._options.pop("enableCrossPartitionQuery", None) # throw exception here for vector search query without limit filter or limit > max_limit if query_execution_info.get_non_streaming_order_by(): @@ -191,6 +193,7 @@ def _create_pipelined_execution_context(self, query_execution_info): query_execution_info, self._response_hook, self._raw_response_hook) + execution_context_aggregator._configure_partition_ranges() return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator, query_execution_info) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py index 62bf5d9dadfe..8c24ab41721a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py @@ -82,36 +82,8 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i else: self._document_producer_comparator = document_producer._PartitionKeyRangeDocumentProducerComparator() - # will be a list of (partition_min, partition_max) tuples - targetPartitionRanges = self._get_target_partition_key_range() - - targetPartitionQueryExecutionContextList = [] - for partitionTargetRange in targetPartitionRanges: - # create and add the child execution context for the target range - targetPartitionQueryExecutionContextList.append( - self._createTargetPartitionQueryExecutionContext(partitionTargetRange) - ) - self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() - for targetQueryExContext in targetPartitionQueryExecutionContextList: - try: - # TODO: we can also use more_itertools.peekable to be more python friendly - targetQueryExContext.peek() - # if there are matching results in the target ex range add it to the priority queue - - self._orderByPQ.push(targetQueryExContext) - - except exceptions.CosmosHttpResponseError as e: - if exceptions._partition_range_is_gone(e): - # repairing document producer context on partition split - self._repair_document_producer() - else: - raise - - except StopIteration: - continue - def __next__(self): """Returns the next result @@ -139,6 +111,35 @@ def fetch_next_block(self): raise NotImplementedError("You should use pipeline's fetch_next_block.") + def _configure_partition_ranges(self): + # will be a list of (partition_min, partition_max) tuples + targetPartitionRanges = self._get_target_partition_key_range() + + targetPartitionQueryExecutionContextList = [] + for partitionTargetRange in targetPartitionRanges: + # create and add the child execution context for the target range + targetPartitionQueryExecutionContextList.append( + self._createTargetPartitionQueryExecutionContext(partitionTargetRange) + ) + + for targetQueryExContext in targetPartitionQueryExecutionContextList: + try: + # TODO: we can also use more_itertools.peekable to be more python friendly + targetQueryExContext.peek() + # if there are matching results in the target ex range add it to the priority queue + + self._orderByPQ.push(targetQueryExContext) + + except exceptions.CosmosHttpResponseError as e: + if exceptions._partition_range_is_gone(e): + # repairing document producer context on partition split + self._repair_document_producer() + else: + raise + + except StopIteration: + continue + def _repair_document_producer(self): """Repairs the document producer context by using the re-initialized routing map provider in the client, which loads in a refreshed partition key range cache to re-create the partition key ranges. From 77758600c07e80e4ce315c3349fb265a7c2b2c5d Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 15:52:51 -0400 Subject: [PATCH 25/52] add tests --- .../azure/cosmos/_cosmos_client_connection.py | 1 - .../aio/_cosmos_client_connection_async.py | 1 - sdk/cosmos/azure-cosmos/tests/test_config.py | 27 +++++ .../tests/test_partition_split_query.py | 39 ++++--- .../tests/test_partition_split_query_async.py | 104 ++++++++++++++++++ 5 files changed, 152 insertions(+), 20 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 81c68cc30f7d..61e52456dca0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3245,7 +3245,6 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, excluded_l } if excluded_locations is not None: options["excludedLocations"] = excluded_locations - resource_link = base.TrimBeginningAndEndingSlashes(resource_link) path = base.GetPathFromLink(resource_link, http_constants.ResourceType.Document) resource_id = base.GetResourceIdOrFullNameFromLink(resource_link) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index f1a463f2577b..52fda5822068 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -3270,7 +3270,6 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, } if excluded_locations is not None: options["excludedLocations"] = excluded_locations - resource_link = base.TrimBeginningAndEndingSlashes(resource_link) path = base.GetPathFromLink(resource_link, http_constants.ResourceType.Document) resource_id = base.GetResourceIdOrFullNameFromLink(resource_link) diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index ee9b59fa58e3..28e6d8c08b61 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -3,6 +3,7 @@ import collections import os +import random import time import unittest import uuid @@ -290,6 +291,32 @@ def get_full_text_policy(path): ] } +def get_test_item(): + test_item = { + 'id': 'Item_' + str(uuid.uuid4()), + 'test_object': True, + 'lastName': 'Smith', + 'attr1': random.randint(0, 10) + } + return test_item + +def pre_split_hook(response): + request_headers = response.http_request.headers + session_token = request_headers.get('x-ms-session-token') + assert len(session_token) <= 20 + assert session_token.startswith('0:0') + assert session_token.count(':') == 1 + assert session_token.count(',') == 0 + +def post_split_hook(response): + request_headers = response.http_request.headers + session_token = request_headers.get('x-ms-session-token') + assert len(session_token) > 30 + assert len(session_token) < 60 # should only be 0-1 or 0-2, not 0-1-2 + assert session_token.startswith('0:0') + assert session_token.count(':') == 2 + assert session_token.count(',') == 1 + class ResponseHookCaller: def __init__(self): self.count = 0 diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index 68cd10722b01..d56de1dd8850 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -4,7 +4,6 @@ import random import time import unittest -import uuid import os import pytest @@ -12,17 +11,7 @@ import azure.cosmos.cosmos_client as cosmos_client import test_config from azure.cosmos import DatabaseProxy, PartitionKey, ContainerProxy -from azure.cosmos.exceptions import CosmosClientTimeoutError, CosmosHttpResponseError - - -def get_test_item(): - test_item = { - 'id': 'Item_' + str(uuid.uuid4()), - 'test_object': True, - 'lastName': 'Smith', - 'attr1': random.randint(0, 10) - } - return test_item +from azure.cosmos.exceptions import CosmosHttpResponseError def run_queries(container, iterations): @@ -40,8 +29,19 @@ def run_queries(container, iterations): assert str(attr_number) == curr # verify that all results match their randomly generated attributes print("validation succeeded for all query results") +def run_session_token_query(container, split): + query = "select * from c" + # verify session token sent makes sense for number of partitions present + if split: + query_iterable = container.query_items(query=query, enable_cross_partition_query=True, + raw_response_hook=test_config.post_split_hook) + else: + query_iterable = container.query_items(query=query, enable_cross_partition_query=True, + raw_response_hook=test_config.pre_split_hook) + list(query_iterable) + -@pytest.mark.cosmosQuery +@pytest.mark.cosmosSplit class TestPartitionSplitQuery(unittest.TestCase): database: DatabaseProxy = None container: ContainerProxy = None @@ -59,7 +59,8 @@ def setUpClass(cls): cls.database = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.container = cls.database.create_container( id=cls.TEST_CONTAINER_ID, - partition_key=PartitionKey(path="/id")) + partition_key=PartitionKey(path="/id"), + offer_throughput=cls.throughput) if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" @@ -72,29 +73,31 @@ def tearDownClass(cls) -> None: def test_partition_split_query(self): for i in range(100): - body = get_test_item() + body = test_config.get_test_item() self.container.create_item(body=body) start_time = time.time() print("created items, changing offer to 11k and starting queries") - self.database.replace_throughput(11000) + self.container.replace_throughput(11000) offer_time = time.time() print("changed offer to 11k") print("--------------------------------") print("now starting queries") run_queries(self.container, 100) # initial check for queries before partition split + run_session_token_query(self.container, False) # initial session token check before partition split print("initial check succeeded, now reading offer until replacing is done") - offer = self.database.get_throughput() + offer = self.container.get_throughput() while True: if time.time() - start_time > 60 * 25: # timeout test at 25 minutes unittest.skip("Partition split didn't complete in time.") if offer.properties['content'].get('isOfferReplacePending', False): time.sleep(10) - offer = self.database.get_throughput() + offer = self.container.get_throughput() else: print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) run_queries(self.container, 100) # check queries work post partition split + run_session_token_query(self.container, True) # check session token works post partition split self.assertTrue(offer.offer_throughput > self.throughput) return diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py new file mode 100644 index 000000000000..63d13e528edb --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -0,0 +1,104 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import time +import unittest +import random + +import pytest + +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy + +async def run_queries(container, iterations): + ret_list = [] + for i in range(iterations): + curr = str(random.randint(0, 10)) + query = 'SELECT * FROM c WHERE c.attr1=' + curr + ' order by c.attr1' + qlist = [item async for item in container.query_items(query=query, enable_cross_partition_query=True)] + ret_list.append((curr, qlist)) + for ret in ret_list: + curr = ret[0] + if len(ret[1]) != 0: + for results in ret[1]: + attr_number = results['attr1'] + assert str(attr_number) == curr # verify that all results match their randomly generated attributes + print("validation succeeded for all query results") + +async def run_session_token_query(container, split): + query = "select * from c" + # verify session token sent makes sense for number of partitions present + if split: + query_iterable = container.query_items(query=query, enable_cross_partition_query=True, + raw_response_hook=test_config.post_split_hook) + else: + query_iterable = container.query_items(query=query, enable_cross_partition_query=True, + raw_response_hook=test_config.pre_split_hook) + [item async for item in query_iterable] + + +@pytest.mark.cosmosSplit +class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): + database: DatabaseProxy = None + container: ContainerProxy = None + client: CosmosClient = None + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + throughput = 400 + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + TEST_CONTAINER_ID = "Single-partition-container-without-throughput-async" + + @classmethod + def setUpClass(cls): + if (cls.masterKey == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + async def asyncSetUp(self): + self.client = CosmosClient(self.host, self.masterKey) + self.created_database = self.client.get_database_client(self.TEST_DATABASE_ID) + self.container = await self.created_database.create_container( + id=self.TEST_CONTAINER_ID, + partition_key=PartitionKey(path="/id"), + offer_throughput=self.throughput) + + async def asyncTearDown(self): + await self.client.close() + + async def test_partition_split_query_async(self): + for i in range(100): + body = test_config.get_test_item() + await self.container.create_item(body=body) + + start_time = time.time() + print("created items, changing offer to 11k and starting queries") + await self.container.replace_throughput(11000) + offer_time = time.time() + print("changed offer to 11k") + print("--------------------------------") + print("now starting queries") + + await run_queries(self.container, 100) # initial check for queries before partition split + await run_session_token_query(self.container, False) # initial session token check before partition split + print("initial check succeeded, now reading offer until replacing is done") + offer = await self.container.get_throughput() + while True: + if time.time() - start_time > 60 * 25: # timeout test at 25 minutes + unittest.skip("Partition split didn't complete in time.") + if offer.properties['content'].get('isOfferReplacePending', False): + time.sleep(10) + offer = await self.container.get_throughput() + else: + print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) + await run_queries(self.container, 100) # check queries work post partition split + await run_session_token_query(self.container, True) # check session token works post partition split + self.assertTrue(offer.offer_throughput > self.throughput) + return + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From d4dd2d4ff87502c1f2ffa165fad0d41f081f9aff Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 16:01:07 -0400 Subject: [PATCH 26/52] typehint --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index e68f22a3e9c7..1720f53d8397 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -320,7 +320,7 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches def _is_session_token_request( cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], headers: dict, - request_object) -> bool: + request_object: "RequestObject") -> bool: consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) # Figure out if consistency level for this request is session is_session_consistency = consistency_level == documents.ConsistencyLevel.Session From 6c124bcf096fee263d469d35756aa524cd2d4af0 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 22:33:54 -0400 Subject: [PATCH 27/52] address comments --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 4 ++-- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 40c484f136cf..5200547d3518 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -8,8 +8,8 @@ #### Bugs Fixed * Fixed issue where Query Change Feed did not return items if the container uses legacy Hash V1 Partition Keys. This also fixes issues with not being able to change feed query for Specific Partition Key Values for HPK. See [PR 41270](https://github.com/Azure/azure-sdk-for-python/pull/41270/) -* Fixed session container compound session token logic. The SDK will now only send the relevant tokens for each read request, as opposed to the entire compound session token for the container. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). -* Write requests will no longer send session tokens when using session consistency. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). +* Fixed session container compound session token logic. The SDK will now only send the relevant partition-local session tokens for each read request, as opposed to the entire compound session token for the container. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). +* Write requests for single-write region accounts will no longer send session tokens when using session consistency. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). #### Other Changes * Cross-partition queries will now always send a query plan before attempting to execute. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index de3a5c8a2a0c..3cbaab634f38 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -329,7 +329,8 @@ def _is_session_token_request( # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account # configured to use multiple write regions - return (is_session_consistency is True and not IsMasterResource(request_object.resource_type) + return (is_session_consistency is True and cosmos_client_connection.session + and not IsMasterResource(request_object.resource_type) and (documents._OperationType.IsReadOnlyOperation(request_object.operation_type) or request_object.operation_type == "Batch" or cosmos_client_connection._global_endpoint_manager.can_use_multiple_write_locations(request_object))) From 2854cb55cc63500c8aeff0942566c329aba02cca Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 28 May 2025 22:41:08 -0400 Subject: [PATCH 28/52] Update _session.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 19704ca95909..9320d67d14e3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -66,7 +66,6 @@ def get_session_token( with self.session_lock: is_name_based = _base.IsNameBased(resource_path) - collection_rid = "" session_token = "" try: From cce381fb7be6a61846e64268583426374f3792f1 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 29 May 2025 09:11:00 -0400 Subject: [PATCH 29/52] Update _base.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 3cbaab634f38..bcad576b9e9c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -329,7 +329,7 @@ def _is_session_token_request( # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account # configured to use multiple write regions - return (is_session_consistency is True and cosmos_client_connection.session + return (is_session_consistency is True and cosmos_client_connection.session is not None and not IsMasterResource(request_object.resource_type) and (documents._OperationType.IsReadOnlyOperation(request_object.operation_type) or request_object.operation_type == "Batch" From b20a89fc620b986abc6d98fa6882062dd3e24d24 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 29 May 2025 11:24:20 -0400 Subject: [PATCH 30/52] ci tests --- .../tests/test_backwards_compatibility.py | 13 +++++++++++-- .../tests/test_backwards_compatibility_async.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py index 2b2f51771135..290308d6dcee 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility.py @@ -71,6 +71,11 @@ def test_session_token_compatibility(self): database_read = database.read(session_token=str(uuid.uuid4())) assert database_read is not None self.client.delete_database(database2.id, session_token=str(uuid.uuid4())) + try: + database2.read() + pytest.fail("Database read should have failed") + except CosmosHttpResponseError as e: + assert e.status_code == 404 # Container container = self.databaseForTest.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), session_token=str(uuid.uuid4())) @@ -105,7 +110,11 @@ def test_etag_match_condition_compatibility(self): database2 = self.client.create_database_if_not_exists(str(uuid.uuid4()), etag=str(uuid.uuid4()), match_condition=MatchConditions.IfNotModified) assert database2 is not None self.client.delete_database(database2.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) - self.client.delete_database(database.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) + try: + database2.read() + pytest.fail("Database read should have failed") + except CosmosHttpResponseError as e: + assert e.status_code == 404 # Container container = self.databaseForTest.create_container(str(uuid.uuid4()), PartitionKey(path="/pk"), @@ -150,7 +159,7 @@ def test_etag_match_condition_compatibility(self): for result in batch_results: assert result['statusCode'] in (200, 201) - self.databaseForTest.delete_container(container.id, etag=str(uuid.uuid4()), match_condition=MatchConditions.IfModified) + self.client.delete_database(database.id) if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py index 450de9657693..5167d0f05ea8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_backwards_compatibility_async.py @@ -147,4 +147,4 @@ async def test_etag_match_condition_compatibility_async(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From a5b18fe7377532b7b7c787a02cac7992088f74f6 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 29 May 2025 13:51:32 -0400 Subject: [PATCH 31/52] merging main --- .../azure-cosmos/azure/cosmos/_cosmos_client_connection.py | 3 +-- .../azure/cosmos/aio/_cosmos_client_connection_async.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 0ab14ab1abdf..e9bfddf529e1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3119,14 +3119,13 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) - base.set_session_token_header(self, headers, path, request_params, options, partition_key_range_id) - request_params = RequestObject( resource_type, op_type, headers ) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options, partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index db036da42103..ecae7e5d8231 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2911,7 +2911,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if query is None: op_type = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, op_typ, + headers = base.GetHeaders(self, initial_headers, "get", path, id_, resource_type, op_type, options, partition_key_range_id) request_params = _request_object.RequestObject( resource_type, @@ -2952,10 +2952,10 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, + documents._OperationType.SqlQuery, options, partition_key_range_id) request_params = _request_object.RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) request_params.set_excluded_location_from_options(options) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, - request_params.operation_type, options, partition_key_range_id) if not is_query_plan: await base.set_session_token_header_async(self, req_headers, path, request_params, options, partition_key_range_id) From bc3d85342d0c6548e60c94de8b554fe4f8de00c8 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 29 May 2025 16:17:28 -0400 Subject: [PATCH 32/52] Update _cosmos_client_connection.py --- .../azure-cosmos/azure/cosmos/_cosmos_client_connection.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index e9bfddf529e1..df1bdb362b80 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3162,13 +3162,12 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) + request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) if not is_query_plan: req_headers[http_constants.HttpHeaders.IsQuery] = "true" base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id) - request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) - request_params.set_excluded_location_from_options(options) - # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) if isPrefixPartitionQuery and "partitionKeyDefinition" in kwargs: From 575ea0ce46bc9dd363412d63f002e1f90cff27e9 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 29 May 2025 19:34:33 -0400 Subject: [PATCH 33/52] tests --- sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py | 2 +- .../azure-cosmos/tests/test_partition_split_query_async.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index d56de1dd8850..5891dd4efb88 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -41,7 +41,7 @@ def run_session_token_query(container, split): list(query_iterable) -@pytest.mark.cosmosSplit +@pytest.mark.cosmosQuery class TestPartitionSplitQuery(unittest.TestCase): database: DatabaseProxy = None container: ContainerProxy = None diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index 63d13e528edb..f684cb4ab668 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -35,10 +35,10 @@ async def run_session_token_query(container, split): else: query_iterable = container.query_items(query=query, enable_cross_partition_query=True, raw_response_hook=test_config.pre_split_hook) - [item async for item in query_iterable] + item_list = [item async for item in query_iterable] -@pytest.mark.cosmosSplit +@pytest.mark.cosmosQuery class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): database: DatabaseProxy = None container: ContainerProxy = None From 2a19516fc10e8119687e699ce4bef19a7cff773f Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 00:07:04 -0400 Subject: [PATCH 34/52] change session token logic --- .../aio/execution_dispatcher.py | 9 +++-- .../execution_dispatcher.py | 10 +++-- .../azure/cosmos/_query_iterable.py | 5 ++- .../azure-cosmos/azure/cosmos/_session.py | 39 +++++++++++++++++-- .../aio/_cosmos_client_connection_async.py | 31 +++++++++------ .../azure/cosmos/aio/_query_iterable_async.py | 5 ++- 6 files changed, 72 insertions(+), 27 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index b5762b3e4a2a..6d6cf8f12e8b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -33,7 +33,7 @@ from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos.documents import _DistinctType from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.cosmos.http_constants import StatusCodes +from azure.cosmos.http_constants import StatusCodes, ResourceType from ..._constants import _Constants as Constants # pylint: disable=protected-access @@ -48,7 +48,7 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl """ def __init__(self, client, resource_link, query, options, fetch_function, - response_hook, raw_response_hook): + response_hook, raw_response_hook, resource_type): """ Constructor """ @@ -58,6 +58,7 @@ def __init__(self, client, resource_link, query, options, fetch_function, self._resource_link = resource_link self._query = query self._fetch_function = fetch_function + self._resource_type = resource_type self._response_hook = response_hook self._raw_response_hook = raw_response_hook self._fetched_query_plan = False @@ -77,7 +78,7 @@ async def __anext__(self): :raises StopIteration: If no more result is left. """ - if "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: try: return await self._execution_context.__anext__() except CosmosHttpResponseError as e: @@ -99,7 +100,7 @@ async def fetch_next_block(self): :return: List of results. :rtype: list """ - if "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: try: return await self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 51299494e351..2bdadae96026 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -32,7 +32,7 @@ from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos.documents import _DistinctType -from azure.cosmos.http_constants import StatusCodes, SubStatusCodes +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, ResourceType from .._constants import _Constants as Constants # pylint: disable=protected-access @@ -77,7 +77,8 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl to _MultiExecutionContextAggregator """ - def __init__(self, client, resource_link, query, options, fetch_function, response_hook, raw_response_hook): + def __init__(self, client, resource_link, query, options, fetch_function, response_hook, + raw_response_hook, resource_type): """ Constructor """ @@ -87,6 +88,7 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon self._resource_link = resource_link self._query = query self._fetch_function = fetch_function + self._resource_type = resource_type self._response_hook = response_hook self._raw_response_hook = raw_response_hook self._fetched_query_plan = False @@ -106,7 +108,7 @@ def __next__(self): :raises StopIteration: If no more result is left. """ - if "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: try: return next(self._execution_context) except CosmosHttpResponseError as e: @@ -128,7 +130,7 @@ def fetch_next_block(self): :return: List of results. :rtype: list """ - if "enableCrossPartitionQuery" not in self._options: + if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: try: return self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py index 881ca9d9329e..be06b24478a8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py @@ -43,6 +43,7 @@ def __init__( database_link=None, partition_key=None, continuation_token=None, + resource_type=None, response_hook=None, raw_response_hook=None, ): @@ -55,7 +56,7 @@ def __init__( :param (str or dict) query: :param dict options: The request options for the request. :param method fetch_function: - :param method resource_type: The type of the resource being queried + :param str resource_type: The type of the resource being queried :param str resource_link: If this is a Document query/feed collection_link is required. Example of `fetch_function`: @@ -76,7 +77,7 @@ def __init__( self._partition_key = partition_key self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( self._client, self._collection_link, self._query, self._options, self._fetch_function, - response_hook, raw_response_hook) + response_hook, raw_response_hook, resource_type) super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) def _unpack(self, block): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 9320d67d14e3..f91f26b57667 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -79,10 +79,22 @@ def get_session_token( if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] if partition_key_range_id is not None: - container_routing_map = routing_map_provider._collection_routing_map_by_item[collection_name] - current_range = container_routing_map._rangeById.get(partition_key_range_id) - if current_range is not None: - session_token = self._format_session_token(current_range, token_dict) + # if we find a cached session token for the relevant pk range id, use that session token + if token_dict.get(partition_key_range_id): + vector_session_token = token_dict.get(partition_key_range_id) + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token.session_token) + # if we don't find it, we do a session token merge for the parent pk ranges + # this should only happen immediately after a partition split + else: + container_routing_map = routing_map_provider._collection_routing_map_by_item[collection_name] + current_range = container_routing_map._rangeById.get(partition_key_range_id) + if current_range is not None: + vector_session_token = self._resolve_partition_local_session_token(current_range, + token_dict) + session_token = "{0}:{1}".format(partition_key_range_id, + vector_session_token.session_token) + else: + print(3) else: collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], @@ -91,6 +103,7 @@ def get_session_token( epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) session_token = self._format_session_token(pk_range, token_dict) + # session_token = token_dict.get(pk_range[0]['id']) return session_token return "" except Exception: # pylint: disable=broad-except @@ -297,6 +310,24 @@ def _format_session_token(self, pk_range, token_dict): session_token = ",".join(session_token_list) return session_token + def _resolve_partition_local_session_token(self, pk_range, token_dict) -> VectorSessionToken: + parent_session_token = None + parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) + for parent in parents: + vector_session_token = token_dict.get(parent) + # set initial token to be returned + if parent_session_token is None: + parent_session_token = vector_session_token + else: + # if initial token is already set, and the next parent's token is cached, merge vector session tokens + if vector_session_token is not None: + vector_token_1 = VectorSessionToken.create(parent_session_token) + vector_token_2 = VectorSessionToken.create(vector_session_token) + vector_token = vector_token_1.merge(vector_token_2) + parent_session_token = vector_token.session_token + return parent_session_token + class Session(object): """State of an Azure Cosmos session. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index ecae7e5d8231..4d9136af4a29 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -2106,7 +2106,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.PartitionKeyRange ) def ReadDatabases( @@ -2313,6 +2314,7 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca page_iterator_class=query_iterable.QueryIterable, response_hook=response_hook, raw_response_hook=kwargs.get('raw_response_hook'), + resource_type=http_constants.ResourceType.Document ) def QueryItemsChangeFeed( @@ -2444,7 +2446,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca query, options, fetch_function=fetch_fn, - page_iterator_class=query_iterable.QueryIterable + page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Offer ) def ReadUsers( @@ -2506,7 +2509,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.User ) def ReadPermissions( @@ -2555,20 +2559,21 @@ def QueryPermissions( if options is None: options = {} - resource_type = http_constants.ResourceType.Permission - path = base.GetPathFromLink(user_link, resource_type) + path = base.GetPathFromLink(user_link, http_constants.ResourceType.Permission) user_id = base.GetResourceIdOrFullNameFromLink(user_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, resource_type, user_id, lambda r: r["Permissions"], lambda _, b: b, query, options, **kwargs + path, http_constants.ResourceType.Permission, user_id, lambda r: r["Permissions"], lambda _, b: b, + query, options, **kwargs ), self.last_response_headers, ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Permission ) def ReadStoredProcedures( @@ -2630,7 +2635,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.StoredProcedure ) def ReadTriggers( @@ -2692,7 +2698,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Trigger ) def ReadUserDefinedFunctions( @@ -2755,7 +2762,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.UserDefinedFunction ) def ReadConflicts( @@ -2816,7 +2824,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Conflict ) async def QueryFeed( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py index 0fb3ad1c7fc4..d0304ccfde60 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py @@ -44,6 +44,7 @@ def __init__( database_link=None, partition_key=None, continuation_token=None, + resource_type=None, response_hook=None, raw_response_hook=None, ): @@ -56,7 +57,7 @@ def __init__( :param (str or dict) query: :param dict options: The request options for the request. :param method fetch_function: - :param method resource_type: The type of the resource being queried + :param str resource_type: The type of the resource being queried :param str resource_link: If this is a Document query/feed collection_link is required. Example of `fetch_function`: @@ -77,7 +78,7 @@ def __init__( self._partition_key = partition_key self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( self._client, self._collection_link, self._query, self._options, self._fetch_function, - response_hook, raw_response_hook) + response_hook, raw_response_hook, resource_type) super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) async def _unpack(self, block): From f94098fa99f74d169c35c63edcdf71f2dfc9ee8b Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 09:05:15 -0400 Subject: [PATCH 35/52] Update _session.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index f91f26b57667..29d38ebacfa4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -86,7 +86,8 @@ def get_session_token( # if we don't find it, we do a session token merge for the parent pk ranges # this should only happen immediately after a partition split else: - container_routing_map = routing_map_provider._collection_routing_map_by_item[collection_name] + container_routing_map = \ + routing_map_provider._collection_routing_map_by_item[collection_name] current_range = container_routing_map._rangeById.get(partition_key_range_id) if current_range is not None: vector_session_token = self._resolve_partition_local_session_token(current_range, From 883c8619f09234baf3ded37f14231743ea367b26 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 11:55:28 -0400 Subject: [PATCH 36/52] small fixes --- .../azure/cosmos/_cosmos_client_connection.py | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_session.py | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index df1bdb362b80..14c054883926 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -1121,6 +1121,7 @@ def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInse page_iterator_class=query_iterable.QueryIterable, response_hook=response_hook, raw_response_hook=kwargs.get('raw_response_hook'), + resource_type=http_constants.ResourceType.Document ) def QueryItemsChangeFeed( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 29d38ebacfa4..4bc5ccfd09fa 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -92,8 +92,7 @@ def get_session_token( if current_range is not None: vector_session_token = self._resolve_partition_local_session_token(current_range, token_dict) - session_token = "{0}:{1}".format(partition_key_range_id, - vector_session_token.session_token) + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) else: print(3) else: @@ -311,17 +310,15 @@ def _format_session_token(self, pk_range, token_dict): session_token = ",".join(session_token_list) return session_token - def _resolve_partition_local_session_token(self, pk_range, token_dict) -> VectorSessionToken: + def _resolve_partition_local_session_token(self, pk_range, token_dict) -> str: parent_session_token = None parents = pk_range[0].get('parents').copy() - parents.append(pk_range[0]['id']) for parent in parents: - vector_session_token = token_dict.get(parent) - # set initial token to be returned + vector_session_token = token_dict.get(parent).session_token if parent_session_token is None: parent_session_token = vector_session_token + # if initial token is already set, and the next parent's token is cached, merge vector session tokens else: - # if initial token is already set, and the next parent's token is cached, merge vector session tokens if vector_session_token is not None: vector_token_1 = VectorSessionToken.create(parent_session_token) vector_token_2 = VectorSessionToken.create(vector_session_token) From b2741265d44e1d4d2a7731312bd21e554daf3760 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 12:39:45 -0400 Subject: [PATCH 37/52] Update _session.py --- .../azure-cosmos/azure/cosmos/_session.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 4bc5ccfd09fa..3b0db03356e2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -92,9 +92,8 @@ def get_session_token( if current_range is not None: vector_session_token = self._resolve_partition_local_session_token(current_range, token_dict) - session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) - else: - print(3) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) else: collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], @@ -103,7 +102,6 @@ def get_session_token( epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) session_token = self._format_session_token(pk_range, token_dict) - # session_token = token_dict.get(pk_range[0]['id']) return session_token return "" except Exception: # pylint: disable=broad-except @@ -147,10 +145,22 @@ async def get_session_token_async( if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] if partition_key_range_id is not None: - container_routing_map = routing_map_provider._collection_routing_map_by_item[collection_name] - current_range = container_routing_map._rangeById.get(partition_key_range_id) - if current_range is not None: - session_token = self._format_session_token(current_range, token_dict) + # if we find a cached session token for the relevant pk range id, use that session token + if token_dict.get(partition_key_range_id): + vector_session_token = token_dict.get(partition_key_range_id) + session_token = "{0}:{1}".format(partition_key_range_id, + vector_session_token.session_token) + # if we don't find it, we do a session token merge for the parent pk ranges + # this should only happen immediately after a partition split + else: + container_routing_map = \ + routing_map_provider._collection_routing_map_by_item[collection_name] + current_range = container_routing_map._rangeById.get(partition_key_range_id) + if current_range is not None: + vector_session_token = self._resolve_partition_local_session_token(current_range, + token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) else: collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], @@ -310,7 +320,7 @@ def _format_session_token(self, pk_range, token_dict): session_token = ",".join(session_token_list) return session_token - def _resolve_partition_local_session_token(self, pk_range, token_dict) -> str: + def _resolve_partition_local_session_token(self, pk_range, token_dict): parent_session_token = None parents = pk_range[0].get('parents').copy() for parent in parents: From 3918a5cc51fe2fbf4dd63c2687e405283e854e30 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 13:15:13 -0400 Subject: [PATCH 38/52] Update _session.py --- sdk/cosmos/azure-cosmos/azure/cosmos/_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 3b0db03356e2..4fa38c32722e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -35,7 +35,7 @@ from .exceptions import CosmosHttpResponseError from .partition_key import PartitionKey -# pylint: disable=protected-access +# pylint: disable=protected-access,too-many-nested-blocks class SessionContainer(object): def __init__(self): From 33276520b72acf5ba132fd0659ee0bf756a3c8ce Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 16:52:03 -0400 Subject: [PATCH 39/52] updates --- .../cosmos/_execution_context/execution_dispatcher.py | 10 ++++++---- sdk/cosmos/azure-cosmos/dev_requirements.txt | 2 +- .../tests/test_partition_split_query_async.py | 7 +++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 2bdadae96026..7d2e7228ed7f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -108,7 +108,8 @@ def __next__(self): :raises StopIteration: If no more result is left. """ - if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: + if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or + self._resource_type != ResourceType.Document): try: return next(self._execution_context) except CosmosHttpResponseError as e: @@ -130,7 +131,8 @@ def fetch_next_block(self): :return: List of results. :rtype: list """ - if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: + if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or + self._resource_type != ResourceType.Document): try: return self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: @@ -151,8 +153,8 @@ def _create_pipelined_execution_context(self, query_execution_info): raise CosmosHttpResponseError( StatusCodes.BAD_REQUEST, "Cross partition query only supports 'VALUE ' for aggregates") - if self._fetched_query_plan: - self._options.pop("enableCrossPartitionQuery", None) + # if self._fetched_query_plan: + # self._options.pop("enableCrossPartitionQuery", None) # throw exception here for vector search query without limit filter or limit > max_limit if query_execution_info.get_non_streaming_order_by(): diff --git a/sdk/cosmos/azure-cosmos/dev_requirements.txt b/sdk/cosmos/azure-cosmos/dev_requirements.txt index 401c284ac2ce..5a354b1f90af 100644 --- a/sdk/cosmos/azure-cosmos/dev_requirements.txt +++ b/sdk/cosmos/azure-cosmos/dev_requirements.txt @@ -1,4 +1,4 @@ -aiohttp>=3.8.5 +aiohttp<=3.12.2 -e ../../core/azure-core -e ../../identity/azure-identity -e ../../../tools/azure-sdk-tools diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index f684cb4ab668..b3089626beee 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -30,11 +30,9 @@ async def run_session_token_query(container, split): query = "select * from c" # verify session token sent makes sense for number of partitions present if split: - query_iterable = container.query_items(query=query, enable_cross_partition_query=True, - raw_response_hook=test_config.post_split_hook) + query_iterable = container.query_items(query=query, raw_response_hook=test_config.post_split_hook) else: - query_iterable = container.query_items(query=query, enable_cross_partition_query=True, - raw_response_hook=test_config.pre_split_hook) + query_iterable = container.query_items(query=query, raw_response_hook=test_config.pre_split_hook) item_list = [item async for item in query_iterable] @@ -61,6 +59,7 @@ def setUpClass(cls): async def asyncSetUp(self): self.client = CosmosClient(self.host, self.masterKey) + await self.client.__aenter__() self.created_database = self.client.get_database_client(self.TEST_DATABASE_ID) self.container = await self.created_database.create_container( id=self.TEST_CONTAINER_ID, From c7b5ca53cb88e8dda5d5e4d95eb7393238fca2b6 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Fri, 30 May 2025 18:49:17 -0400 Subject: [PATCH 40/52] update tests that checked for compound tokens - this will no longer be the case --- .../tests/test_latest_session_token.py | 14 ++++---------- .../tests/test_latest_session_token_async.py | 15 +++++---------- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py index 7fd18a53326b..7ea83b17749b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py @@ -93,16 +93,10 @@ def test_latest_session_token_from_pk(self): phys_feed_ranges_and_session_tokens) phys_session_token = container.get_latest_session_token(phys_feed_ranges_and_session_tokens, phys_target_feed_range) - assert is_compound_session_token(phys_session_token) - session_tokens = phys_session_token.split(",") - assert len(session_tokens) == 2 - pk_range_id1, session_token1 = parse_session_token(session_tokens[0]) - pk_range_id2, session_token2 = parse_session_token(session_tokens[1]) - pk_range_ids = [pk_range_id1, pk_range_id2] - - assert 620 <= (session_token1.global_lsn + session_token2.global_lsn) - assert '1' in pk_range_ids - assert '2' in pk_range_ids + pk_range_id, session_token = parse_session_token(phys_session_token) + + assert session_token.global_lsn >= 360 + assert '2' in pk_range_id self.database.delete_container(container.id) def test_latest_session_token_hpk(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py index 139078683bd3..c7fbfcffd7b6 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py @@ -95,16 +95,11 @@ async def test_latest_session_token_from_pk_async(self): phys_feed_ranges_and_session_tokens) phys_session_token = await container.get_latest_session_token(phys_feed_ranges_and_session_tokens, phys_target_feed_range) - assert is_compound_session_token(phys_session_token) - session_tokens = phys_session_token.split(",") - assert len(session_tokens) == 2 - pk_range_id1, session_token1 = parse_session_token(session_tokens[0]) - pk_range_id2, session_token2 = parse_session_token(session_tokens[1]) - pk_range_ids = [pk_range_id1, pk_range_id2] - - assert 620 <= (session_token1.global_lsn + session_token2.global_lsn) - assert '1' in pk_range_ids - assert '2' in pk_range_ids + pk_range_id, session_token = parse_session_token(phys_session_token) + + assert session_token.global_lsn >= 360 + assert '2' in pk_range_id + self.database.delete_container(container.id) await self.database.delete_container(container.id) async def test_latest_session_token_hpk(self): From 941719698687853f036675ae9fbca1b325ef0d1e Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:12:02 -0400 Subject: [PATCH 41/52] Update test_config.py --- sdk/cosmos/azure-cosmos/tests/test_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 6cff5f412cc4..3137f2ec536e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -303,7 +303,7 @@ def pre_split_hook(response): request_headers = response.http_request.headers session_token = request_headers.get('x-ms-session-token') assert len(session_token) <= 20 - assert session_token.startswith('0:0') + assert session_token.startswith('0') assert session_token.count(':') == 1 assert session_token.count(',') == 0 @@ -312,7 +312,7 @@ def post_split_hook(response): session_token = request_headers.get('x-ms-session-token') assert len(session_token) > 30 assert len(session_token) < 60 # should only be 0-1 or 0-2, not 0-1-2 - assert session_token.startswith('0:0') + assert session_token.startswith('0') is False assert session_token.count(':') == 2 assert session_token.count(',') == 1 From 525ad6c6b0464d2600a09a723fbdc6054606b700 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:58:02 -0400 Subject: [PATCH 42/52] Update execution_dispatcher.py --- .../cosmos/_execution_context/aio/execution_dispatcher.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 6d6cf8f12e8b..0a9765bcf843 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -78,7 +78,8 @@ async def __anext__(self): :raises StopIteration: If no more result is left. """ - if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: + if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or + self._resource_type != ResourceType.Document): try: return await self._execution_context.__anext__() except CosmosHttpResponseError as e: @@ -100,7 +101,8 @@ async def fetch_next_block(self): :return: List of results. :rtype: list """ - if "enableCrossPartitionQuery" not in self._options or self._resource_type != ResourceType.Document: + if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or + self._resource_type != ResourceType.Document): try: return await self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: From d1a014f9876642a37d4a0f2e99863316b6b66fae Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 2 Jun 2025 16:28:54 -0400 Subject: [PATCH 43/52] remove query logic --- .../aio/execution_dispatcher.py | 36 +++++++---------- .../execution_dispatcher.py | 36 +++++++---------- .../azure-cosmos/azure/cosmos/_session.py | 39 ++++++++++++------- 3 files changed, 52 insertions(+), 59 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 0a9765bcf843..caa595218f40 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -78,17 +78,13 @@ async def __anext__(self): :raises StopIteration: If no more result is left. """ - if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or - self._resource_type != ResourceType.Document): - try: - return await self._execution_context.__anext__() - except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - await self._create_execution_context_with_query_plan() - else: - raise e - else: - await self._create_execution_context_with_query_plan() + try: + return await self._execution_context.__anext__() + except CosmosHttpResponseError as e: + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + await self._create_execution_context_with_query_plan() + else: + raise e return await self._execution_context.__anext__() @@ -101,17 +97,13 @@ async def fetch_next_block(self): :return: List of results. :rtype: list """ - if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or - self._resource_type != ResourceType.Document): - try: - return await self._execution_context.fetch_next_block() - except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - await self._create_execution_context_with_query_plan() - else: - raise e - else: - await self._create_execution_context_with_query_plan() + try: + return await self._execution_context.fetch_next_block() + except CosmosHttpResponseError as e: + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + await self._create_execution_context_with_query_plan() + else: + raise e return await self._execution_context.fetch_next_block() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 7d2e7228ed7f..79b8d6996b74 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -108,17 +108,13 @@ def __next__(self): :raises StopIteration: If no more result is left. """ - if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or - self._resource_type != ResourceType.Document): - try: - return next(self._execution_context) - except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - self._create_execution_context_with_query_plan() - else: - raise e - else: - self._create_execution_context_with_query_plan() + try: + return next(self._execution_context) + except CosmosHttpResponseError as e: + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + self._create_execution_context_with_query_plan() + else: + raise e return next(self._execution_context) @@ -131,17 +127,13 @@ def fetch_next_block(self): :return: List of results. :rtype: list """ - if ("enableCrossPartitionQuery" not in self._options or self._fetched_query_plan or - self._resource_type != ResourceType.Document): - try: - return self._execution_context.fetch_next_block() - except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - self._create_execution_context_with_query_plan() - else: - raise e - else: - self._create_execution_context_with_query_plan() + try: + return self._execution_context.fetch_next_block() + except CosmosHttpResponseError as e: + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + self._create_execution_context_with_query_plan() + else: + raise e return self._execution_context.fetch_next_block() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 4fa38c32722e..d07e6f6ba2e0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -46,7 +46,7 @@ def __init__(self): def get_session_token( self, resource_path: str, - pk_value: str, + pk_value: Any, container_properties_cache: Dict[str, Dict[str, Any]], routing_map_provider: SmartRoutingMapProvider, partition_key_range_id: Optional[int]) -> str: @@ -55,7 +55,7 @@ def get_session_token( :param str resource_path: Self link / path to the resource :param ~azure.cosmos.SmartRoutingMapProvider routing_map_provider: routing map containing relevant session information, such as partition key ranges for a given collection - :param str pk_value: The partition key value being used for the operation + :param Any pk_value: The partition key value being used for the operation :param container_properties_cache: Container properties cache used to fetch partition key definitions :type container_properties_cache: Dict[str, Dict[str, Any]] :param int partition_key_range_id: The partition key range ID used for the operation @@ -94,14 +94,18 @@ def get_session_token( token_dict) if vector_session_token is not None: session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) - else: + elif pk_value is not None: collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], kind=collection_pk_definition['kind'], version=collection_pk_definition['version']) epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) - session_token = self._format_session_token(pk_range, token_dict) + if len(pk_range) > 0: + partition_key_range_id = pk_range[0]['id'] + vector_session_token = self._resolve_partition_local_session_token(pk_range, token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) return session_token return "" except Exception: # pylint: disable=broad-except @@ -110,7 +114,7 @@ def get_session_token( async def get_session_token_async( self, resource_path: str, - pk_value: str, + pk_value: Any, container_properties_cache: Dict[str, Dict[str, Any]], routing_map_provider: SmartRoutingMapProviderAsync, partition_key_range_id: Optional[str]) -> str: @@ -119,7 +123,7 @@ async def get_session_token_async( :param str resource_path: Self link / path to the resource :param ~azure.cosmos.SmartRoutingMapProviderAsync routing_map_provider: routing map containing relevant session information, such as partition key ranges for a given collection - :param str pk_value: The partition key value being used for the operation + :param Any pk_value: The partition key value being used for the operation :param container_properties_cache: Container properties cache used to fetch partition key definitions :type container_properties_cache: Dict[str, Dict[str, Any]] :param Any routing_map_provider: The routing map provider containing the partition key range cache logic @@ -131,7 +135,6 @@ async def get_session_token_async( with self.session_lock: is_name_based = _base.IsNameBased(resource_path) - collection_rid = "" session_token = "" try: @@ -161,14 +164,18 @@ async def get_session_token_async( token_dict) if vector_session_token is not None: session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) - else: + elif pk_value is not None: collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] partition_key = PartitionKey(path=collection_pk_definition['paths'], kind=collection_pk_definition['kind'], version=collection_pk_definition['version']) epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) pk_range = await routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) - session_token = self._format_session_token(pk_range, token_dict) + if len(pk_range) > 0: + partition_key_range_id = pk_range[0]['id'] + vector_session_token = self._resolve_partition_local_session_token(pk_range, token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) return session_token return "" except Exception: # pylint: disable=broad-except @@ -323,13 +330,15 @@ def _format_session_token(self, pk_range, token_dict): def _resolve_partition_local_session_token(self, pk_range, token_dict): parent_session_token = None parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) for parent in parents: - vector_session_token = token_dict.get(parent).session_token - if parent_session_token is None: - parent_session_token = vector_session_token - # if initial token is already set, and the next parent's token is cached, merge vector session tokens - else: - if vector_session_token is not None: + session_token = token_dict.get(parent) + if session_token is not None: + vector_session_token = session_token.session_token + if parent_session_token is None: + parent_session_token = vector_session_token + # if initial token is already set, and the next parent's token is cached, merge vector session tokens + else: vector_token_1 = VectorSessionToken.create(parent_session_token) vector_token_2 = VectorSessionToken.create(vector_session_token) vector_token = vector_token_1.merge(vector_token_2) From 5c825e7ed3eac47d24468b3be068e4e4f0270b8f Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:13:28 -0400 Subject: [PATCH 44/52] remove partition split testing --- .../tests/test_partition_split_query.py | 13 ------------- .../tests/test_partition_split_query_async.py | 11 ----------- 2 files changed, 24 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index 5891dd4efb88..f3a7f1d475d7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -29,17 +29,6 @@ def run_queries(container, iterations): assert str(attr_number) == curr # verify that all results match their randomly generated attributes print("validation succeeded for all query results") -def run_session_token_query(container, split): - query = "select * from c" - # verify session token sent makes sense for number of partitions present - if split: - query_iterable = container.query_items(query=query, enable_cross_partition_query=True, - raw_response_hook=test_config.post_split_hook) - else: - query_iterable = container.query_items(query=query, enable_cross_partition_query=True, - raw_response_hook=test_config.pre_split_hook) - list(query_iterable) - @pytest.mark.cosmosQuery class TestPartitionSplitQuery(unittest.TestCase): @@ -85,7 +74,6 @@ def test_partition_split_query(self): print("now starting queries") run_queries(self.container, 100) # initial check for queries before partition split - run_session_token_query(self.container, False) # initial session token check before partition split print("initial check succeeded, now reading offer until replacing is done") offer = self.container.get_throughput() while True: @@ -97,7 +85,6 @@ def test_partition_split_query(self): else: print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) run_queries(self.container, 100) # check queries work post partition split - run_session_token_query(self.container, True) # check session token works post partition split self.assertTrue(offer.offer_throughput > self.throughput) return diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index b3089626beee..15c994f14132 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -26,15 +26,6 @@ async def run_queries(container, iterations): assert str(attr_number) == curr # verify that all results match their randomly generated attributes print("validation succeeded for all query results") -async def run_session_token_query(container, split): - query = "select * from c" - # verify session token sent makes sense for number of partitions present - if split: - query_iterable = container.query_items(query=query, raw_response_hook=test_config.post_split_hook) - else: - query_iterable = container.query_items(query=query, raw_response_hook=test_config.pre_split_hook) - item_list = [item async for item in query_iterable] - @pytest.mark.cosmosQuery class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): @@ -83,7 +74,6 @@ async def test_partition_split_query_async(self): print("now starting queries") await run_queries(self.container, 100) # initial check for queries before partition split - await run_session_token_query(self.container, False) # initial session token check before partition split print("initial check succeeded, now reading offer until replacing is done") offer = await self.container.get_throughput() while True: @@ -95,7 +85,6 @@ async def test_partition_split_query_async(self): else: print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) await run_queries(self.container, 100) # check queries work post partition split - await run_session_token_query(self.container, True) # check session token works post partition split self.assertTrue(offer.offer_throughput > self.throughput) return From 76f707b6341a6fdcdc0d13efbdee536847cbeb5b Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Mon, 2 Jun 2025 18:24:48 -0400 Subject: [PATCH 45/52] oylint --- .../azure/cosmos/_execution_context/aio/execution_dispatcher.py | 2 +- .../azure/cosmos/_execution_context/execution_dispatcher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index caa595218f40..56fc2ddd89e9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -33,7 +33,7 @@ from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos.documents import _DistinctType from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.cosmos.http_constants import StatusCodes, ResourceType +from azure.cosmos.http_constants import StatusCodes from ..._constants import _Constants as Constants # pylint: disable=protected-access diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 79b8d6996b74..cbe6d67a0da9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -32,7 +32,7 @@ from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo from azure.cosmos.documents import _DistinctType -from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, ResourceType +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes from .._constants import _Constants as Constants # pylint: disable=protected-access From 29ea4b1dc75ecc7579afc55be82751cefb4415cf Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Wed, 4 Jun 2025 11:55:23 -0400 Subject: [PATCH 46/52] Update dev_requirements.txt --- sdk/cosmos/azure-cosmos/dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/dev_requirements.txt b/sdk/cosmos/azure-cosmos/dev_requirements.txt index 5a354b1f90af..401c284ac2ce 100644 --- a/sdk/cosmos/azure-cosmos/dev_requirements.txt +++ b/sdk/cosmos/azure-cosmos/dev_requirements.txt @@ -1,4 +1,4 @@ -aiohttp<=3.12.2 +aiohttp>=3.8.5 -e ../../core/azure-core -e ../../identity/azure-identity -e ../../../tools/azure-sdk-tools From 9123dfe94e8bf1e9b08cf592706743c1b3fef6a9 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:29:59 -0400 Subject: [PATCH 47/52] delete duplicate --- sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py index c7fbfcffd7b6..5e1fbffa5921 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py @@ -99,7 +99,6 @@ async def test_latest_session_token_from_pk_async(self): assert session_token.global_lsn >= 360 assert '2' in pk_range_id - self.database.delete_container(container.id) await self.database.delete_container(container.id) async def test_latest_session_token_hpk(self): From 69af5400d0bbf2f1e136d33a392a57c677cd3f59 Mon Sep 17 00:00:00 2001 From: bambriz Date: Thu, 19 Jun 2025 13:59:34 -0700 Subject: [PATCH 48/52] update tests for partition split --- .../azure-cosmos/tests/test_partition_split_query.py | 7 ++++--- .../tests/test_partition_split_query_async.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index a431b0e79bb0..c22cd57b17a4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -41,6 +41,7 @@ class TestPartitionSplitQuery(unittest.TestCase): throughput = 400 TEST_DATABASE_ID = configs.TEST_DATABASE_ID TEST_CONTAINER_ID = "Single-partition-container-without-throughput" + MAX_TIME = 60 * 7 # 7 minutes for the test to complete, should be enough for partition split to complete @classmethod def setUpClass(cls): @@ -77,10 +78,10 @@ def test_partition_split_query(self): print("initial check succeeded, now reading offer until replacing is done") offer = self.container.get_throughput() while True: - if time.time() - start_time > 60 * 25: # timeout test at 25 minutes - raise unittest.SkipTest("Partition split didn't complete in time") + if time.time() - start_time > self.MAX_TIME: # timeout test at 25 minutes + self.skipTest("Partition split didn't complete in time") if offer.properties['content'].get('isOfferReplacePending', False): - time.sleep(10) + time.sleep(30) # wait for the offer to be replaced, check every 30 seconds offer = self.container.get_throughput() else: print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index 15c994f14132..56b912ca506d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -38,6 +38,7 @@ class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): throughput = 400 TEST_DATABASE_ID = configs.TEST_DATABASE_ID TEST_CONTAINER_ID = "Single-partition-container-without-throughput-async" + MAX_TIME = 60 * 7 # 7 minutes for the test to complete, should be enough for partition split to complete @classmethod def setUpClass(cls): @@ -77,10 +78,10 @@ async def test_partition_split_query_async(self): print("initial check succeeded, now reading offer until replacing is done") offer = await self.container.get_throughput() while True: - if time.time() - start_time > 60 * 25: # timeout test at 25 minutes - unittest.skip("Partition split didn't complete in time.") + if time.time() - start_time > self.MAX_TIME: # timeout test at 25 minutes + self.skipTest("Partition split didn't complete in time.") if offer.properties['content'].get('isOfferReplacePending', False): - time.sleep(10) + time.sleep(30) # wait for the offer to be replaced, check every 30 seconds offer = await self.container.get_throughput() else: print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) @@ -89,4 +90,4 @@ async def test_partition_split_query_async(self): return if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From db5233f51619ef18fd1430aaeef9f76dd476a72c Mon Sep 17 00:00:00 2001 From: bambriz Date: Fri, 20 Jun 2025 17:33:04 -0700 Subject: [PATCH 49/52] change timeout of split partition key tests reduces timeout from 25 minutes to 7 minutes --- sdk/cosmos/azure-cosmos/tests/test_config.py | 8 ++++---- .../azure-cosmos/tests/test_latest_session_token.py | 2 +- .../azure-cosmos/tests/test_partition_split_query.py | 2 +- .../tests/test_partition_split_query_async.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 51d1df4cd6b0..68a9d7317f96 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -27,8 +27,8 @@ except: print("no urllib3") -SPLIT_TIMEOUT = 60*25 # timeout test at 25 minutes -SLEEP_TIME = 60 # sleep for 1 minutes +SPLIT_TIMEOUT = 60*7 # timeout test at 7 minutes +SLEEP_TIME = 30 # sleep for 1 minutes class TestConfig(object): local_host = 'https://localhost:8081/' @@ -215,7 +215,7 @@ def trigger_split(container, throughput): while True: offer = container.get_throughput() if offer.properties['content'].get('isOfferReplacePending', False): - if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 25 minutes + if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 7 minutes raise unittest.SkipTest("Partition split didn't complete in time") else: print("Waiting for split to complete") @@ -236,7 +236,7 @@ async def trigger_split_async(container, throughput): while True: offer = await container.get_throughput() if offer.properties['content'].get('isOfferReplacePending', False): - if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 25 minutes + if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 7 minutes raise unittest.SkipTest("Partition split didn't complete in time") else: print("Waiting for split to complete") diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py index 7ea83b17749b..0eb8a828cf78 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py @@ -95,7 +95,7 @@ def test_latest_session_token_from_pk(self): phys_session_token = container.get_latest_session_token(phys_feed_ranges_and_session_tokens, phys_target_feed_range) pk_range_id, session_token = parse_session_token(phys_session_token) - assert session_token.global_lsn >= 360 + assert session_token.global_lsn >= 350 assert '2' in pk_range_id self.database.delete_container(container.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index c22cd57b17a4..908b0433e26e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -78,7 +78,7 @@ def test_partition_split_query(self): print("initial check succeeded, now reading offer until replacing is done") offer = self.container.get_throughput() while True: - if time.time() - start_time > self.MAX_TIME: # timeout test at 25 minutes + if time.time() - start_time > self.MAX_TIME: # timeout test at 7 minutes self.skipTest("Partition split didn't complete in time") if offer.properties['content'].get('isOfferReplacePending', False): time.sleep(30) # wait for the offer to be replaced, check every 30 seconds diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index 56b912ca506d..5b0efa04d106 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -78,7 +78,7 @@ async def test_partition_split_query_async(self): print("initial check succeeded, now reading offer until replacing is done") offer = await self.container.get_throughput() while True: - if time.time() - start_time > self.MAX_TIME: # timeout test at 25 minutes + if time.time() - start_time > self.MAX_TIME: # timeout test at 7 minutes self.skipTest("Partition split didn't complete in time.") if offer.properties['content'].get('isOfferReplacePending', False): time.sleep(30) # wait for the offer to be replaced, check every 30 seconds From af0b1e0584cc7559f054c77b55e1e5018e3f0d58 Mon Sep 17 00:00:00 2001 From: bambriz Date: Mon, 23 Jun 2025 16:36:28 -0700 Subject: [PATCH 50/52] update test --- sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py | 2 +- .../azure-cosmos/tests/test_partition_split_query_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index 908b0433e26e..793302973a87 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -30,7 +30,7 @@ def run_queries(container, iterations): print("validation succeeded for all query results") -@pytest.mark.cosmosQuery +@pytest.mark.cosmosSplit class TestPartitionSplitQuery(unittest.TestCase): database: DatabaseProxy = None container: ContainerProxy = None diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index 5b0efa04d106..c860735c0e28 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -27,7 +27,7 @@ async def run_queries(container, iterations): print("validation succeeded for all query results") -@pytest.mark.cosmosQuery +@pytest.mark.cosmosSplit class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): database: DatabaseProxy = None container: ContainerProxy = None From d7364acc4173760803ed114b4eb54dff9ac86223 Mon Sep 17 00:00:00 2001 From: bambriz Date: Mon, 30 Jun 2025 10:32:33 -0700 Subject: [PATCH 51/52] Test fixes Fixes tests to accomodate the extra readfeed that may happen during a read item operation. This also includes other general test fixes. --- ...hangefeed_partition_key_variation_async.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_config.py | 8 +-- .../tests/test_partition_split_query.py | 4 +- .../tests/test_partition_split_query_async.py | 4 +- .../tests/test_service_retry_policies.py | 44 +++++++++++++++- .../test_service_retry_policies_async.py | 47 ++++++++++++++++- .../test_timeout_and_failover_retry_policy.py | 51 ++++++++++++------- 7 files changed, 129 insertions(+), 31 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py b/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py index 48d00337ae7f..87f74f4af7c1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_changefeed_partition_key_variation_async.py @@ -262,7 +262,7 @@ async def _get_properties_override(): for item in items: try: - epk_range = container._get_epk_range_for_partition_key(container_properties, item["pk"]) + epk_range = await container._get_epk_range_for_partition_key(item["pk"]) assert epk_range is not None, f"EPK range should not be None for partition key {item['pk']}." except Exception as e: assert False, f"Failed to get EPK range for partition key {item['pk']}: {str(e)}" diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 68a9d7317f96..7a0371d8a62f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -27,8 +27,8 @@ except: print("no urllib3") -SPLIT_TIMEOUT = 60*7 # timeout test at 7 minutes -SLEEP_TIME = 30 # sleep for 1 minutes +SPLIT_TIMEOUT = 60*10 # timeout test at 10 minutes +SLEEP_TIME = 30 # sleep for 30 seconds class TestConfig(object): local_host = 'https://localhost:8081/' @@ -215,7 +215,7 @@ def trigger_split(container, throughput): while True: offer = container.get_throughput() if offer.properties['content'].get('isOfferReplacePending', False): - if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 7 minutes + if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 10 minutes raise unittest.SkipTest("Partition split didn't complete in time") else: print("Waiting for split to complete") @@ -236,7 +236,7 @@ async def trigger_split_async(container, throughput): while True: offer = await container.get_throughput() if offer.properties['content'].get('isOfferReplacePending', False): - if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 7 minutes + if time.time() - start_time > SPLIT_TIMEOUT: # timeout test at 10 minutes raise unittest.SkipTest("Partition split didn't complete in time") else: print("Waiting for split to complete") diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index 793302973a87..c5dbfe651239 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -41,7 +41,7 @@ class TestPartitionSplitQuery(unittest.TestCase): throughput = 400 TEST_DATABASE_ID = configs.TEST_DATABASE_ID TEST_CONTAINER_ID = "Single-partition-container-without-throughput" - MAX_TIME = 60 * 7 # 7 minutes for the test to complete, should be enough for partition split to complete + MAX_TIME = 60 * 10 # 10 minutes for the test to complete, should be enough for partition split to complete @classmethod def setUpClass(cls): @@ -78,7 +78,7 @@ def test_partition_split_query(self): print("initial check succeeded, now reading offer until replacing is done") offer = self.container.get_throughput() while True: - if time.time() - start_time > self.MAX_TIME: # timeout test at 7 minutes + if time.time() - start_time > self.MAX_TIME: # timeout test at 10 minutes self.skipTest("Partition split didn't complete in time") if offer.properties['content'].get('isOfferReplacePending', False): time.sleep(30) # wait for the offer to be replaced, check every 30 seconds diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index c860735c0e28..3567ebd023dd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -38,7 +38,7 @@ class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): throughput = 400 TEST_DATABASE_ID = configs.TEST_DATABASE_ID TEST_CONTAINER_ID = "Single-partition-container-without-throughput-async" - MAX_TIME = 60 * 7 # 7 minutes for the test to complete, should be enough for partition split to complete + MAX_TIME = 60 * 10 # 10 minutes for the test to complete, should be enough for partition split to complete @classmethod def setUpClass(cls): @@ -78,7 +78,7 @@ async def test_partition_split_query_async(self): print("initial check succeeded, now reading offer until replacing is done") offer = await self.container.get_throughput() while True: - if time.time() - start_time > self.MAX_TIME: # timeout test at 7 minutes + if time.time() - start_time > self.MAX_TIME: # timeout test at 10 minutes self.skipTest("Partition split didn't complete in time.") if offer.properties['content'].get('isOfferReplacePending', False): time.sleep(30) # wait for the offer to be replaced, check every 30 seconds diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py index 2e862e1c3371..4acaf7773011 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies.py @@ -10,6 +10,7 @@ from azure.cosmos import (CosmosClient, _retry_utility, DatabaseAccount, _global_endpoint_manager, _location_cache) from azure.cosmos._location_cache import RegionalRoutingContext +from azure.cosmos._request_object import RequestObject @pytest.mark.cosmosEmulator @@ -58,7 +59,7 @@ def test_service_request_retry_policy(self): expected_counter = len(original_location_cache.read_regional_routing_contexts) try: # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) _retry_utility.ExecuteFunction = mf container.read_item(created_item['id'], created_item['pk']) pytest.fail("Exception was not raised.") @@ -145,7 +146,7 @@ def test_service_response_retry_policy(self): self.REGIONAL_ENDPOINT] try: # Mock the function to return the ServiceResponseException we retry - mf = self.MockExecuteServiceResponseException(Exception) + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(Exception, self.original_execute_function) _retry_utility.ExecuteFunction = mf container.read_item(created_item['id'], created_item['pk']) pytest.fail("Exception was not raised.") @@ -281,6 +282,25 @@ def __call__(self, func, *args, **kwargs): exception.exc_type = Exception raise exception + class MockExecuteServiceRequestExceptionIgnoreQuery(object): + def __init__(self, original_execute_function): + self.counter = 0 + self.original_execute_function = original_execute_function + + def __call__(self, func, *args, **kwargs): + + if args and isinstance(args[1], RequestObject): + request_obj = args[1] + if request_obj.resource_type == "docs" and request_obj.operation_type == "Query" or\ + request_obj.resource_type == "pkranges" and request_obj.operation_type == "ReadFeed": + # Ignore query requests, As an additional ReadFeed might occur during a regular Read operation + return self.original_execute_function(func, *args, **kwargs) + self.counter = self.counter + 1 + exception = ServiceRequestError("mock exception") + exception.exc_type = Exception + raise exception + return self.original_execute_function(func, *args, **kwargs) + class MockExecuteServiceResponseException(object): def __init__(self, err_type): self.err_type = err_type @@ -292,6 +312,26 @@ def __call__(self, func, *args, **kwargs): exception.exc_type = self.err_type raise exception + class MockExecuteServiceResponseExceptionIgnoreQuery(object): + def __init__(self, err_type, original_execute_function): + self.err_type = err_type + self.counter = 0 + self.original_execute_function = original_execute_function + + def __call__(self, func, *args, **kwargs): + + if args and isinstance(args[1], RequestObject): + request_obj = args[1] + if request_obj.resource_type == "docs" and request_obj.operation_type == "Query" or\ + request_obj.resource_type == "pkranges" and request_obj.operation_type == "ReadFeed": + # Ignore query requests, As an additional ReadFeed might occur during a regular Read operation + return self.original_execute_function(func, *args, **kwargs) + self.counter = self.counter + 1 + exception = ServiceResponseError("mock exception") + exception.exc_type = self.err_type + raise exception + return self.original_execute_function(func, *args, **kwargs) + def MockGetDatabaseAccountStub(self, endpoint): read_regions = ["West US", "East US"] read_locations = [] diff --git a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py index 408ede83d647..1a1b985a74bc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_service_retry_policies_async.py @@ -12,6 +12,7 @@ import test_config from azure.cosmos import DatabaseAccount, _location_cache from azure.cosmos._location_cache import RegionalRoutingContext +from azure.cosmos._request_object import RequestObject from azure.cosmos.aio import CosmosClient, _retry_utility_async, _global_endpoint_manager_async from azure.cosmos.exceptions import CosmosHttpResponseError @@ -68,7 +69,7 @@ async def test_service_request_retry_policy_async(self): expected_counter = len(original_location_cache.read_regional_routing_contexts) try: # Mock the function to return the ServiceRequestException we retry - mf = self.MockExecuteServiceRequestException() + mf = self.MockExecuteServiceRequestExceptionIgnoreQuery(self.original_execute_function) _retry_utility_async.ExecuteFunctionAsync = mf await container.read_item(created_item['id'], created_item['pk']) pytest.fail("Exception was not raised.") @@ -159,7 +160,8 @@ async def test_service_response_retry_policy_async(self): self.REGIONAL_ENDPOINT] try: # Mock the function to return the ClientConnectionError we retry - mf = self.MockExecuteServiceResponseException(AttributeError, None) + mf = self.MockExecuteServiceResponseExceptionIgnoreQuery(AttributeError, + None, self.original_execute_function) _retry_utility_async.ExecuteFunctionAsync = mf await container.read_item(created_item['id'], created_item['pk']) pytest.fail("Exception was not raised.") @@ -442,6 +444,25 @@ def __call__(self, func, *args, **kwargs): exception.exc_type = Exception raise exception + class MockExecuteServiceRequestExceptionIgnoreQuery(object): + def __init__(self, original_execute_function): + self.counter = 0 + self.original_execute_function = original_execute_function + + def __call__(self, func, *args, **kwargs): + + if args and isinstance(args[1], RequestObject): + request_obj = args[1] + if request_obj.resource_type == "docs" and request_obj.operation_type == "Query" or\ + request_obj.resource_type == "pkranges" and request_obj.operation_type == "ReadFeed": + # Ignore query requests, As an additional ReadFeed might occur during a regular Read operation + return self.original_execute_function(func, *args, **kwargs) + self.counter = self.counter + 1 + exception = ServiceRequestError("mock exception") + exception.exc_type = Exception + raise exception + return self.original_execute_function(func, *args, **kwargs) + class MockExecuteServiceResponseException(object): def __init__(self, err_type, inner_exception): self.err_type = err_type @@ -455,6 +476,28 @@ def __call__(self, func, *args, **kwargs): exception.inner_exception = self.inner_exception raise exception + class MockExecuteServiceResponseExceptionIgnoreQuery(object): + def __init__(self, err_type, inner_exception, original_execute_function): + self.err_type = err_type + self.inner_exception = inner_exception + self.counter = 0 + self.original_execute_function = original_execute_function + + def __call__(self, func, *args, **kwargs): + + if args and isinstance(args[1], RequestObject): + request_obj = args[1] + if request_obj.resource_type == "docs" and request_obj.operation_type == "Query" or \ + request_obj.resource_type == "pkranges" and request_obj.operation_type == "ReadFeed": + # Ignore query requests, As an additional ReadFeed might occur during a regular Read operation + return self.original_execute_function(func, *args, **kwargs) + self.counter = self.counter + 1 + exception = ServiceResponseError("mock exception") + exception.exc_type = self.err_type + exception.inner_exception = self.inner_exception + raise exception + return self.original_execute_function(func, *args, **kwargs) + async def MockGetDatabaseAccountStub(self, endpoint): read_regions = ["West US", "East US"] read_locations = [] diff --git a/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py index 3d583c4c0692..4c21ed121441 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/tests/test_timeout_and_failover_retry_policy.py @@ -11,6 +11,7 @@ import test_config from azure.cosmos import _retry_utility, PartitionKey from azure.cosmos._location_cache import RegionalRoutingContext, EndpointOperationType +from azure.cosmos._request_object import RequestObject COLLECTION = "created_collection" @pytest.fixture(scope="class") @@ -150,14 +151,21 @@ def __init__(self, org_func, num_exceptions, status_code): self.status_code = status_code def __call__(self, func, *args, **kwargs): - if self.counter != 0 and self.counter >= self.num_exceptions: - return self.org_func(func, *args, **kwargs) - else: - self.counter += 1 - raise exceptions.CosmosHttpResponseError( - status_code=self.status_code, - message="Some Exception", - response=test_config.FakeResponse({})) + if args and isinstance(args[1], RequestObject): + request_obj = args[1] + if request_obj.resource_type == "docs" and request_obj.operation_type == "Query" or \ + request_obj.resource_type == "pkranges" and request_obj.operation_type == "ReadFeed": + # Ignore query or ReadFeed requests + return self.org_func(func, *args, **kwargs) + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + return self.org_func(func, *args, **kwargs) class MockExecuteFunctionCrossRegion(object): def __init__(self, org_func, status_code, location_endpoint_to_route): @@ -167,16 +175,23 @@ def __init__(self, org_func, status_code, location_endpoint_to_route): self.location_endpoint_to_route = location_endpoint_to_route def __call__(self, func, *args, **kwargs): - if self.counter == 1: - assert args[1].location_endpoint_to_route == self.location_endpoint_to_route - args[1].location_endpoint_to_route = test_config.TestConfig.host - return self.org_func(func, *args, **kwargs) - else: - self.counter += 1 - raise exceptions.CosmosHttpResponseError( - status_code=self.status_code, - message="Some Exception", - response=test_config.FakeResponse({})) + if args and isinstance(args[1], RequestObject): + request_obj = args[1] + if request_obj.resource_type == "docs" and request_obj.operation_type == "Query" or \ + request_obj.resource_type == "pkranges" and request_obj.operation_type == "ReadFeed": + # Ignore query or ReadFeed requests + return self.org_func(func, *args, **kwargs) + if self.counter == 1: + assert args[1].location_endpoint_to_route == self.location_endpoint_to_route + args[1].location_endpoint_to_route = test_config.TestConfig.host + return self.org_func(func, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + return self.org_func(func, *args, **kwargs) From b52c4ac2a8e9aabf4f0d8dabcd13d7d28cc86b93 Mon Sep 17 00:00:00 2001 From: bambriz Date: Tue, 1 Jul 2025 11:47:21 -0700 Subject: [PATCH 52/52] update tests --- .../azure-cosmos/tests/test_latest_session_token_async.py | 1 + .../azure-cosmos/tests/test_partition_split_query_async.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py index 5e1fbffa5921..fbc367aad21f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py @@ -46,6 +46,7 @@ class TestLatestSessionTokenAsync(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.client = CosmosClient(self.host, self.masterKey) + await self.client.__aenter__() self.database = self.client.get_database_client(self.TEST_DATABASE_ID) async def asyncTearDown(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index 3567ebd023dd..37642712fb8b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -16,7 +16,7 @@ async def run_queries(container, iterations): for i in range(iterations): curr = str(random.randint(0, 10)) query = 'SELECT * FROM c WHERE c.attr1=' + curr + ' order by c.attr1' - qlist = [item async for item in container.query_items(query=query, enable_cross_partition_query=True)] + qlist = [item async for item in container.query_items(query=query)] ret_list.append((curr, qlist)) for ret in ret_list: curr = ret[0]