diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 0ac68a729f26..bebb7d2866ca 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -5,6 +5,9 @@ ### Features Added - Stable release of features from 12.26.0b1 +### Bugs Fixed +- Fixed an issue where `BlobClient`'s `start_copy_from_url` with `incremental_copy=True` results in `TypeError`. + ## 12.26.0b1 (2025-05-06) ### Features Added diff --git a/sdk/storage/azure-storage-blob/assets.json b/sdk/storage/azure-storage-blob/assets.json index 8f0795ef0392..5fa65eb24cb2 100644 --- a/sdk/storage/azure-storage-blob/assets.json +++ b/sdk/storage/azure-storage-blob/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-blob", - "Tag": "python/storage/azure-storage-blob_56ef3e2a11" + "Tag": "python/storage/azure-storage-blob_e7a8cad1a0" } diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/__init__.py b/sdk/storage/azure-storage-blob/azure/storage/blob/__init__.py index 2386595611bd..9871a54c4cb7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/__init__.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/__init__.py @@ -122,7 +122,7 @@ def upload_blob_to_url( entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. :keyword str encoding: Encoding to use if text is supplied as input. Defaults to UTF-8. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: dict(str, Any) """ with BlobClient.from_blob_url(blob_url, credential=credential) as client: @@ -153,7 +153,7 @@ def download_blob_from_url( :param output: Where the data should be downloaded to. This could be either a file path to write to, or an open IO handle to write to. - :type output: str or writable stream. + :type output: str or IO. :param credential: The credentials with which to authenticate. This is optional if the blob URL already has a SAS token or the blob is public. The value can be a SAS token string, @@ -190,6 +190,7 @@ def download_blob_from_url( blob. Also note that if enabled, the memory-efficient upload algorithm will not be used, because computing the MD5 hash requires buffering entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :return: None :rtype: None """ overwrite = kwargs.pop('overwrite', False) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index 5e75c24417aa..e549004366ce 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -236,7 +236,7 @@ def from_blob_url( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A Blob client. + :return: A Blob client. :rtype: ~azure.storage.blob.BlobClient """ account_url, container_name, blob_name, path_snapshot = _from_blob_url(blob_url=blob_url, snapshot=snapshot) @@ -284,7 +284,7 @@ def from_connection_string( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A Blob client. + :return: A Blob client. :rtype: ~azure.storage.blob.BlobClient .. admonition:: Example: @@ -311,7 +311,7 @@ def get_account_information(self, **kwargs: Any) -> Dict[str, str]: The information can also be retrieved if the user has a SAS to a container or blob. The keys in the returned dictionary include 'sku_name' and 'account_kind'. - :returns: A dict of account information (SKU and account type). + :return: A dict of account information (SKU and account type). :rtype: dict(str, str) """ try: @@ -431,7 +431,7 @@ def upload_blob_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Blob-updated property Dict (Etag and last modified) + :return: Blob-updated property Dict (Etag and last modified) :rtype: Dict[str, Any] """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -580,7 +580,7 @@ def upload_blob( see `here `__. This method may make multiple calls to the service and the timeout will apply to each call individually. - :returns: Blob-updated property Dict (Etag and last modified) + :return: Blob-updated property Dict (Etag and last modified) :rtype: Dict[str, Any] .. admonition:: Example: @@ -725,7 +725,7 @@ def download_blob( the timeout will apply to each call individually. multiple calls to the Azure service and the timeout will apply to each call individually. - :returns: A streaming object (StorageStreamDownloader) + :return: A streaming object (StorageStreamDownloader) :rtype: ~azure.storage.blob.StorageStreamDownloader .. admonition:: Example: @@ -829,7 +829,7 @@ def query_blob(self, query_expression: str, **kwargs: Any) -> BlobQueryReader: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A streaming object (BlobQueryReader) + :return: A streaming object (BlobQueryReader) :rtype: ~azure.storage.blob.BlobQueryReader .. admonition:: Example: @@ -874,7 +874,7 @@ def delete_blob(self, delete_snapshots: Optional[str] = None, **kwargs: Any) -> and retains the blob for a specified number of days. After the specified number of days, the blob's data is removed from the service during garbage collection. Soft deleted blob is accessible through :func:`~ContainerClient.list_blobs()` specifying `include=['deleted']` - option. Soft-deleted blob can be restored using :func:`undelete` operation. + option. Soft-deleted blob can be restored using :func:`~BlobClient.undelete_blob()` operation. :param Optional[str] delete_snapshots: Required if the blob has associated snapshots. Values include: @@ -922,6 +922,7 @@ def delete_blob(self, delete_snapshots: Optional[str] = None, **kwargs: Any) -> This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -960,6 +961,7 @@ def undelete_blob(self, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -991,7 +993,7 @@ def exists(self, **kwargs: Any) -> bool: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: boolean + :return: boolean :rtype: bool """ version_id = get_version_id(self.version_id, kwargs) @@ -1061,7 +1063,7 @@ def get_blob_properties(self, **kwargs: Any) -> BlobProperties: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: BlobProperties + :return: BlobProperties :rtype: ~azure.storage.blob.BlobProperties .. admonition:: Example: @@ -1147,7 +1149,7 @@ def set_http_headers(self, content_settings: Optional["ContentSettings"] = None, This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: Dict[str, Any] """ options = _set_http_headers_options(content_settings=content_settings, **kwargs) @@ -1214,7 +1216,7 @@ def set_blob_metadata( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: Dict[str, Union[str, datetime]] """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -1250,7 +1252,7 @@ def set_immutability_policy( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, str] """ @@ -1276,7 +1278,7 @@ def delete_immutability_policy(self, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, str] """ @@ -1301,7 +1303,7 @@ def set_legal_hold(self, legal_hold: bool, **kwargs: Any) -> Dict[str, Union[str This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, Union[str, datetime, bool]] """ @@ -1398,7 +1400,7 @@ def create_page_blob( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -1494,7 +1496,7 @@ def create_append_blob( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -1573,7 +1575,7 @@ def create_snapshot( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Snapshot ID, Etag, and last modified). + :return: Blob-updated property dict (Snapshot ID, Etag, and last modified). :rtype: dict[str, Any] .. admonition:: Example: @@ -1773,7 +1775,7 @@ def start_copy_from_url( .. versionadded:: 12.10.0 - :returns: A dictionary of copy properties (etag, last_modified, copy_id, copy_status). + :return: A dictionary of copy properties (etag, last_modified, copy_id, copy_status). :rtype: dict[str, Union[str, ~datetime.datetime]] .. admonition:: Example: @@ -1812,6 +1814,7 @@ def abort_copy( The copy operation to abort. This can be either an ID string, or an instance of BlobProperties. :type copy_id: str or ~azure.storage.blob.BlobProperties + :return: None :rtype: None .. admonition:: Example: @@ -1874,7 +1877,7 @@ def acquire_lease(self, lease_duration: int =-1, lease_id: Optional[str] = None, This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A BlobLeaseClient object. + :return: A BlobLeaseClient object. :rtype: ~azure.storage.blob.BlobLeaseClient .. admonition:: Example: @@ -1930,6 +1933,7 @@ def set_standard_blob_tier(self, standard_blob_tier: Union[str, "StandardBlobTie Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.BlobLeaseClient or str + :return: None :rtype: None """ access_conditions = get_access_conditions(kwargs.pop('lease', None)) @@ -2000,7 +2004,7 @@ def stage_block( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob property dict. + :return: Blob property dict. :rtype: dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -2075,7 +2079,7 @@ def stage_block_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Blob property dict. + :return: Blob property dict. :rtype: dict[str, Any] """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -2120,7 +2124,7 @@ def get_block_list( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A tuple of two lists - committed and uncommitted blocks + :return: A tuple of two lists - committed and uncommitted blocks :rtype: Tuple[List[BlobBlock], List[BlobBlock]] """ access_conditions = get_access_conditions(kwargs.pop('lease', None)) @@ -2232,7 +2236,7 @@ def commit_block_list( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -2274,6 +2278,7 @@ def set_premium_page_blob_tier(self, premium_page_blob_tier: "PremiumPageBlobTie Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.BlobLeaseClient or str + :return: None :rtype: None """ access_conditions = get_access_conditions(kwargs.pop('lease', None)) @@ -2329,7 +2334,7 @@ def set_blob_tags(self, tags: Optional[Dict[str, str]] = None, **kwargs: Any) -> This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: Dict[str, Any] """ version_id = get_version_id(self.version_id, kwargs) @@ -2362,7 +2367,7 @@ def get_blob_tags(self, **kwargs: Any) -> Dict[str, str]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, str] """ version_id = get_version_id(self.version_id, kwargs) @@ -2434,7 +2439,7 @@ def get_page_ranges( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: + :return: A tuple of two lists of page ranges as dictionaries with 'start' and 'end' keys. The first element are filled page ranges, the 2nd element is cleared page ranges. :rtype: tuple(list(dict(str, str), list(dict(str, str)) @@ -2527,7 +2532,7 @@ def list_page_ranges( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) of PageRange. + :return: An iterable (auto-paging) of PageRange. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.PageRange] """ results_per_page = kwargs.pop('results_per_page', None) @@ -2610,7 +2615,7 @@ def get_page_range_diff_for_managed_disk( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: + :return: A tuple of two lists of page ranges as dictionaries with 'start' and 'end' keys. The first element are filled page ranges, the 2nd element is cleared page ranges. :rtype: tuple(list(dict(str, str), list(dict(str, str)) @@ -2675,7 +2680,7 @@ def set_sequence_number( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ options = _set_sequence_number_options(sequence_number_action, sequence_number=sequence_number, **kwargs) @@ -2731,7 +2736,7 @@ def resize_blob(self, size: int, **kwargs: Any) -> Dict[str, Union[str, datetime This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -2827,7 +2832,7 @@ def upload_page( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -2958,7 +2963,7 @@ def upload_pages_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Response after uploading pages from specified URL. + :return: Response after uploading pages from specified URL. :rtype: Dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3038,7 +3043,7 @@ def clear_page(self, offset: int, length: int, **kwargs: Any) -> Dict[str, Union This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3135,7 +3140,7 @@ def append_block( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag, last modified, append offset, committed block count). + :return: Blob-updated property dict (Etag, last modified, append offset, committed block count). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3259,7 +3264,7 @@ def append_block_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Result after appending a new block. + :return: Result after appending a new block. :rtype: Dict[str, Union[str, datetime, int]] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3317,7 +3322,7 @@ def seal_append_blob(self, **kwargs: Any) -> Dict[str, Union[str, datetime, int] This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag, last modified, append offset, committed block count). + :return: Blob-updated property dict (Etag, last modified, append offset, committed block count). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3334,7 +3339,7 @@ def _get_container_client(self) -> "ContainerClient": The container need not already exist. Defaults to current blob's credentials. - :returns: A ContainerClient. + :return: A ContainerClient. :rtype: ~azure.storage.blob.ContainerClient .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py index 200b89c8ddc2..48cf47367cfd 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py @@ -270,7 +270,7 @@ def _download_blob_options( The string representing the SDK package version. :param AzureBlobStorage client: The generated Blob Storage client. - :returns: A dictionary containing the download blob options. + :return: A dictionary containing the download blob options. :rtype: Dict[str, Any] """ if length is not None: @@ -658,19 +658,20 @@ def _start_copy_from_url_options( # pylint:disable=too-many-statements options = { 'copy_source': source_url, - 'seal_blob': kwargs.pop('seal_destination_blob', None), 'timeout': timeout, 'modified_access_conditions': dest_mod_conditions, - 'blob_tags_string': blob_tags_string, 'headers': headers, 'cls': return_response_headers, } + if not incremental_copy: source_mod_conditions = get_source_conditions(kwargs) dest_access_conditions = get_access_conditions(kwargs.pop('destination_lease', None)) options['source_modified_access_conditions'] = source_mod_conditions options['lease_access_conditions'] = dest_access_conditions options['tier'] = tier.value if tier else None + options['seal_blob'] = kwargs.pop('seal_destination_blob', None) + options['blob_tags_string'] = blob_tags_string options.update(kwargs) return options diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py index f6e17cb756f0..359b8014b065 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py @@ -137,7 +137,7 @@ def _format_url(self, hostname): :param str hostname: The hostname of the current location mode. - :returns: A formatted endpoint URL including current location mode hostname. + :return: A formatted endpoint URL including current location mode hostname. :rtype: str """ return f"{self.scheme}://{hostname}/{self._query_str}" @@ -169,7 +169,7 @@ def from_connection_string( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A Blob service client. + :return: A Blob service client. :rtype: ~azure.storage.blob.BlobServiceClient .. admonition:: Example: @@ -206,7 +206,7 @@ def get_user_delegation_key( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: The user delegation key. + :return: The user delegation key. :rtype: ~azure.storage.blob.UserDelegationKey """ key_info = KeyInfo(start=_to_utc_datetime(key_start_time), expiry=_to_utc_datetime(key_expiry_time)) @@ -227,7 +227,7 @@ def get_account_information(self, **kwargs: Any) -> Dict[str, str]: The information can also be retrieved if the user has a SAS to a container or blob. The keys in the returned dictionary include 'sku_name' and 'account_kind'. - :returns: A dict of account information (SKU and account type). + :return: A dict of account information (SKU and account type). :rtype: dict(str, str) .. admonition:: Example: @@ -270,7 +270,7 @@ def get_service_stats(self, **kwargs: Any) -> Dict[str, Any]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: The blob service stats. + :return: The blob service stats. :rtype: Dict[str, Any] .. admonition:: Example: @@ -301,7 +301,7 @@ def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An object containing blob service properties such as + :return: An object containing blob service properties such as analytics logging, hour/minute metrics, cors rules, etc. :rtype: Dict[str, Any] @@ -371,6 +371,7 @@ def set_service_properties( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -435,7 +436,7 @@ def list_containers( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) of ContainerProperties. + :return: An iterable (auto-paging) of ContainerProperties. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.ContainerProperties] .. admonition:: Example: @@ -489,7 +490,7 @@ def find_blobs_by_tags(self, filter_expression: str, **kwargs: Any) -> ItemPaged This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of BlobProperties. + :return: An iterable (auto-paging) response of BlobProperties. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.FilteredBlob] """ @@ -538,7 +539,7 @@ def create_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A container client to interact with the newly created container. + :return: A container client to interact with the newly created container. :rtype: ~azure.storage.blob.ContainerClient .. admonition:: Example: @@ -600,6 +601,8 @@ def delete_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None + :rtype: None .. admonition:: Example: @@ -638,7 +641,7 @@ def _rename_container(self, name: str, new_name: str, **kwargs: Any) -> Containe This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A container client for the renamed container. + :return: A container client for the renamed container. :rtype: ~azure.storage.blob.ContainerClient """ renamed_container = self.get_container_client(new_name) @@ -677,7 +680,7 @@ def undelete_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: The undeleted ContainerClient. + :return: The undeleted ContainerClient. :rtype: ~azure.storage.blob.ContainerClient """ new_name = kwargs.pop('new_name', None) @@ -701,7 +704,7 @@ def get_container_client(self, container: Union[ContainerProperties, str]) -> Co The container. This can either be the name of the container, or an instance of ContainerProperties. :type container: str or ~azure.storage.blob.ContainerProperties - :returns: A ContainerClient. + :return: A ContainerClient. :rtype: ~azure.storage.blob.ContainerClient .. admonition:: Example: @@ -750,7 +753,7 @@ def get_blob_client( :type snapshot: str or dict(str, Any) :keyword str version_id: The version id parameter is an opaque DateTime value that, when present, specifies the version of the blob to operate on. - :returns: A BlobClient. + :return: A BlobClient. :rtype: ~azure.storage.blob.BlobClient .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 783df6bc753e..07ccbb594e42 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -193,7 +193,7 @@ def from_container_url( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A container client. + :return: A container client. :rtype: ~azure.storage.blob.ContainerClient """ try: @@ -246,7 +246,7 @@ def from_connection_string( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A container client. + :return: A container client. :rtype: ~azure.storage.blob.ContainerClient .. admonition:: Example: @@ -293,7 +293,7 @@ def create_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A dictionary of response headers. + :return: A dictionary of response headers. :rtype: Dict[str, Union[str, datetime]] .. admonition:: Example: @@ -338,7 +338,7 @@ def _rename_container(self, new_name: str, **kwargs: Any) -> "ContainerClient": This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: The renamed container client. + :return: The renamed container client. :rtype: ~azure.storage.blob.ContainerClient """ lease = kwargs.pop('lease', None) @@ -392,6 +392,7 @@ def delete_container(self, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -458,7 +459,7 @@ def acquire_lease( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A BlobLeaseClient object, that can be run in a context manager. + :return: A BlobLeaseClient object, that can be run in a context manager. :rtype: ~azure.storage.blob.BlobLeaseClient .. admonition:: Example: @@ -483,7 +484,7 @@ def get_account_information(self, **kwargs: Any) -> Dict[str, str]: The information can also be retrieved if the user has a SAS to a container or blob. The keys in the returned dictionary include 'sku_name' and 'account_kind'. - :returns: A dict of account information (SKU and account type). + :return: A dict of account information (SKU and account type). :rtype: dict(str, str) """ try: @@ -543,7 +544,7 @@ def exists(self, **kwargs: Any) -> bool: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: boolean + :return: boolean :rtype: bool """ try: @@ -594,7 +595,7 @@ def set_container_metadata( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Container-updated property dict (Etag and last modified). + :return: Container-updated property dict (Etag and last modified). :rtype: dict[str, str or datetime] .. admonition:: Example: @@ -629,7 +630,7 @@ def _get_blob_service_client(self) -> "BlobServiceClient": Defaults to current container's credentials. - :returns: A BlobServiceClient. + :return: A BlobServiceClient. :rtype: ~azure.storage.blob.BlobServiceClient .. admonition:: Example: @@ -671,7 +672,7 @@ def get_container_access_policy(self, **kwargs: Any) -> Dict[str, Any]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Access policy information in a dict. + :return: Access policy information in a dict. :rtype: dict[str, Any] .. admonition:: Example: @@ -738,7 +739,7 @@ def set_container_access_policy( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Container-updated property dict (Etag and last modified). + :return: Container-updated property dict (Etag and last modified). :rtype: dict[str, str or ~datetime.datetime] .. admonition:: Example: @@ -801,7 +802,7 @@ def list_blobs( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of BlobProperties. + :return: An iterable (auto-paging) response of BlobProperties. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.BlobProperties] .. admonition:: Example: @@ -850,7 +851,7 @@ def list_blob_names(self, **kwargs: Any) -> ItemPaged[str]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of blob names as strings. + :return: An iterable (auto-paging) response of blob names as strings. :rtype: ~azure.core.paging.ItemPaged[str] """ if kwargs.pop('prefix', None): @@ -883,7 +884,7 @@ def walk_blobs( include: Optional[Union[List[str], str]] = None, delimiter: str = "/", **kwargs: Any - ) -> ItemPaged[BlobProperties]: + ) -> ItemPaged[Union[BlobProperties, BlobPrefix]]: """Returns a generator to list the blobs under the specified container. The generator will lazily follow the continuation tokens returned by the service. This operation will list blobs in accordance with a hierarchy, @@ -908,8 +909,8 @@ def walk_blobs( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of BlobProperties. - :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.BlobProperties] + :return: An iterable (auto-paging) response of BlobProperties or BlobPrefix. + :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.BlobProperties or ~azure.storage.blob.BlobPrefix] """ if kwargs.pop('prefix', None): raise ValueError("Passing 'prefix' has no effect on filtering, " + @@ -954,7 +955,7 @@ def find_blobs_by_tags( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of FilteredBlob. + :return: An iterable (auto-paging) response of FilteredBlob. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.BlobProperties] """ results_per_page = kwargs.pop('results_per_page', None) @@ -1078,7 +1079,7 @@ def upload_blob( function(current: int, total: Optional[int]) where current is the number of bytes transferred so far, and total is the size of the blob or None if the size is unknown. :paramtype progress_hook: Callable[[int, Optional[int]], None] - :returns: A BlobClient to interact with the newly uploaded blob. + :return: A BlobClient to interact with the newly uploaded blob. :rtype: ~azure.storage.blob.BlobClient .. admonition:: Example: @@ -1128,7 +1129,8 @@ def delete_blob( and retains the blob or snapshot for specified number of days. After specified number of days, blob's data is removed from the service during garbage collection. Soft deleted blob or snapshot is accessible through :func:`list_blobs()` specifying `include=["deleted"]` - option. Soft-deleted blob or snapshot can be restored using :func:`~azure.storage.blob.BlobClient.undelete()` + option. Soft-deleted blob or snapshot can be restored using + :func:`~azure.storage.blob.BlobClient.undelete_blob()` :param str blob: The blob with which to interact. :param str delete_snapshots: @@ -1176,6 +1178,7 @@ def delete_blob( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None """ if isinstance(blob, BlobProperties): @@ -1302,7 +1305,7 @@ def download_blob( the timeout will apply to each call individually. multiple calls to the Azure service and the timeout will apply to each call individually. - :returns: A streaming object (StorageStreamDownloader) + :return: A streaming object (StorageStreamDownloader) :rtype: ~azure.storage.blob.StorageStreamDownloader """ if isinstance(blob, BlobProperties): @@ -1334,7 +1337,7 @@ def delete_blobs( # pylint: disable=delete-operation-wrong-return-type and retains the blobs or snapshots for specified number of days. After specified number of days, blobs' data is removed from the service during garbage collection. Soft deleted blobs or snapshots are accessible through :func:`list_blobs()` specifying `include=["deleted"]` - Soft-deleted blobs or snapshots can be restored using :func:`~azure.storage.blob.BlobClient.undelete()` + Soft-deleted blobs or snapshots can be restored using :func:`~azure.storage.blob.BlobClient.undelete_blob()` The maximum number of blobs that can be deleted in a single request is 256. @@ -1586,7 +1589,7 @@ def get_blob_client( or the response returned from :func:`~BlobClient.create_snapshot()`. :keyword str version_id: The version id parameter is an opaque DateTime value that, when present, specifies the version of the blob to operate on. - :returns: A BlobClient. + :return: A BlobClient. :rtype: ~azure.storage.blob.BlobClient .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_deserialize.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_deserialize.py index b6ee916097a1..19ec4c07e338 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_deserialize.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_deserialize.py @@ -194,7 +194,7 @@ def parse_tags(generated_tags: Optional["BlobTags"]) -> Optional[Dict[str, str]] :param Optional[BlobTags] generated_tags: A list containing the BlobTag objects from generated code. - :returns: A dictionary of the BlobTag objects. + :return: A dictionary of the BlobTag objects. :rtype: Optional[Dict[str, str]] """ if generated_tags: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index 090c226c6094..2fec8f18c13b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -547,7 +547,7 @@ def chunks(self) -> Iterator[bytes]: NOTE: If the stream has been partially read, some data may be re-downloaded by the iterator. - :returns: An iterator of the chunks in the download stream. + :return: An iterator of the chunks in the download stream. :rtype: Iterator[bytes] .. admonition:: Example: @@ -621,7 +621,7 @@ def read(self, size: int = -1, *, chars: Optional[int] = None) -> T: The number of chars to download from the stream. Leave unspecified or set negative to download all chars. Note, this can only be used when encoding is specified on `download_blob`. - :returns: + :return: The requested data as bytes or a string if encoding was specified. If the return value is empty, there is no more data to read. :rtype: T @@ -757,7 +757,7 @@ def readall(self) -> T: Read the entire contents of this blob. This operation is blocking until all data is downloaded. - :returns: The requested data as bytes or a string if encoding was specified. + :return: The requested data as bytes or a string if encoding was specified. :rtype: T """ return self.read() @@ -769,7 +769,7 @@ def readinto(self, stream: IO[bytes]) -> int: The stream to download to. This can be an open file-handle, or any writable stream. The stream must be seekable if the download uses more than one parallel connection. - :returns: The number of bytes read. + :return: The number of bytes read. :rtype: int """ if self._text_mode: @@ -866,7 +866,7 @@ def content_as_bytes(self, max_concurrency=1): :param int max_concurrency: The number of parallel connections with which to download. - :returns: The contents of the file as bytes. + :return: The contents of the file as bytes. :rtype: bytes """ warnings.warn( @@ -891,7 +891,7 @@ def content_as_text(self, max_concurrency=1, encoding="UTF-8"): The number of parallel connections with which to download. :param str encoding: Test encoding to decode the downloaded bytes. Default is UTF-8. - :returns: The content of the file as a str. + :return: The content of the file as a str. :rtype: str """ warnings.warn( @@ -917,7 +917,7 @@ def download_to_stream(self, stream, max_concurrency=1): uses more than one parallel connection. :param int max_concurrency: The number of parallel connections with which to download. - :returns: The properties of the downloaded blob. + :return: The properties of the downloaded blob. :rtype: Any """ warnings.warn( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_encryption.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_encryption.py index 42f5c51d0762..2153d1da1da6 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_encryption.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_encryption.py @@ -38,51 +38,46 @@ from cryptography.hazmat.primitives.padding import PaddingContext -_ENCRYPTION_PROTOCOL_V1 = '1.0' -_ENCRYPTION_PROTOCOL_V2 = '2.0' -_ENCRYPTION_PROTOCOL_V2_1 = '2.1' +_ENCRYPTION_PROTOCOL_V1 = "1.0" +_ENCRYPTION_PROTOCOL_V2 = "2.0" +_ENCRYPTION_PROTOCOL_V2_1 = "2.1" _VALID_ENCRYPTION_PROTOCOLS = [_ENCRYPTION_PROTOCOL_V1, _ENCRYPTION_PROTOCOL_V2, _ENCRYPTION_PROTOCOL_V2_1] _ENCRYPTION_V2_PROTOCOLS = [_ENCRYPTION_PROTOCOL_V2, _ENCRYPTION_PROTOCOL_V2_1] _GCM_REGION_DATA_LENGTH = 4 * 1024 * 1024 _GCM_NONCE_LENGTH = 12 _GCM_TAG_LENGTH = 16 -_ERROR_OBJECT_INVALID = \ - '{0} does not define a complete interface. Value of {1} is either missing or invalid.' +_ERROR_OBJECT_INVALID = "{0} does not define a complete interface. Value of {1} is either missing or invalid." _ERROR_UNSUPPORTED_METHOD_FOR_ENCRYPTION = ( - 'The require_encryption flag is set, but encryption is not supported' - ' for this method.') + "The require_encryption flag is set, but encryption is not supported for this method." +) class KeyEncryptionKey(Protocol): - def wrap_key(self, key: bytes) -> bytes: - ... + def wrap_key(self, key: bytes) -> bytes: ... - def unwrap_key(self, key: bytes, algorithm: str) -> bytes: - ... + def unwrap_key(self, key: bytes, algorithm: str) -> bytes: ... - def get_kid(self) -> str: - ... + def get_kid(self) -> str: ... - def get_key_wrap_algorithm(self) -> str: - ... + def get_key_wrap_algorithm(self) -> str: ... def _validate_not_none(param_name: str, param: Any): if param is None: - raise ValueError(f'{param_name} should not be None.') + raise ValueError(f"{param_name} should not be None.") def _validate_key_encryption_key_wrap(kek: KeyEncryptionKey): # Note that None is not callable and so will fail the second clause of each check. - if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) - if not hasattr(kek, 'get_kid') or not callable(kek.get_kid): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) - if not hasattr(kek, 'get_key_wrap_algorithm') or not callable(kek.get_key_wrap_algorithm): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm')) + if not hasattr(kek, "wrap_key") or not callable(kek.wrap_key): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "wrap_key")) + if not hasattr(kek, "get_kid") or not callable(kek.get_kid): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "get_kid")) + if not hasattr(kek, "get_key_wrap_algorithm") or not callable(kek.get_key_wrap_algorithm): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "get_key_wrap_algorithm")) class StorageEncryptionMixin(object): @@ -91,19 +86,22 @@ def _configure_encryption(self, kwargs: Dict[str, Any]): self.encryption_version = kwargs.get("encryption_version", "1.0") self.key_encryption_key = kwargs.get("key_encryption_key") self.key_resolver_function = kwargs.get("key_resolver_function") - if self.key_encryption_key and self.encryption_version == '1.0': - warnings.warn("This client has been configured to use encryption with version 1.0. " + - "Version 1.0 is deprecated and no longer considered secure. It is highly " + - "recommended that you switch to using version 2.0. The version can be " + - "specified using the 'encryption_version' keyword.") + if self.key_encryption_key and self.encryption_version == "1.0": + warnings.warn( + "This client has been configured to use encryption with version 1.0. " + + "Version 1.0 is deprecated and no longer considered secure. It is highly " + + "recommended that you switch to using version 2.0. The version can be " + + "specified using the 'encryption_version' keyword." + ) class _EncryptionAlgorithm(object): """ Specifies which client encryption algorithm is used. """ - AES_CBC_256 = 'AES_CBC_256' - AES_GCM_256 = 'AES_GCM_256' + + AES_CBC_256 = "AES_CBC_256" + AES_GCM_256 = "AES_GCM_256" class _WrappedContentKey: @@ -120,9 +118,9 @@ def __init__(self, algorithm: str, encrypted_key: bytes, key_id: str) -> None: :param str key_id: The key-encryption-key identifier string. """ - _validate_not_none('algorithm', algorithm) - _validate_not_none('encrypted_key', encrypted_key) - _validate_not_none('key_id', key_id) + _validate_not_none("algorithm", algorithm) + _validate_not_none("encrypted_key", encrypted_key) + _validate_not_none("key_id", key_id) self.algorithm = algorithm self.encrypted_key = encrypted_key @@ -144,9 +142,9 @@ def __init__(self, data_length: int, nonce_length: int, tag_length: int) -> None :param int tag_length: The length of the encryption tag. """ - _validate_not_none('data_length', data_length) - _validate_not_none('nonce_length', nonce_length) - _validate_not_none('tag_length', tag_length) + _validate_not_none("data_length", data_length) + _validate_not_none("nonce_length", nonce_length) + _validate_not_none("tag_length", tag_length) self.data_length = data_length self.nonce_length = nonce_length @@ -166,8 +164,8 @@ def __init__(self, encryption_algorithm: _EncryptionAlgorithm, protocol: str) -> :param str protocol: The protocol version used for encryption. """ - _validate_not_none('encryption_algorithm', encryption_algorithm) - _validate_not_none('protocol', protocol) + _validate_not_none("encryption_algorithm", encryption_algorithm) + _validate_not_none("protocol", protocol) self.encryption_algorithm = str(encryption_algorithm) self.protocol = protocol @@ -179,11 +177,12 @@ class _EncryptionData: """ def __init__( - self, content_encryption_IV: Optional[bytes], + self, + content_encryption_IV: Optional[bytes], encrypted_region_info: Optional[_EncryptedRegionInfo], encryption_agent: _EncryptionAgent, wrapped_content_key: _WrappedContentKey, - key_wrapping_metadata: Dict[str, Any] + key_wrapping_metadata: Dict[str, Any], ) -> None: """ :param Optional[bytes] content_encryption_IV: @@ -200,14 +199,14 @@ def __init__( :param Dict[str, Any] key_wrapping_metadata: A dict containing metadata related to the key wrapping. """ - _validate_not_none('encryption_agent', encryption_agent) - _validate_not_none('wrapped_content_key', wrapped_content_key) + _validate_not_none("encryption_agent", encryption_agent) + _validate_not_none("wrapped_content_key", wrapped_content_key) # Validate we have the right matching optional parameter for the specified algorithm if encryption_agent.encryption_algorithm == _EncryptionAlgorithm.AES_CBC_256: - _validate_not_none('content_encryption_IV', content_encryption_IV) + _validate_not_none("content_encryption_IV", content_encryption_IV) elif encryption_agent.encryption_algorithm == _EncryptionAlgorithm.AES_GCM_256: - _validate_not_none('encrypted_region_info', encrypted_region_info) + _validate_not_none("encrypted_region_info", encrypted_region_info) else: raise ValueError("Invalid encryption algorithm.") @@ -225,8 +224,10 @@ class GCMBlobEncryptionStream: will use the same encryption key and will generate a guaranteed unique nonce for each encryption region. """ + def __init__( - self, content_encryption_key: bytes, + self, + content_encryption_key: bytes, data_stream: IO[bytes], ) -> None: """ @@ -237,7 +238,7 @@ def __init__( self.data_stream = data_stream self.offset = 0 - self.current = b'' + self.current = b"" self.nonce_counter = 0 def read(self, size: int = -1) -> bytes: @@ -286,7 +287,7 @@ def encrypt_data_v2(data: bytes, nonce: int, key: bytes) -> bytes: :return: The encrypted bytes in the form: nonce + ciphertext + tag. :rtype: bytes """ - nonce_bytes = nonce.to_bytes(_GCM_NONCE_LENGTH, 'big') + nonce_bytes = nonce.to_bytes(_GCM_NONCE_LENGTH, "big") aesgcm = AESGCM(key) # Returns ciphertext + tag @@ -307,11 +308,8 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: def modify_user_agent_for_encryption( - user_agent: str, - moniker: str, - encryption_version: str, - request_options: Dict[str, Any] - ) -> None: + user_agent: str, moniker: str, encryption_version: str, request_options: Dict[str, Any] +) -> None: """ Modifies the request options to contain a user agent string updated with encryption information. Adds azstorage-clientsideencryption/ immediately proceeding the SDK descriptor. @@ -322,7 +320,7 @@ def modify_user_agent_for_encryption( :param Dict[str, Any] request_options: The reuqest options to add the user agent override to. """ # If the user has specified user_agent_overwrite=True, don't make any modifications - if request_options.get('user_agent_overwrite'): + if request_options.get("user_agent_overwrite"): return # If the feature flag is already present, don't add it again @@ -333,11 +331,11 @@ def modify_user_agent_for_encryption( index = user_agent.find(f"azsdk-python-{moniker}") user_agent = f"{user_agent[:index]}{feature_flag} {user_agent[index:]}" # Since we are using user_agent_overwrite=True, we must prepend the user's user_agent if there is one - if request_options.get('user_agent'): + if request_options.get("user_agent"): user_agent = f"{request_options.get('user_agent')} {user_agent}" - request_options['user_agent'] = user_agent - request_options['user_agent_overwrite'] = True + request_options["user_agent"] = user_agent + request_options["user_agent_overwrite"] = True def get_adjusted_upload_size(length: int, encryption_version: str) -> int: @@ -362,10 +360,8 @@ def get_adjusted_upload_size(length: int, encryption_version: str) -> int: def get_adjusted_download_range_and_offset( - start: int, - end: int, - length: Optional[int], - encryption_data: Optional[_EncryptionData]) -> Tuple[Tuple[int, int], Tuple[int, int]]: + start: int, end: int, length: Optional[int], encryption_data: Optional[_EncryptionData] +) -> Tuple[Tuple[int, int], Tuple[int, int]]: """ Gets the new download range and offsets into the decrypted data for the given user-specified range. The new download range will include all @@ -453,7 +449,7 @@ def parse_encryption_data(metadata: Dict[str, Any]) -> Optional[_EncryptionData] try: # Use case insensitive dict as key needs to be case-insensitive case_insensitive_metadata = CaseInsensitiveDict(metadata) - return _dict_to_encryption_data(loads(case_insensitive_metadata['encryptiondata'])) + return _dict_to_encryption_data(loads(case_insensitive_metadata["encryptiondata"])) except: # pylint: disable=bare-except return None @@ -468,9 +464,11 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp :return: The new blob size. :rtype: int """ - if (encryption_data is not None and - encryption_data.encrypted_region_info is not None and - is_encryption_v2(encryption_data)): + if ( + encryption_data is not None + and encryption_data.encrypted_region_info is not None + and is_encryption_v2(encryption_data) + ): nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length @@ -485,11 +483,8 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp def _generate_encryption_data_dict( - kek: KeyEncryptionKey, - cek: bytes, - iv: Optional[bytes], - version: str - ) -> TypedOrderedDict[str, Any]: + kek: KeyEncryptionKey, cek: bytes, iv: Optional[bytes], version: str +) -> TypedOrderedDict[str, Any]: """ Generates and returns the encryption metadata as a dict. @@ -506,7 +501,7 @@ def _generate_encryption_data_dict( # For V2, we include the encryption version in the wrapped key. elif version == _ENCRYPTION_PROTOCOL_V2: # We must pad the version to 8 bytes for AES Keywrap algorithms - to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') + cek + to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b"\0") + cek wrapped_cek = kek.wrap_key(to_wrap) else: raise ValueError("Invalid encryption version specified.") @@ -514,31 +509,31 @@ def _generate_encryption_data_dict( # Build the encryption_data dict. # Use OrderedDict to comply with Java's ordering requirement. wrapped_content_key = OrderedDict() - wrapped_content_key['KeyId'] = kek.get_kid() - wrapped_content_key['EncryptedKey'] = encode_base64(wrapped_cek) - wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + wrapped_content_key["KeyId"] = kek.get_kid() + wrapped_content_key["EncryptedKey"] = encode_base64(wrapped_cek) + wrapped_content_key["Algorithm"] = kek.get_key_wrap_algorithm() encryption_agent = OrderedDict() - encryption_agent['Protocol'] = version + encryption_agent["Protocol"] = version if version == _ENCRYPTION_PROTOCOL_V1: - encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_CBC_256 + encryption_agent["EncryptionAlgorithm"] = _EncryptionAlgorithm.AES_CBC_256 elif version == _ENCRYPTION_PROTOCOL_V2: - encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_GCM_256 + encryption_agent["EncryptionAlgorithm"] = _EncryptionAlgorithm.AES_GCM_256 encrypted_region_info = OrderedDict() - encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH - encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH + encrypted_region_info["DataLength"] = _GCM_REGION_DATA_LENGTH + encrypted_region_info["NonceLength"] = _GCM_NONCE_LENGTH encryption_data_dict: TypedOrderedDict[str, Any] = OrderedDict() - encryption_data_dict['WrappedContentKey'] = wrapped_content_key - encryption_data_dict['EncryptionAgent'] = encryption_agent + encryption_data_dict["WrappedContentKey"] = wrapped_content_key + encryption_data_dict["EncryptionAgent"] = encryption_agent if version == _ENCRYPTION_PROTOCOL_V1: - encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) + encryption_data_dict["ContentEncryptionIV"] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: - encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = OrderedDict({'EncryptionLibrary': 'Python ' + VERSION}) + encryption_data_dict["EncryptedRegionInfo"] = encrypted_region_info + encryption_data_dict["KeyWrappingMetadata"] = OrderedDict({"EncryptionLibrary": "Python " + VERSION}) return encryption_data_dict @@ -554,43 +549,42 @@ def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _Encryptio :rtype: _EncryptionData """ try: - protocol = encryption_data_dict['EncryptionAgent']['Protocol'] + protocol = encryption_data_dict["EncryptionAgent"]["Protocol"] if protocol not in _VALID_ENCRYPTION_PROTOCOLS: raise ValueError("Unsupported encryption version.") except KeyError as exc: raise ValueError("Unsupported encryption version.") from exc - wrapped_content_key = encryption_data_dict['WrappedContentKey'] - wrapped_content_key = _WrappedContentKey(wrapped_content_key['Algorithm'], - decode_base64_to_bytes(wrapped_content_key['EncryptedKey']), - wrapped_content_key['KeyId']) - - encryption_agent = encryption_data_dict['EncryptionAgent'] - encryption_agent = _EncryptionAgent(encryption_agent['EncryptionAlgorithm'], - encryption_agent['Protocol']) - - if 'KeyWrappingMetadata' in encryption_data_dict: - key_wrapping_metadata = encryption_data_dict['KeyWrappingMetadata'] + wrapped_content_key = encryption_data_dict["WrappedContentKey"] + wrapped_content_key = _WrappedContentKey( + wrapped_content_key["Algorithm"], + decode_base64_to_bytes(wrapped_content_key["EncryptedKey"]), + wrapped_content_key["KeyId"], + ) + + encryption_agent = encryption_data_dict["EncryptionAgent"] + encryption_agent = _EncryptionAgent(encryption_agent["EncryptionAlgorithm"], encryption_agent["Protocol"]) + + if "KeyWrappingMetadata" in encryption_data_dict: + key_wrapping_metadata = encryption_data_dict["KeyWrappingMetadata"] else: key_wrapping_metadata = None # AES-CBC only encryption_iv = None - if 'ContentEncryptionIV' in encryption_data_dict: - encryption_iv = decode_base64_to_bytes(encryption_data_dict['ContentEncryptionIV']) + if "ContentEncryptionIV" in encryption_data_dict: + encryption_iv = decode_base64_to_bytes(encryption_data_dict["ContentEncryptionIV"]) # AES-GCM only region_info = None - if 'EncryptedRegionInfo' in encryption_data_dict: - encrypted_region_info = encryption_data_dict['EncryptedRegionInfo'] - region_info = _EncryptedRegionInfo(encrypted_region_info['DataLength'], - encrypted_region_info['NonceLength'], - _GCM_TAG_LENGTH) - - encryption_data = _EncryptionData(encryption_iv, - region_info, - encryption_agent, - wrapped_content_key, - key_wrapping_metadata) + if "EncryptedRegionInfo" in encryption_data_dict: + encrypted_region_info = encryption_data_dict["EncryptedRegionInfo"] + region_info = _EncryptedRegionInfo( + encrypted_region_info["DataLength"], encrypted_region_info["NonceLength"], _GCM_TAG_LENGTH + ) + + encryption_data = _EncryptionData( + encryption_iv, region_info, encryption_agent, wrapped_content_key, key_wrapping_metadata + ) return encryption_data @@ -614,7 +608,7 @@ def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: def _validate_and_unwrap_cek( encryption_data: _EncryptionData, key_encryption_key: Optional[KeyEncryptionKey] = None, - key_resolver: Optional[Callable[[str], KeyEncryptionKey]] = None + key_resolver: Optional[Callable[[str], KeyEncryptionKey]] = None, ) -> bytes: """ Extracts and returns the content_encryption_key stored in the encryption_data object @@ -636,15 +630,15 @@ def _validate_and_unwrap_cek( :rtype: bytes """ - _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) + _validate_not_none("encrypted_key", encryption_data.wrapped_content_key.encrypted_key) # Validate we have the right info for the specified version if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V1: - _validate_not_none('content_encryption_IV', encryption_data.content_encryption_IV) + _validate_not_none("content_encryption_IV", encryption_data.content_encryption_IV) elif encryption_data.encryption_agent.protocol in _ENCRYPTION_V2_PROTOCOLS: - _validate_not_none('encrypted_region_info', encryption_data.encrypted_region_info) + _validate_not_none("encrypted_region_info", encryption_data.encrypted_region_info) else: - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") content_encryption_key: Optional[bytes] = None @@ -654,29 +648,29 @@ def _validate_and_unwrap_cek( if key_encryption_key is None: raise ValueError("Unable to decrypt. key_resolver and key_encryption_key cannot both be None.") - if not hasattr(key_encryption_key, 'get_kid') or not callable(key_encryption_key.get_kid): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) - if not hasattr(key_encryption_key, 'unwrap_key') or not callable(key_encryption_key.unwrap_key): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'unwrap_key')) + if not hasattr(key_encryption_key, "get_kid") or not callable(key_encryption_key.get_kid): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "get_kid")) + if not hasattr(key_encryption_key, "unwrap_key") or not callable(key_encryption_key.unwrap_key): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "unwrap_key")) if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): - raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') + raise ValueError("Provided or resolved key-encryption-key does not match the id of key used to encrypt.") # Will throw an exception if the specified algorithm is not supported. content_encryption_key = key_encryption_key.unwrap_key( - encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm + ) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. if encryption_data.encryption_agent.protocol in _ENCRYPTION_V2_PROTOCOLS: - version_2_bytes = encryption_data.encryption_agent.protocol.encode().ljust(8, b'\0') - cek_version_bytes = content_encryption_key[:len(version_2_bytes)] + version_2_bytes = encryption_data.encryption_agent.protocol.encode().ljust(8, b"\0") + cek_version_bytes = content_encryption_key[: len(version_2_bytes)] if cek_version_bytes != version_2_bytes: - raise ValueError('The encryption metadata is not valid and may have been modified.') + raise ValueError("The encryption metadata is not valid and may have been modified.") # Remove version from the start of the cek. - content_encryption_key = content_encryption_key[len(version_2_bytes):] + content_encryption_key = content_encryption_key[len(version_2_bytes) :] - _validate_not_none('content_encryption_key', content_encryption_key) + _validate_not_none("content_encryption_key", content_encryption_key) return content_encryption_key @@ -685,7 +679,7 @@ def _decrypt_message( message: bytes, encryption_data: _EncryptionData, key_encryption_key: Optional[KeyEncryptionKey] = None, - resolver: Optional[Callable[[str], KeyEncryptionKey]] = None + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None, ) -> bytes: """ Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. @@ -710,7 +704,7 @@ def _decrypt_message( :return: The decrypted plaintext. :rtype: bytes """ - _validate_not_none('message', message) + _validate_not_none("message", message) content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, resolver) if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V1: @@ -721,11 +715,11 @@ def _decrypt_message( # decrypt data decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(message) + decryptor.finalize()) + decrypted_data = decryptor.update(message) + decryptor.finalize() # unpad data unpadder = PKCS7(128).unpadder() - decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) + decrypted_data = unpadder.update(decrypted_data) + unpadder.finalize() elif encryption_data.encryption_agent.protocol in _ENCRYPTION_V2_PROTOCOLS: block_info = encryption_data.encrypted_region_info @@ -745,7 +739,7 @@ def _decrypt_message( decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) else: - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") return decrypted_data @@ -773,8 +767,8 @@ def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str :rtype: (str, bytes) """ - _validate_not_none('blob', blob) - _validate_not_none('key_encryption_key', key_encryption_key) + _validate_not_none("blob", blob) + _validate_not_none("key_encryption_key", key_encryption_key) _validate_key_encryption_key_wrap(key_encryption_key) if version == _ENCRYPTION_PROTOCOL_V1: @@ -805,16 +799,16 @@ def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str else: raise ValueError("Invalid encryption version specified.") - encryption_data = _generate_encryption_data_dict(key_encryption_key, content_encryption_key, - initialization_vector, version) - encryption_data['EncryptionMode'] = 'FullBlob' + encryption_data = _generate_encryption_data_dict( + key_encryption_key, content_encryption_key, initialization_vector, version + ) + encryption_data["EncryptionMode"] = "FullBlob" return dumps(encryption_data), encrypted_data def generate_blob_encryption_data( - key_encryption_key: Optional[KeyEncryptionKey], - version: str + key_encryption_key: Optional[KeyEncryptionKey], version: str ) -> Tuple[Optional[bytes], Optional[bytes], Optional[str]]: """ Generates the encryption_metadata for the blob. @@ -836,24 +830,23 @@ def generate_blob_encryption_data( # Initialization vector only needed for V1 if version == _ENCRYPTION_PROTOCOL_V1: initialization_vector = os.urandom(16) - encryption_data_dict = _generate_encryption_data_dict(key_encryption_key, - content_encryption_key, - initialization_vector, - version) - encryption_data_dict['EncryptionMode'] = 'FullBlob' + encryption_data_dict = _generate_encryption_data_dict( + key_encryption_key, content_encryption_key, initialization_vector, version + ) + encryption_data_dict["EncryptionMode"] = "FullBlob" encryption_data = dumps(encryption_data_dict) return content_encryption_key, initialization_vector, encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements - require_encryption: bool, - key_encryption_key: Optional[KeyEncryptionKey], - key_resolver: Optional[Callable[[str], KeyEncryptionKey]], - content: bytes, - start_offset: int, - end_offset: int, - response_headers: Dict[str, Any] + require_encryption: bool, + key_encryption_key: Optional[KeyEncryptionKey], + key_resolver: Optional[Callable[[str], KeyEncryptionKey]], + content: bytes, + start_offset: int, + end_offset: int, + response_headers: Dict[str, Any], ) -> bytes: """ Decrypts the given blob contents and returns only the requested range. @@ -885,39 +878,40 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements :rtype: bytes """ try: - encryption_data = _dict_to_encryption_data(loads(response_headers['x-ms-meta-encryptiondata'])) + encryption_data = _dict_to_encryption_data(loads(response_headers["x-ms-meta-encryptiondata"])) except Exception as exc: # pylint: disable=broad-except if require_encryption: raise ValueError( - 'Encryption required, but received data does not contain appropriate metadata.' + \ - 'Data was either not encrypted or metadata has been lost.') from exc + "Encryption required, but received data does not contain appropriate metadata." + + "Data was either not encrypted or metadata has been lost." + ) from exc return content algorithm = encryption_data.encryption_agent.encryption_algorithm - if algorithm not in(_EncryptionAlgorithm.AES_CBC_256, _EncryptionAlgorithm.AES_GCM_256): - raise ValueError('Specified encryption algorithm is not supported.') + if algorithm not in (_EncryptionAlgorithm.AES_CBC_256, _EncryptionAlgorithm.AES_GCM_256): + raise ValueError("Specified encryption algorithm is not supported.") version = encryption_data.encryption_agent.protocol if version not in _VALID_ENCRYPTION_PROTOCOLS: - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, key_resolver) if version == _ENCRYPTION_PROTOCOL_V1: - blob_type = response_headers['x-ms-blob-type'] + blob_type = response_headers["x-ms-blob-type"] iv: Optional[bytes] = None unpad = False - if 'content-range' in response_headers: - content_range = response_headers['content-range'] + if "content-range" in response_headers: + content_range = response_headers["content-range"] # Format: 'bytes x-y/size' # Ignore the word 'bytes' - content_range = content_range.split(' ') + content_range = content_range.split(" ") - content_range = content_range[1].split('-') - content_range = content_range[1].split('/') + content_range = content_range[1].split("-") + content_range = content_range[1].split("/") end_range = int(content_range[0]) blob_size = int(content_range[1]) @@ -934,7 +928,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements unpad = True iv = encryption_data.content_encryption_IV - if blob_type == 'PageBlob': + if blob_type == "PageBlob": unpad = False if iv is None: @@ -948,7 +942,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements unpadder = PKCS7(128).unpadder() content = unpadder.update(content) + unpadder.finalize() - return content[start_offset: len(content) - end_offset] + return content[start_offset : len(content) - end_offset] if version in _ENCRYPTION_V2_PROTOCOLS: # We assume the content contains only full encryption regions @@ -967,7 +961,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements while offset < total_size: # Process one encryption region at a time process_size = min(region_length, total_size) - encrypted_region = content[offset:offset + process_size] + encrypted_region = content[offset : offset + process_size] # First bytes are the nonce nonce = encrypted_region[:nonce_length] @@ -982,13 +976,11 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements # Read the caller requested data from the decrypted content return decrypted_content[start_offset:end_offset] - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") def get_blob_encryptor_and_padder( - cek: Optional[bytes], - iv: Optional[bytes], - should_pad: bool + cek: Optional[bytes], iv: Optional[bytes], should_pad: bool ) -> Tuple[Optional["AEADEncryptionContext"], Optional["PaddingContext"]]: encryptor = None padder = None @@ -1022,13 +1014,13 @@ def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, ve :rtype: str """ - _validate_not_none('message', message) - _validate_not_none('key_encryption_key', key_encryption_key) + _validate_not_none("message", message) + _validate_not_none("key_encryption_key", key_encryption_key) _validate_key_encryption_key_wrap(key_encryption_key) # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. - message_as_bytes: bytes = message.encode('utf-8') + message_as_bytes: bytes = message.encode("utf-8") if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -1062,11 +1054,12 @@ def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, ve raise ValueError("Invalid encryption version specified.") # Build the dictionary structure. - queue_message = {'EncryptedMessageContents': encode_base64(encrypted_data), - 'EncryptionData': _generate_encryption_data_dict(key_encryption_key, - content_encryption_key, - initialization_vector, - version)} + queue_message = { + "EncryptedMessageContents": encode_base64(encrypted_data), + "EncryptionData": _generate_encryption_data_dict( + key_encryption_key, content_encryption_key, initialization_vector, version + ), + } return dumps(queue_message) @@ -1076,7 +1069,7 @@ def decrypt_queue_message( response: "PipelineResponse", require_encryption: bool, key_encryption_key: Optional[KeyEncryptionKey], - resolver: Optional[Callable[[str], KeyEncryptionKey]] + resolver: Optional[Callable[[str], KeyEncryptionKey]], ) -> str: """ Returns the decrypted message contents from an EncryptedQueueMessage. @@ -1106,22 +1099,22 @@ def decrypt_queue_message( try: deserialized_message: Dict[str, Any] = loads(message) - encryption_data = _dict_to_encryption_data(deserialized_message['EncryptionData']) - decoded_data = decode_base64_to_bytes(deserialized_message['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(deserialized_message["EncryptionData"]) + decoded_data = decode_base64_to_bytes(deserialized_message["EncryptedMessageContents"]) except (KeyError, ValueError) as exc: # Message was not json formatted and so was not encrypted # or the user provided a json formatted message # or the metadata was malformed. if require_encryption: raise ValueError( - 'Encryption required, but received message does not contain appropriate metatadata. ' + \ - 'Message was either not encrypted or metadata was incorrect.') from exc + "Encryption required, but received message does not contain appropriate metatadata. " + + "Message was either not encrypted or metadata was incorrect." + ) from exc return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode("utf-8") except Exception as error: raise HttpResponseError( - message="Decryption failed.", - response=response, #type: ignore [arg-type] - error=error) from error + message="Decryption failed.", response=response, error=error # type: ignore [arg-type] + ) from error diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_lease.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_lease.py index b8b5684d7c23..19bd7179c09a 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_lease.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_lease.py @@ -102,6 +102,7 @@ def acquire(self, lease_duration: int = -1, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None """ mod_conditions = get_modify_conditions(kwargs) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_quick_query_helper.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_quick_query_helper.py index 60bda00db5c2..ae2afbafd2ff 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_quick_query_helper.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_quick_query_helper.py @@ -84,7 +84,7 @@ def readall(self) -> bytes: This operation is blocking until all data is downloaded. - :returns: The query results. + :return: The query results. :rtype: bytes """ stream = BytesIO() @@ -100,7 +100,7 @@ def readinto(self, stream: IO) -> None: :param IO stream: The stream to download to. This can be an open file-handle, or any writable stream. - :returns: None + :return: None """ for record in self._iter_stream(): stream.write(record) @@ -110,7 +110,7 @@ def records(self) -> Iterable[bytes]: Records will be returned line by line. - :returns: A record generator for the query result. + :return: A record generator for the query result. :rtype: Iterable[bytes] """ delimiter = self.record_delimiter.encode('utf-8') diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/__init__.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/__init__.py index a8b1a27d48f9..4dbbb7ed7b09 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/__init__.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/__init__.py @@ -11,7 +11,7 @@ try: from urllib.parse import quote, unquote except ImportError: - from urllib2 import quote, unquote # type: ignore + from urllib2 import quote, unquote # type: ignore def url_quote(url): @@ -24,20 +24,20 @@ def url_unquote(url): def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") def decode_base64_to_bytes(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") return base64.b64decode(data) def decode_base64_to_text(data): decoded_bytes = decode_base64_to_bytes(data) - return decoded_bytes.decode('utf-8') + return decoded_bytes.decode("utf-8") def sign_string(key, string_to_sign, key_is_base64=True): @@ -45,9 +45,9 @@ def sign_string(key, string_to_sign, key_is_base64=True): key = decode_base64_to_bytes(key) else: if isinstance(key, str): - key = key.encode('utf-8') + key = key.encode("utf-8") if isinstance(string_to_sign, str): - string_to_sign = string_to_sign.encode('utf-8') + string_to_sign = string_to_sign.encode("utf-8") signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = encode_base64(digest) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/authentication.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/authentication.py index b41f2391ed4a..f778dc71eec4 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/authentication.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/authentication.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) +# fmt: off table_lv0 = [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, @@ -51,6 +52,8 @@ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, ] +# fmt: on + def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements tables = [table_lv0, table_lv4] @@ -95,6 +98,7 @@ def _wrap_exception(ex, desired_type): msg = ex.args[0] return desired_type(msg) + # This method attempts to emulate the sorting done by the service def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: @@ -135,38 +139,42 @@ def __init__(self, account_name, account_key): @staticmethod def _get_headers(request, headers_to_sign): headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value) - if 'content-length' in headers and headers['content-length'] == '0': - del headers['content-length'] - return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n' + if "content-length" in headers and headers["content-length"] == "0": + del headers["content-length"] + return "\n".join(headers.get(x, "") for x in headers_to_sign) + "\n" @staticmethod def _get_verb(request): - return request.http_request.method + '\n' + return request.http_request.method + "\n" def _get_canonicalized_resource(self, request): uri_path = urlparse(request.http_request.url).path try: - if isinstance(request.context.transport, AioHttpTransport) or \ - isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \ - isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None), - AioHttpTransport): + if ( + isinstance(request.context.transport, AioHttpTransport) + or isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) + or isinstance( + getattr(getattr(request.context.transport, "_transport", None), "_transport", None), + AioHttpTransport, + ) + ): uri_path = URL(uri_path) - return '/' + self.account_name + str(uri_path) + return "/" + self.account_name + str(uri_path) except TypeError: pass - return '/' + self.account_name + uri_path + return "/" + self.account_name + uri_path @staticmethod def _get_canonicalized_headers(request): - string_to_sign = '' + string_to_sign = "" x_ms_headers = [] for name, value in request.http_request.headers.items(): - if name.startswith('x-ms-'): + if name.startswith("x-ms-"): x_ms_headers.append((name.lower(), value)) x_ms_headers = _storage_header_sort(x_ms_headers) for name, value in x_ms_headers: if value is not None: - string_to_sign += ''.join([name, ':', value, '\n']) + string_to_sign += "".join([name, ":", value, "\n"]) return string_to_sign @staticmethod @@ -174,37 +182,46 @@ def _get_canonicalized_resource_query(request): sorted_queries = list(request.http_request.query.items()) sorted_queries.sort() - string_to_sign = '' + string_to_sign = "" for name, value in sorted_queries: if value is not None: - string_to_sign += '\n' + name.lower() + ':' + unquote(value) + string_to_sign += "\n" + name.lower() + ":" + unquote(value) return string_to_sign def _add_authorization_header(self, request, string_to_sign): try: signature = sign_string(self.account_key, string_to_sign) - auth_string = 'SharedKey ' + self.account_name + ':' + signature - request.http_request.headers['Authorization'] = auth_string + auth_string = "SharedKey " + self.account_name + ":" + signature + request.http_request.headers["Authorization"] = auth_string except Exception as ex: # Wrap any error that occurred as signing error # Doing so will clarify/locate the source of problem raise _wrap_exception(ex, AzureSigningError) from ex def on_request(self, request): - string_to_sign = \ - self._get_verb(request) + \ - self._get_headers( + string_to_sign = ( + self._get_verb(request) + + self._get_headers( request, [ - 'content-encoding', 'content-language', 'content-length', - 'content-md5', 'content-type', 'date', 'if-modified-since', - 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range' - ] - ) + \ - self._get_canonicalized_headers(request) + \ - self._get_canonicalized_resource(request) + \ - self._get_canonicalized_resource_query(request) + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", + ], + ) + + self._get_canonicalized_headers(request) + + self._get_canonicalized_resource(request) + + self._get_canonicalized_resource_query(request) + ) self._add_authorization_header(request, string_to_sign) # logger.debug("String_to_sign=%s", string_to_sign) @@ -212,7 +229,7 @@ def on_request(self, request): class StorageHttpChallenge(object): def __init__(self, challenge): - """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """ + """Parses an HTTP WWW-Authentication Bearer challenge from the Storage service.""" if not challenge: raise ValueError("Challenge cannot be empty") @@ -221,7 +238,7 @@ def __init__(self, challenge): # name=value pairs either comma or space separated with values possibly being # enclosed in quotes - for item in re.split('[, ]', trimmed_challenge): + for item in re.split("[, ]", trimmed_challenge): comps = item.split("=") if len(comps) == 2: key = comps[0].strip(' "') @@ -230,11 +247,11 @@ def __init__(self, challenge): self._parameters[key] = value # Extract and verify required parameters - self.authorization_uri = self._parameters.get('authorization_uri') + self.authorization_uri = self._parameters.get("authorization_uri") if not self.authorization_uri: raise ValueError("Authorization Uri not found") - self.resource_id = self._parameters.get('resource_id') + self.resource_id = self._parameters.get("resource_id") if not self.resource_id: raise ValueError("Resource id not found") diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io.py index 7b59165b60da..f63cf78e4b74 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io.py @@ -42,8 +42,8 @@ # ------------------------------------------------------------------------------ # Constants -STRUCT_FLOAT = struct.Struct('= 0), n + assert n >= 0, n input_bytes = self.reader.read(n) if n > 0 and not input_bytes: raise StopIteration - assert (len(input_bytes) == n), input_bytes + assert len(input_bytes) == n, input_bytes return input_bytes @staticmethod @@ -149,7 +150,7 @@ def read_bytes(self): Bytes are encoded as a long followed by that many bytes of data. """ nbytes = self.read_long() - assert (nbytes >= 0), nbytes + assert nbytes >= 0, nbytes return self.read(nbytes) def read_utf8(self): @@ -160,9 +161,9 @@ def read_utf8(self): input_bytes = self.read_bytes() if PY3: try: - return input_bytes.decode('utf-8') + return input_bytes.decode("utf-8") except UnicodeDecodeError as exn: - logger.error('Invalid UTF-8 input bytes: %r', input_bytes) # pylint: disable=do-not-log-raised-errors + logger.error("Invalid UTF-8 input bytes: %r", input_bytes) # pylint: disable=do-not-log-raised-errors raise exn else: # PY2 @@ -216,41 +217,40 @@ def __init__(self, writer_schema=None): def set_writer_schema(self, writer_schema): self._writer_schema = writer_schema - writer_schema = property(lambda self: self._writer_schema, - set_writer_schema) + writer_schema = property(lambda self: self._writer_schema, set_writer_schema) def read(self, decoder): return self.read_data(self.writer_schema, decoder) def read_data(self, writer_schema, decoder): # function dispatch for reading data based on type of writer's schema - if writer_schema.type == 'null': + if writer_schema.type == "null": result = decoder.read_null() - elif writer_schema.type == 'boolean': + elif writer_schema.type == "boolean": result = decoder.read_boolean() - elif writer_schema.type == 'string': + elif writer_schema.type == "string": result = decoder.read_utf8() - elif writer_schema.type == 'int': + elif writer_schema.type == "int": result = decoder.read_int() - elif writer_schema.type == 'long': + elif writer_schema.type == "long": result = decoder.read_long() - elif writer_schema.type == 'float': + elif writer_schema.type == "float": result = decoder.read_float() - elif writer_schema.type == 'double': + elif writer_schema.type == "double": result = decoder.read_double() - elif writer_schema.type == 'bytes': + elif writer_schema.type == "bytes": result = decoder.read_bytes() - elif writer_schema.type == 'fixed': + elif writer_schema.type == "fixed": result = self.read_fixed(writer_schema, decoder) - elif writer_schema.type == 'enum': + elif writer_schema.type == "enum": result = self.read_enum(writer_schema, decoder) - elif writer_schema.type == 'array': + elif writer_schema.type == "array": result = self.read_array(writer_schema, decoder) - elif writer_schema.type == 'map': + elif writer_schema.type == "map": result = self.read_map(writer_schema, decoder) - elif writer_schema.type in ['union', 'error_union']: + elif writer_schema.type in ["union", "error_union"]: result = self.read_union(writer_schema, decoder) - elif writer_schema.type in ['record', 'error', 'request']: + elif writer_schema.type in ["record", "error", "request"]: result = self.read_record(writer_schema, decoder) else: fail_msg = f"Cannot read unknown schema type: {writer_schema.type}" @@ -258,35 +258,35 @@ def read_data(self, writer_schema, decoder): return result def skip_data(self, writer_schema, decoder): - if writer_schema.type == 'null': + if writer_schema.type == "null": result = decoder.skip_null() - elif writer_schema.type == 'boolean': + elif writer_schema.type == "boolean": result = decoder.skip_boolean() - elif writer_schema.type == 'string': + elif writer_schema.type == "string": result = decoder.skip_utf8() - elif writer_schema.type == 'int': + elif writer_schema.type == "int": result = decoder.skip_int() - elif writer_schema.type == 'long': + elif writer_schema.type == "long": result = decoder.skip_long() - elif writer_schema.type == 'float': + elif writer_schema.type == "float": result = decoder.skip_float() - elif writer_schema.type == 'double': + elif writer_schema.type == "double": result = decoder.skip_double() - elif writer_schema.type == 'bytes': + elif writer_schema.type == "bytes": result = decoder.skip_bytes() - elif writer_schema.type == 'fixed': + elif writer_schema.type == "fixed": result = self.skip_fixed(writer_schema, decoder) - elif writer_schema.type == 'enum': + elif writer_schema.type == "enum": result = self.skip_enum(decoder) - elif writer_schema.type == 'array': + elif writer_schema.type == "array": self.skip_array(writer_schema, decoder) result = None - elif writer_schema.type == 'map': + elif writer_schema.type == "map": self.skip_map(writer_schema, decoder) result = None - elif writer_schema.type in ['union', 'error_union']: + elif writer_schema.type in ["union", "error_union"]: result = self.skip_union(writer_schema, decoder) - elif writer_schema.type in ['record', 'error', 'request']: + elif writer_schema.type in ["record", "error", "request"]: self.skip_record(writer_schema, decoder) result = None else: @@ -389,8 +389,9 @@ def read_union(self, writer_schema, decoder): # schema resolution index_of_schema = int(decoder.read_long()) if index_of_schema >= len(writer_schema.schemas): - fail_msg = (f"Can't access branch index {index_of_schema} " - f"for union with {len(writer_schema.schemas)} branches") + fail_msg = ( + f"Can't access branch index {index_of_schema} " f"for union with {len(writer_schema.schemas)} branches" + ) raise SchemaResolutionException(fail_msg, writer_schema) selected_writer_schema = writer_schema.schemas[index_of_schema] @@ -400,8 +401,9 @@ def read_union(self, writer_schema, decoder): def skip_union(self, writer_schema, decoder): index_of_schema = int(decoder.read_long()) if index_of_schema >= len(writer_schema.schemas): - fail_msg = (f"Can't access branch index {index_of_schema} " - f"for union with {len(writer_schema.schemas)} branches") + fail_msg = ( + f"Can't access branch index {index_of_schema} " f"for union with {len(writer_schema.schemas)} branches" + ) raise SchemaResolutionException(fail_msg, writer_schema) return self.skip_data(writer_schema.schemas[index_of_schema], decoder) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io_async.py index 487f5e4b47bb..b56a75e0c64c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/avro_io_async.py @@ -61,14 +61,14 @@ async def read(self, n): """Read n bytes. :param int n: Number of bytes to read. - :returns: The next n bytes from the input. + :return: The next n bytes from the input. :rtype: bytes """ - assert (n >= 0), n + assert n >= 0, n input_bytes = await self.reader.read(n) if n > 0 and not input_bytes: raise StopAsyncIteration - assert (len(input_bytes) == n), input_bytes + assert len(input_bytes) == n, input_bytes return input_bytes @staticmethod @@ -132,7 +132,7 @@ async def read_bytes(self): Bytes are encoded as a long followed by that many bytes of data. """ nbytes = await self.read_long() - assert (nbytes >= 0), nbytes + assert nbytes >= 0, nbytes return await self.read(nbytes) async def read_utf8(self): @@ -143,13 +143,13 @@ async def read_utf8(self): input_bytes = await self.read_bytes() if PY3: try: - return input_bytes.decode('utf-8') + return input_bytes.decode("utf-8") except UnicodeDecodeError as exn: - logger.error('Invalid UTF-8 input bytes: %r', input_bytes) # pylint: disable=do-not-log-raised-errors + logger.error("Invalid UTF-8 input bytes: %r", input_bytes) # pylint: disable=do-not-log-raised-errors raise exn else: # PY2 - return unicode(input_bytes, "utf-8") # pylint: disable=undefined-variable + return unicode(input_bytes, "utf-8") # pylint: disable=undefined-variable def skip_null(self): pass @@ -200,41 +200,40 @@ def __init__(self, writer_schema=None): def set_writer_schema(self, writer_schema): self._writer_schema = writer_schema - writer_schema = property(lambda self: self._writer_schema, - set_writer_schema) + writer_schema = property(lambda self: self._writer_schema, set_writer_schema) async def read(self, decoder): return await self.read_data(self.writer_schema, decoder) async def read_data(self, writer_schema, decoder): # function dispatch for reading data based on type of writer's schema - if writer_schema.type == 'null': + if writer_schema.type == "null": result = decoder.read_null() - elif writer_schema.type == 'boolean': + elif writer_schema.type == "boolean": result = await decoder.read_boolean() - elif writer_schema.type == 'string': + elif writer_schema.type == "string": result = await decoder.read_utf8() - elif writer_schema.type == 'int': + elif writer_schema.type == "int": result = await decoder.read_int() - elif writer_schema.type == 'long': + elif writer_schema.type == "long": result = await decoder.read_long() - elif writer_schema.type == 'float': + elif writer_schema.type == "float": result = await decoder.read_float() - elif writer_schema.type == 'double': + elif writer_schema.type == "double": result = await decoder.read_double() - elif writer_schema.type == 'bytes': + elif writer_schema.type == "bytes": result = await decoder.read_bytes() - elif writer_schema.type == 'fixed': + elif writer_schema.type == "fixed": result = await self.read_fixed(writer_schema, decoder) - elif writer_schema.type == 'enum': + elif writer_schema.type == "enum": result = await self.read_enum(writer_schema, decoder) - elif writer_schema.type == 'array': + elif writer_schema.type == "array": result = await self.read_array(writer_schema, decoder) - elif writer_schema.type == 'map': + elif writer_schema.type == "map": result = await self.read_map(writer_schema, decoder) - elif writer_schema.type in ['union', 'error_union']: + elif writer_schema.type in ["union", "error_union"]: result = await self.read_union(writer_schema, decoder) - elif writer_schema.type in ['record', 'error', 'request']: + elif writer_schema.type in ["record", "error", "request"]: result = await self.read_record(writer_schema, decoder) else: fail_msg = f"Cannot read unknown schema type: {writer_schema.type}" @@ -242,35 +241,35 @@ async def read_data(self, writer_schema, decoder): return result async def skip_data(self, writer_schema, decoder): - if writer_schema.type == 'null': + if writer_schema.type == "null": result = decoder.skip_null() - elif writer_schema.type == 'boolean': + elif writer_schema.type == "boolean": result = await decoder.skip_boolean() - elif writer_schema.type == 'string': + elif writer_schema.type == "string": result = await decoder.skip_utf8() - elif writer_schema.type == 'int': + elif writer_schema.type == "int": result = await decoder.skip_int() - elif writer_schema.type == 'long': + elif writer_schema.type == "long": result = await decoder.skip_long() - elif writer_schema.type == 'float': + elif writer_schema.type == "float": result = await decoder.skip_float() - elif writer_schema.type == 'double': + elif writer_schema.type == "double": result = await decoder.skip_double() - elif writer_schema.type == 'bytes': + elif writer_schema.type == "bytes": result = await decoder.skip_bytes() - elif writer_schema.type == 'fixed': + elif writer_schema.type == "fixed": result = await self.skip_fixed(writer_schema, decoder) - elif writer_schema.type == 'enum': + elif writer_schema.type == "enum": result = await self.skip_enum(decoder) - elif writer_schema.type == 'array': + elif writer_schema.type == "array": await self.skip_array(writer_schema, decoder) result = None - elif writer_schema.type == 'map': + elif writer_schema.type == "map": await self.skip_map(writer_schema, decoder) result = None - elif writer_schema.type in ['union', 'error_union']: + elif writer_schema.type in ["union", "error_union"]: result = await self.skip_union(writer_schema, decoder) - elif writer_schema.type in ['record', 'error', 'request']: + elif writer_schema.type in ["record", "error", "request"]: await self.skip_record(writer_schema, decoder) result = None else: @@ -373,8 +372,9 @@ async def read_union(self, writer_schema, decoder): # schema resolution index_of_schema = int(await decoder.read_long()) if index_of_schema >= len(writer_schema.schemas): - fail_msg = (f"Can't access branch index {index_of_schema} " - f"for union with {len(writer_schema.schemas)} branches") + fail_msg = ( + f"Can't access branch index {index_of_schema} " f"for union with {len(writer_schema.schemas)} branches" + ) raise SchemaResolutionException(fail_msg, writer_schema) selected_writer_schema = writer_schema.schemas[index_of_schema] @@ -384,8 +384,9 @@ async def read_union(self, writer_schema, decoder): async def skip_union(self, writer_schema, decoder): index_of_schema = int(await decoder.read_long()) if index_of_schema >= len(writer_schema.schemas): - fail_msg = (f"Can't access branch index {index_of_schema} " - f"for union with {len(writer_schema.schemas)} branches") + fail_msg = ( + f"Can't access branch index {index_of_schema} " f"for union with {len(writer_schema.schemas)} branches" + ) raise SchemaResolutionException(fail_msg, writer_schema) return await self.skip_data(writer_schema.schemas[index_of_schema], decoder) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile.py index 757e0329cd07..0c60651023ef 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile.py @@ -26,17 +26,18 @@ VERSION = 1 if PY3: - MAGIC = b'Obj' + bytes([VERSION]) + MAGIC = b"Obj" + bytes([VERSION]) MAGIC_SIZE = len(MAGIC) else: - MAGIC = 'Obj' + chr(VERSION) + MAGIC = "Obj" + chr(VERSION) MAGIC_SIZE = len(MAGIC) # Size of the synchronization marker, in number of bytes: SYNC_SIZE = 16 # Schema of the container header: -META_SCHEMA = schema.parse(""" +META_SCHEMA = schema.parse( + """ { "type": "record", "name": "org.apache.avro.file.Header", "fields": [{ @@ -50,13 +51,15 @@ "type": {"type": "fixed", "name": "sync", "size": %(sync_size)d} }] } -""" % { - 'magic_size': MAGIC_SIZE, - 'sync_size': SYNC_SIZE, -}) +""" + % { + "magic_size": MAGIC_SIZE, + "sync_size": SYNC_SIZE, + } +) # Codecs supported by container files: -VALID_CODECS = frozenset(['null', 'deflate']) +VALID_CODECS = frozenset(["null", "deflate"]) # Metadata key associated to the schema: SCHEMA_KEY = "avro.schema" @@ -69,6 +72,7 @@ class DataFileException(schema.AvroException): """Problem reading or writing file object containers.""" + # ------------------------------------------------------------------------------ @@ -84,7 +88,7 @@ def __init__(self, reader, datum_reader, **kwargs): """ self._reader = reader self._raw_decoder = avro_io.BinaryDecoder(reader) - self._header_reader = kwargs.pop('header_reader', None) + self._header_reader = kwargs.pop("header_reader", None) self._header_decoder = None if self._header_reader is None else avro_io.BinaryDecoder(self._header_reader) self._datum_decoder = None # Maybe reset at every block. self._datum_reader = datum_reader @@ -97,11 +101,11 @@ def __init__(self, reader, datum_reader, **kwargs): self._read_header() # ensure codec is valid - avro_codec_raw = self.get_meta('avro.codec') + avro_codec_raw = self.get_meta("avro.codec") if avro_codec_raw is None: self.codec = "null" else: - self.codec = avro_codec_raw.decode('utf-8') + self.codec = avro_codec_raw.decode("utf-8") if self.codec not in VALID_CODECS: raise DataFileException(f"Unknown codec: {self.codec}.") @@ -110,7 +114,7 @@ def __init__(self, reader, datum_reader, **kwargs): # object_position is to support reading from current position in the future read, # no need to downloading from the beginning of avro. - if hasattr(self._reader, 'object_position'): + if hasattr(self._reader, "object_position"): self.reader.track_object_position() self._cur_object_index = 0 @@ -120,8 +124,7 @@ def __init__(self, reader, datum_reader, **kwargs): if self._header_reader is not None: self._datum_decoder = self._raw_decoder - self.datum_reader.writer_schema = ( - schema.parse(self.get_meta(SCHEMA_KEY).decode('utf-8'))) + self.datum_reader.writer_schema = schema.parse(self.get_meta(SCHEMA_KEY).decode("utf-8")) def __enter__(self): return self @@ -168,7 +171,7 @@ def get_meta(self, key): """Reports the value of a given metadata key. :param str key: Metadata key to report the value of. - :returns: Value associated to the metadata key, as bytes. + :return: Value associated to the metadata key, as bytes. :rtype: bytes """ return self._meta.get(key) @@ -184,15 +187,15 @@ def _read_header(self): header = self.datum_reader.read_data(META_SCHEMA, header_decoder) # check magic number - if header.get('magic') != MAGIC: + if header.get("magic") != MAGIC: fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC!r}." raise schema.AvroException(fail_msg) # set metadata - self._meta = header['meta'] + self._meta = header["meta"] # set sync marker - self._sync_marker = header['sync'] + self._sync_marker = header["sync"] def _read_block_header(self): self._block_count = self.raw_decoder.read_long() @@ -200,7 +203,7 @@ def _read_block_header(self): # Skip a long; we don't need to use the length. self.raw_decoder.skip_long() self._datum_decoder = self._raw_decoder - elif self.codec == 'deflate': + elif self.codec == "deflate": # Compressed data is stored as (length, data), which # corresponds to how the "bytes" type is encoded. data = self.raw_decoder.read_bytes() @@ -229,7 +232,7 @@ def __next__(self): # object_position is to support reading from current position in the future read, # no need to downloading from the beginning of avro file with this attr. - if hasattr(self._reader, 'object_position'): + if hasattr(self._reader, "object_position"): self.reader.track_object_position() self._cur_object_index = 0 @@ -242,7 +245,7 @@ def __next__(self): # object_position is to support reading from current position in the future read, # This will track the index of the next item to be read. # This will also track the offset before the next sync marker. - if hasattr(self._reader, 'object_position'): + if hasattr(self._reader, "object_position"): if self.block_count == 0: # the next event to be read is at index 0 in the new chunk of blocks, self.reader.track_object_position() diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile_async.py index 85dc5cb582b3..dfba76113133 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/datafile_async.py @@ -24,7 +24,7 @@ # Constants # Codecs supported by container files: -VALID_CODECS = frozenset(['null']) +VALID_CODECS = frozenset(["null"]) class AsyncDataFileReader(object): # pylint: disable=too-many-instance-attributes @@ -39,9 +39,10 @@ def __init__(self, reader, datum_reader, **kwargs): """ self._reader = reader self._raw_decoder = avro_io_async.AsyncBinaryDecoder(reader) - self._header_reader = kwargs.pop('header_reader', None) - self._header_decoder = None if self._header_reader is None else \ - avro_io_async.AsyncBinaryDecoder(self._header_reader) + self._header_reader = kwargs.pop("header_reader", None) + self._header_decoder = ( + None if self._header_reader is None else avro_io_async.AsyncBinaryDecoder(self._header_reader) + ) self._datum_decoder = None # Maybe reset at every block. self._datum_reader = datum_reader self.codec = "null" @@ -59,11 +60,11 @@ async def init(self): await self._read_header() # ensure codec is valid - avro_codec_raw = self.get_meta('avro.codec') + avro_codec_raw = self.get_meta("avro.codec") if avro_codec_raw is None: self.codec = "null" else: - self.codec = avro_codec_raw.decode('utf-8') + self.codec = avro_codec_raw.decode("utf-8") if self.codec not in VALID_CODECS: raise DataFileException(f"Unknown codec: {self.codec}.") @@ -72,7 +73,7 @@ async def init(self): # object_position is to support reading from current position in the future read, # no need to downloading from the beginning of avro. - if hasattr(self._reader, 'object_position'): + if hasattr(self._reader, "object_position"): self.reader.track_object_position() # header_reader indicates reader only has partial content. The reader doesn't have block header, @@ -80,8 +81,7 @@ async def init(self): # Also ChangeFeed only has codec==null, so use _raw_decoder is good. if self._header_reader is not None: self._datum_decoder = self._raw_decoder - self.datum_reader.writer_schema = ( - schema.parse(self.get_meta(SCHEMA_KEY).decode('utf-8'))) + self.datum_reader.writer_schema = schema.parse(self.get_meta(SCHEMA_KEY).decode("utf-8")) return self async def __aenter__(self): @@ -129,7 +129,7 @@ def get_meta(self, key): """Reports the value of a given metadata key. :param str key: Metadata key to report the value of. - :returns: Value associated to the metadata key, as bytes. + :return: Value associated to the metadata key, as bytes. :rtype: bytes """ return self._meta.get(key) @@ -145,15 +145,15 @@ async def _read_header(self): header = await self.datum_reader.read_data(META_SCHEMA, header_decoder) # check magic number - if header.get('magic') != MAGIC: + if header.get("magic") != MAGIC: fail_msg = f"Not an Avro data file: {header.get('magic')} doesn't match {MAGIC!r}." raise schema.AvroException(fail_msg) # set metadata - self._meta = header['meta'] + self._meta = header["meta"] # set sync marker - self._sync_marker = header['sync'] + self._sync_marker = header["sync"] async def _read_block_header(self): self._block_count = await self.raw_decoder.read_long() @@ -182,7 +182,7 @@ async def __anext__(self): # object_position is to support reading from current position in the future read, # no need to downloading from the beginning of avro file with this attr. - if hasattr(self._reader, 'object_position'): + if hasattr(self._reader, "object_position"): await self.reader.track_object_position() self._cur_object_index = 0 @@ -195,7 +195,7 @@ async def __anext__(self): # object_position is to support reading from current position in the future read, # This will track the index of the next item to be read. # This will also track the offset before the next sync marker. - if hasattr(self._reader, 'object_position'): + if hasattr(self._reader, "object_position"): if self.block_count == 0: # the next event to be read is at index 0 in the new chunk of blocks, await self.reader.track_object_position() diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/schema.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/schema.py index d5484abcdd9d..62275c7ad601 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/schema.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/avro/schema.py @@ -29,6 +29,7 @@ import json import logging import re + logger = logging.getLogger(__name__) # ------------------------------------------------------------------------------ @@ -37,43 +38,47 @@ # Log level more verbose than DEBUG=10, INFO=20, etc. DEBUG_VERBOSE = 5 -NULL = 'null' -BOOLEAN = 'boolean' -STRING = 'string' -BYTES = 'bytes' -INT = 'int' -LONG = 'long' -FLOAT = 'float' -DOUBLE = 'double' -FIXED = 'fixed' -ENUM = 'enum' -RECORD = 'record' -ERROR = 'error' -ARRAY = 'array' -MAP = 'map' -UNION = 'union' +NULL = "null" +BOOLEAN = "boolean" +STRING = "string" +BYTES = "bytes" +INT = "int" +LONG = "long" +FLOAT = "float" +DOUBLE = "double" +FIXED = "fixed" +ENUM = "enum" +RECORD = "record" +ERROR = "error" +ARRAY = "array" +MAP = "map" +UNION = "union" # Request and error unions are part of Avro protocols: -REQUEST = 'request' -ERROR_UNION = 'error_union' - -PRIMITIVE_TYPES = frozenset([ - NULL, - BOOLEAN, - STRING, - BYTES, - INT, - LONG, - FLOAT, - DOUBLE, -]) - -NAMED_TYPES = frozenset([ - FIXED, - ENUM, - RECORD, - ERROR, -]) +REQUEST = "request" +ERROR_UNION = "error_union" + +PRIMITIVE_TYPES = frozenset( + [ + NULL, + BOOLEAN, + STRING, + BYTES, + INT, + LONG, + FLOAT, + DOUBLE, + ] +) + +NAMED_TYPES = frozenset( + [ + FIXED, + ENUM, + RECORD, + ERROR, + ] +) VALID_TYPES = frozenset.union( PRIMITIVE_TYPES, @@ -87,31 +92,37 @@ ], ) -SCHEMA_RESERVED_PROPS = frozenset([ - 'type', - 'name', - 'namespace', - 'fields', # Record - 'items', # Array - 'size', # Fixed - 'symbols', # Enum - 'values', # Map - 'doc', -]) - -FIELD_RESERVED_PROPS = frozenset([ - 'default', - 'name', - 'doc', - 'order', - 'type', -]) - -VALID_FIELD_SORT_ORDERS = frozenset([ - 'ascending', - 'descending', - 'ignore', -]) +SCHEMA_RESERVED_PROPS = frozenset( + [ + "type", + "name", + "namespace", + "fields", # Record + "items", # Array + "size", # Fixed + "symbols", # Enum + "values", # Map + "doc", + ] +) + +FIELD_RESERVED_PROPS = frozenset( + [ + "default", + "name", + "doc", + "order", + "type", + ] +) + +VALID_FIELD_SORT_ORDERS = frozenset( + [ + "ascending", + "descending", + "ignore", + ] +) # ------------------------------------------------------------------------------ @@ -141,12 +152,12 @@ def __init__(self, data_type, other_props=None): other_props: Optional dictionary of additional properties. """ if data_type not in VALID_TYPES: - raise SchemaParseException(f'{data_type!r} is not a valid Avro type.') + raise SchemaParseException(f"{data_type!r} is not a valid Avro type.") # All properties of this schema, as a map: property name -> property value self._props = {} - self._props['type'] = data_type + self._props["type"] = data_type self._type = data_type if other_props: @@ -155,7 +166,7 @@ def __init__(self, data_type, other_props=None): @property def namespace(self): """Returns: the namespace this schema belongs to, if any, or None.""" - return self._props.get('namespace', None) + return self._props.get("namespace", None) @property def type(self): @@ -165,7 +176,7 @@ def type(self): @property def doc(self): """Returns: the documentation associated to this schema, if any, or None.""" - return self._props.get('doc', None) + return self._props.get("doc", None) @property def props(self): @@ -193,20 +204,19 @@ def __str__(self): # Schema types that have names (records, enums, and fixed) must be aware of not # re-defining schemas that are already listed in the parameter names. @abc.abstractmethod - def to_json(self, names): - ... + def to_json(self, names): ... # ------------------------------------------------------------------------------ -_RE_NAME = re.compile(r'[A-Za-z_][A-Za-z0-9_]*') +_RE_NAME = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") _RE_FULL_NAME = re.compile( - r'^' - r'[.]?(?:[A-Za-z_][A-Za-z0-9_]*[.])*' # optional namespace - r'([A-Za-z_][A-Za-z0-9_]*)' # name - r'$' + r"^" + r"[.]?(?:[A-Za-z_][A-Za-z0-9_]*[.])*" # optional namespace + r"([A-Za-z_][A-Za-z0-9_]*)" # name + r"$" ) @@ -222,32 +232,31 @@ def __init__(self, name, namespace=None): """ # Normalize: namespace is always defined as a string, possibly empty. if namespace is None: - namespace = '' + namespace = "" - if '.' in name: + if "." in name: # name is absolute, namespace is ignored: self._fullname = name match = _RE_FULL_NAME.match(self._fullname) if match is None: - raise SchemaParseException( - f'Invalid absolute schema name: {self._fullname!r}.') + raise SchemaParseException(f"Invalid absolute schema name: {self._fullname!r}.") self._name = match.group(1) - self._namespace = self._fullname[:-(len(self._name) + 1)] + self._namespace = self._fullname[: -(len(self._name) + 1)] else: # name is relative, combine with explicit namespace: self._name = name self._namespace = namespace - self._fullname = (self._name - if (not self._namespace) else - f'{self._namespace}.{self._name}') + self._fullname = self._name if (not self._namespace) else f"{self._namespace}.{self._name}" # Validate the fullname: if _RE_FULL_NAME.match(self._fullname) is None: - raise SchemaParseException(f"Invalid schema name {self._fullname!r} inferred from " - f"name {self._name!r} and namespace {self._namespace!r}.") + raise SchemaParseException( + f"Invalid schema name {self._fullname!r} inferred from " + f"name {self._name!r} and namespace {self._namespace!r}." + ) def __eq__(self, other): if not isinstance(other, Name): @@ -302,7 +311,7 @@ def new_with_default_namespace(self, namespace): """Creates a new name tracker from this tracker, but with a new default ns. :param Any namespace: New default namespace to use. - :returns: New name tracker with the specified default namespace. + :return: New name tracker with the specified default namespace. :rtype: Names """ return Names(names=self._names, default_namespace=namespace) @@ -312,7 +321,7 @@ def get_name(self, name, namespace=None): :param Any name: Name to resolve (absolute or relative). :param Optional[Any] namespace: Optional explicit namespace. - :returns: The specified name, resolved according to this tracker. + :return: The specified name, resolved according to this tracker. :rtype: Name """ if namespace is None: @@ -324,7 +333,7 @@ def get_schema(self, name, namespace=None): :param Any name: Name (absolute or relative) of the Avro schema to look up. :param Optional[Any] namespace: Optional explicit namespace. - :returns: The schema with the specified name, if any, or None + :return: The schema with the specified name, if any, or None :rtype: Union[Any, None] """ avro_name = self.get_name(name=name, namespace=namespace) @@ -335,15 +344,15 @@ def prune_namespace(self, properties): if self.default_namespace is None: # I have no default -- no change return properties - if 'namespace' not in properties: + if "namespace" not in properties: # he has no namespace - no change return properties - if properties['namespace'] != self.default_namespace: + if properties["namespace"] != self.default_namespace: # we're different - leave his stuff alone return properties # we each have a namespace and it's redundant. delete his. prunable = properties.copy() - del prunable['namespace'] + del prunable["namespace"] return prunable def register(self, schema): @@ -352,13 +361,11 @@ def register(self, schema): :param Any schema: Named Avro schema to register in this tracker. """ if schema.fullname in VALID_TYPES: - raise SchemaParseException( - f'{schema.fullname} is a reserved type name.') + raise SchemaParseException(f"{schema.fullname} is a reserved type name.") if schema.fullname in self.names: - raise SchemaParseException( - f'Avro name {schema.fullname!r} already exists.') + raise SchemaParseException(f"Avro name {schema.fullname!r} already exists.") - logger.log(DEBUG_VERBOSE, 'Register new name for %r', schema.fullname) + logger.log(DEBUG_VERBOSE, "Register new name for %r", schema.fullname) self._names[schema.fullname] = schema @@ -372,12 +379,12 @@ class NamedSchema(Schema): """ def __init__( - self, - data_type, - name=None, - namespace=None, - names=None, - other_props=None, + self, + data_type, + name=None, + namespace=None, + names=None, + other_props=None, ): """Initializes a new named schema object. @@ -388,16 +395,16 @@ def __init__( names: Tracker to resolve and register Avro names. other_props: Optional map of additional properties of the schema. """ - assert (data_type in NAMED_TYPES), (f'Invalid named type: {data_type!r}') + assert data_type in NAMED_TYPES, f"Invalid named type: {data_type!r}" self._avro_name = names.get_name(name=name, namespace=namespace) super(NamedSchema, self).__init__(data_type, other_props) names.register(self) - self._props['name'] = self.name + self._props["name"] = self.name if self.namespace: - self._props['namespace'] = self.namespace + self._props["namespace"] = self.namespace @property def avro_name(self): @@ -420,7 +427,7 @@ def name_ref(self, names): """Reports this schema name relative to the specified name tracker. :param Any names: Avro name tracker to relativize this schema name against. - :returns: This schema name, relativized against the specified name tracker. + :return: This schema name, relativized against the specified name tracker. :rtype: Any """ if self.namespace == names.default_namespace: @@ -432,8 +439,8 @@ def name_ref(self, names): # Schema types that have names (records, enums, and fixed) must be aware # of not re-defining schemas that are already listed in the parameter names. @abc.abstractmethod - def to_json(self, names): - ... + def to_json(self, names): ... + # ------------------------------------------------------------------------------ @@ -445,15 +452,7 @@ class Field(object): """Representation of the schema of a field in a record.""" def __init__( - self, - data_type, - name, - index, - has_default, - default=_NO_DEFAULT, - order=None, - doc=None, - other_props=None + self, data_type, name, index, has_default, default=_NO_DEFAULT, order=None, doc=None, other_props=None ): """Initializes a new Field object. @@ -468,9 +467,9 @@ def __init__( other_props: """ if (not isinstance(name, str)) or (not name): - raise SchemaParseException(f'Invalid record field name: {name!r}.') + raise SchemaParseException(f"Invalid record field name: {name!r}.") if (order is not None) and (order not in VALID_FIELD_SORT_ORDERS): - raise SchemaParseException(f'Invalid record field order: {order!r}.') + raise SchemaParseException(f"Invalid record field order: {order!r}.") # All properties of this record field: self._props = {} @@ -480,17 +479,17 @@ def __init__( self._props.update(other_props) self._index = index - self._type = self._props['type'] = data_type - self._name = self._props['name'] = name + self._type = self._props["type"] = data_type + self._name = self._props["name"] = name if has_default: - self._props['default'] = default + self._props["default"] = default if order is not None: - self._props['order'] = order + self._props["order"] = order if doc is not None: - self._props['doc'] = doc + self._props["doc"] = doc @property def type(self): @@ -509,7 +508,7 @@ def index(self): @property def default(self): - return self._props['default'] + return self._props["default"] @property def has_default(self): @@ -517,11 +516,11 @@ def has_default(self): @property def order(self): - return self._props.get('order', None) + return self._props.get("order", None) @property def doc(self): - return self._props.get('doc', None) + return self._props.get("doc", None) @property def props(self): @@ -538,7 +537,7 @@ def to_json(self, names=None): if names is None: names = Names() to_dump = self.props.copy() - to_dump['type'] = self.type.to_json(names) + to_dump["type"] = self.type.to_json(names) return to_dump def __eq__(self, that): @@ -563,7 +562,7 @@ def __init__(self, data_type, other_props=None): data_type: Type of the schema to construct. Must be primitive. """ if data_type not in PRIMITIVE_TYPES: - raise AvroException(f'{data_type!r} is not a valid primitive type.') + raise AvroException(f"{data_type!r} is not a valid primitive type.") super(PrimitiveSchema, self).__init__(data_type, other_props=other_props) @property @@ -593,16 +592,16 @@ def __eq__(self, that): class FixedSchema(NamedSchema): def __init__( - self, - name, - namespace, - size, - names=None, - other_props=None, + self, + name, + namespace, + size, + names=None, + other_props=None, ): # Ensure valid ctor args if not isinstance(size, int): - fail_msg = 'Fixed Schema requires a valid integer for size property.' + fail_msg = "Fixed Schema requires a valid integer for size property." raise AvroException(fail_msg) super(FixedSchema, self).__init__( @@ -612,12 +611,12 @@ def __init__( names=names, other_props=other_props, ) - self._props['size'] = size + self._props["size"] = size @property def size(self): """Returns: the size of this fixed schema, in bytes.""" - return self._props['size'] + return self._props["size"] def to_json(self, names=None): if names is None: @@ -636,13 +635,13 @@ def __eq__(self, that): class EnumSchema(NamedSchema): def __init__( - self, - name, - namespace, - symbols, - names=None, - doc=None, - other_props=None, + self, + name, + namespace, + symbols, + names=None, + doc=None, + other_props=None, ): """Initializes a new enumeration schema object. @@ -656,10 +655,8 @@ def __init__( """ symbols = tuple(symbols) symbol_set = frozenset(symbols) - if (len(symbol_set) != len(symbols) - or not all(map(lambda symbol: isinstance(symbol, str), symbols))): - raise AvroException( - f'Invalid symbols for enum schema: {symbols!r}.') + if len(symbol_set) != len(symbols) or not all(map(lambda symbol: isinstance(symbol, str), symbols)): + raise AvroException(f"Invalid symbols for enum schema: {symbols!r}.") super(EnumSchema, self).__init__( data_type=ENUM, @@ -669,14 +666,14 @@ def __init__( other_props=other_props, ) - self._props['symbols'] = symbols + self._props["symbols"] = symbols if doc is not None: - self._props['doc'] = doc + self._props["doc"] = doc @property def symbols(self): """Returns: the symbols defined in this enum.""" - return self._props['symbols'] + return self._props["symbols"] def to_json(self, names=None): if names is None: @@ -709,7 +706,7 @@ def __init__(self, items, other_props=None): other_props=other_props, ) self._items_schema = items - self._props['items'] = items + self._props["items"] = items @property def items(self): @@ -721,7 +718,7 @@ def to_json(self, names=None): names = Names() to_dump = self.props.copy() item_schema = self.items - to_dump['items'] = item_schema.to_json(names) + to_dump["items"] = item_schema.to_json(names) return to_dump def __eq__(self, that): @@ -747,7 +744,7 @@ def __init__(self, values, other_props=None): other_props=other_props, ) self._values_schema = values - self._props['values'] = values + self._props["values"] = values @property def values(self): @@ -758,7 +755,7 @@ def to_json(self, names=None): if names is None: names = Names() to_dump = self.props.copy() - to_dump['values'] = self.values.to_json(names) + to_dump["values"] = self.values.to_json(names) return to_dump def __eq__(self, that): @@ -784,23 +781,21 @@ def __init__(self, schemas): # Validate the schema branches: # All named schema names are unique: - named_branches = tuple( - filter(lambda schema: schema.type in NAMED_TYPES, self._schemas)) + named_branches = tuple(filter(lambda schema: schema.type in NAMED_TYPES, self._schemas)) unique_names = frozenset(map(lambda schema: schema.fullname, named_branches)) if len(unique_names) != len(named_branches): - schemas = ''.join(map(lambda schema: (f'\n\t - {schema}'), self._schemas)) - raise AvroException(f'Invalid union branches with duplicate schema name:{schemas}') + schemas = "".join(map(lambda schema: (f"\n\t - {schema}"), self._schemas)) + raise AvroException(f"Invalid union branches with duplicate schema name:{schemas}") # Types are unique within unnamed schemas, and union is not allowed: - unnamed_branches = tuple( - filter(lambda schema: schema.type not in NAMED_TYPES, self._schemas)) + unnamed_branches = tuple(filter(lambda schema: schema.type not in NAMED_TYPES, self._schemas)) unique_types = frozenset(map(lambda schema: schema.type, unnamed_branches)) if UNION in unique_types: - schemas = ''.join(map(lambda schema: (f'\n\t - {schema}'), self._schemas)) - raise AvroException(f'Invalid union branches contain other unions:{schemas}') + schemas = "".join(map(lambda schema: (f"\n\t - {schema}"), self._schemas)) + raise AvroException(f"Invalid union branches contain other unions:{schemas}") if len(unique_types) != len(unnamed_branches): - schemas = ''.join(map(lambda schema: (f'\n\t - {schema}'), self._schemas)) - raise AvroException(f'Invalid union branches with duplicate type:{schemas}') + schemas = "".join(map(lambda schema: (f"\n\t - {schema}"), self._schemas)) + raise AvroException(f"Invalid union branches with duplicate type:{schemas}") @property def schemas(self): @@ -861,23 +856,22 @@ def _make_field(index, field_desc, names): :param int index: 0-based index of the field in the record. :param Any field_desc: JSON descriptors of a record field. :param Any names: The names for this schema. - :returns: The field schema. + :return: The field schema. :rtype: Field """ field_schema = schema_from_json_data( - json_data=field_desc['type'], + json_data=field_desc["type"], names=names, ) - other_props = ( - dict(filter_keys_out(items=field_desc, keys=FIELD_RESERVED_PROPS))) + other_props = dict(filter_keys_out(items=field_desc, keys=FIELD_RESERVED_PROPS)) return Field( data_type=field_schema, - name=field_desc['name'], + name=field_desc["name"], index=index, - has_default=('default' in field_desc), - default=field_desc.get('default', _NO_DEFAULT), - order=field_desc.get('order', None), - doc=field_desc.get('doc', None), + has_default=("default" in field_desc), + default=field_desc.get("default", _NO_DEFAULT), + order=field_desc.get("order", None), + doc=field_desc.get("doc", None), other_props=other_props, ) @@ -888,7 +882,7 @@ def make_field_list(field_desc_list, names): :param Any field_desc_list: Collection of field JSON descriptors. :param Any names: The names for this schema. - :returns: Field schemas. + :return: Field schemas. :rtype: Field """ for index, field_desc in enumerate(field_desc_list): @@ -900,27 +894,18 @@ def _make_field_map(fields): Guarantees field name unicity. :param Any fields: Iterable of field schema. - :returns: A map of field schemas, indexed by name. + :return: A map of field schemas, indexed by name. :rtype: Dict[Any, Any] """ field_map = {} for field in fields: if field.name in field_map: - raise SchemaParseException( - f'Duplicate record field name {field.name!r}.') + raise SchemaParseException(f"Duplicate record field name {field.name!r}.") field_map[field.name] = field return field_map def __init__( - self, - name, - namespace, - fields=None, - make_fields=None, - names=None, - record_type=RECORD, - doc=None, - other_props=None + self, name, namespace, fields=None, make_fields=None, names=None, record_type=RECORD, doc=None, other_props=None ): """Initializes a new record schema object. @@ -954,8 +939,7 @@ def __init__( other_props=other_props, ) else: - raise SchemaParseException( - f'Invalid record type: {record_type!r}.') + raise SchemaParseException(f"Invalid record type: {record_type!r}.") nested_names = [] if record_type in [RECORD, ERROR]: @@ -973,9 +957,9 @@ def __init__( self._field_map = RecordSchema._make_field_map(self._fields) - self._props['fields'] = fields + self._props["fields"] = fields if doc is not None: - self._props['doc'] = doc + self._props["doc"] = doc @property def fields(self): @@ -999,7 +983,7 @@ def to_json(self, names=None): names.names[self.fullname] = self to_dump = names.prune_namespace(self.props.copy()) - to_dump['fields'] = [f.to_json(names) for f in self.fields] + to_dump["fields"] = [f.to_json(names) for f in self.fields] return to_dump def __eq__(self, that): @@ -1017,7 +1001,7 @@ def filter_keys_out(items, keys): :param Dict[Any, Any] items: Dictionary of items to filter the keys out of. :param Dict[Any, Any] keys: Dictionary of keys to filter the extracted keys against. - :returns: Filtered items. + :return: Filtered items. :rtype: Tuple(Any, Any) """ for key, value in items.items(): @@ -1048,31 +1032,29 @@ def MakeSchema(desc): def _schema_from_json_object(json_object, names): - data_type = json_object.get('type') + data_type = json_object.get("type") if data_type is None: - raise SchemaParseException( - f'Avro schema JSON descriptor has no "type" property: {json_object!r}') + raise SchemaParseException(f'Avro schema JSON descriptor has no "type" property: {json_object!r}') - other_props = dict( - filter_keys_out(items=json_object, keys=SCHEMA_RESERVED_PROPS)) + other_props = dict(filter_keys_out(items=json_object, keys=SCHEMA_RESERVED_PROPS)) if data_type in PRIMITIVE_TYPES: # FIXME should not ignore other properties result = PrimitiveSchema(data_type, other_props=other_props) elif data_type in NAMED_TYPES: - name = json_object.get('name') - namespace = json_object.get('namespace', names.default_namespace) + name = json_object.get("name") + namespace = json_object.get("namespace", names.default_namespace) if data_type == FIXED: - size = json_object.get('size') + size = json_object.get("size") result = FixedSchema(name, namespace, size, names, other_props) elif data_type == ENUM: - symbols = json_object.get('symbols') - doc = json_object.get('doc') + symbols = json_object.get("symbols") + doc = json_object.get("doc") result = EnumSchema(name, namespace, symbols, names, doc, other_props) elif data_type in [RECORD, ERROR]: - field_desc_list = json_object.get('fields', ()) + field_desc_list = json_object.get("fields", ()) def MakeFields(names): return tuple(RecordSchema.make_field_list(field_desc_list, names)) @@ -1083,17 +1065,17 @@ def MakeFields(names): make_fields=MakeFields, names=names, record_type=data_type, - doc=json_object.get('doc'), + doc=json_object.get("doc"), other_props=other_props, ) else: - raise ValueError(f'Internal error: unknown type {data_type!r}.') + raise ValueError(f"Internal error: unknown type {data_type!r}.") elif data_type in VALID_TYPES: # Unnamed, non-primitive Avro type: if data_type == ARRAY: - items_desc = json_object.get('items') + items_desc = json_object.get("items") if items_desc is None: raise SchemaParseException(f'Invalid array schema descriptor with no "items" : {json_object!r}.') result = ArraySchema( @@ -1102,7 +1084,7 @@ def MakeFields(names): ) elif data_type == MAP: - values_desc = json_object.get('values') + values_desc = json_object.get("values") if values_desc is None: raise SchemaParseException(f'Invalid map schema descriptor with no "values" : {json_object!r}.') result = MapSchema( @@ -1111,17 +1093,15 @@ def MakeFields(names): ) elif data_type == ERROR_UNION: - error_desc_list = json_object.get('declared_errors') + error_desc_list = json_object.get("declared_errors") assert error_desc_list is not None - error_schemas = map( - lambda desc: schema_from_json_data(desc, names=names), - error_desc_list) + error_schemas = map(lambda desc: schema_from_json_data(desc, names=names), error_desc_list) result = ErrorUnionSchema(schemas=error_schemas) else: - raise ValueError(f'Internal error: unknown type {data_type!r}.') + raise ValueError(f"Internal error: unknown type {data_type!r}.") else: - raise SchemaParseException(f'Invalid JSON descriptor for an Avro schema: {json_object!r}') + raise SchemaParseException(f"Invalid JSON descriptor for an Avro schema: {json_object!r}") return result @@ -1139,7 +1119,7 @@ def schema_from_json_data(json_data, names=None): :param Any json_data: JSON data representing the descriptor of the Avro schema. :param Any names: Optional tracker for Avro named schemas. - :returns: The Avro schema parsed from the JSON descriptor. + :return: The Avro schema parsed from the JSON descriptor. :rtype: Any """ if names is None: @@ -1148,8 +1128,7 @@ def schema_from_json_data(json_data, names=None): # Select the appropriate parser based on the JSON data type: parser = _JSONDataParserTypeMap.get(type(json_data)) if parser is None: - raise SchemaParseException( - f'Invalid JSON descriptor for an Avro schema: {json_data!r}.') + raise SchemaParseException(f"Invalid JSON descriptor for an Avro schema: {json_data!r}.") return parser(json_data, names=names) @@ -1161,15 +1140,15 @@ def parse(json_string): Raises SchemaParseException if a JSON parsing error is met, or if the JSON descriptor is invalid. :param str json_string: String representation of the JSON descriptor of the schema. - :returns: The parsed schema. + :return: The parsed schema. :rtype: Any """ try: json_data = json.loads(json_string) except Exception as exn: raise SchemaParseException( - f'Error parsing schema from JSON: {json_string!r}. ' - f'Error message: {exn!r}.') from exn + f"Error parsing schema from JSON: {json_string!r}. " f"Error message: {exn!r}." + ) from exn # Initialize the names object names = Names() diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 7de14050b963..217eb2110f15 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -20,7 +20,10 @@ from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline -from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module + HttpTransport, + RequestsTransport, +) from azure.core.pipeline.policies import ( AzureSasCredentialPolicy, ContentDecodePolicy, @@ -73,8 +76,17 @@ def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + credential: Optional[ + Union[ + str, + Dict[str, str], + AzureNamedKeyCredential, + AzureSasCredential, + "AsyncTokenCredential", + TokenCredential, + ] + ] = None, + **kwargs: Any, ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts", {}) @@ -83,12 +95,15 @@ def __init__( if service not in ["blob", "queue", "file-share", "dfs"]: raise ValueError(f"Invalid service: {service}") - service_name = service.split('-')[0] + service_name = service.split("-")[0] account = parsed_url.netloc.split(f".{service_name}.core.") self.account_name = account[0] if len(account) > 1 else None - if not self.account_name and parsed_url.netloc.startswith("localhost") \ - or parsed_url.netloc.startswith("127.0.0.1"): + if ( + not self.account_name + and parsed_url.netloc.startswith("localhost") + or parsed_url.netloc.startswith("127.0.0.1") + ): self._is_localhost = True self.account_name = parsed_url.path.strip("/") @@ -106,7 +121,7 @@ def __init__( secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") if kwargs.get("secondary_hostname"): secondary_hostname = kwargs["secondary_hostname"] - primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') + primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip("/") self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} self._sdk_moniker = f"storage-{service}/{VERSION}" @@ -119,71 +134,76 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): - """ This method is to close the sockets opened by the client. + def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. """ self._client.close() @property - def url(self): + def url(self) -> str: """The full endpoint URL to this entity, including SAS token if used. This could be either the primary endpoint, or the secondary endpoint depending on the current :func:`location_mode`. - :returns: The full endpoint URL to this entity, including SAS token if used. + :return: The full endpoint URL to this entity, including SAS token if used. :rtype: str """ - return self._format_url(self._hosts[self._location_mode]) + return self._format_url(self._hosts[self._location_mode]) # type: ignore @property - def primary_endpoint(self): + def primary_endpoint(self) -> str: """The full primary endpoint URL. + :return: The full primary endpoint URL. :rtype: str """ - return self._format_url(self._hosts[LocationMode.PRIMARY]) + return self._format_url(self._hosts[LocationMode.PRIMARY]) # type: ignore @property - def primary_hostname(self): + def primary_hostname(self) -> str: """The hostname of the primary endpoint. + :return: The hostname of the primary endpoint. :rtype: str """ return self._hosts[LocationMode.PRIMARY] @property - def secondary_endpoint(self): + def secondary_endpoint(self) -> str: """The full secondary endpoint URL if configured. If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The full secondary endpoint URL. :rtype: str - :raise ValueError: + :raise ValueError: If no secondary endpoint is configured. """ if not self._hosts[LocationMode.SECONDARY]: raise ValueError("No secondary host configured.") - return self._format_url(self._hosts[LocationMode.SECONDARY]) + return self._format_url(self._hosts[LocationMode.SECONDARY]) # type: ignore @property - def secondary_hostname(self): + def secondary_hostname(self) -> Optional[str]: """The hostname of the secondary endpoint. If not available this will be None. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The hostname of the secondary endpoint, or None if not configured. :rtype: Optional[str] """ return self._hosts[LocationMode.SECONDARY] @property - def location_mode(self): + def location_mode(self) -> str: """The location mode that the client is currently using. By default this will be "primary". Options include "primary" and "secondary". + :return: The current location mode. :rtype: str """ @@ -206,11 +226,16 @@ def api_version(self): return self._client._config.version # pylint: disable=protected-access def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -218,7 +243,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): credential = cast(str, credential) query_str += credential.lstrip("?") @@ -228,13 +254,16 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) @@ -268,22 +297,18 @@ def _create_pipeline( config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs) + HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore return config, Pipeline(transport, policies=policies) - def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> Iterator["HttpResponse"]: + def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> Iterator["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An iterator of HttpResponse objects. + :return: An iterator of HttpResponse objects. :rtype: Iterator[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg @@ -292,25 +317,21 @@ def _batch_send( request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), headers={ - 'x-ms-version': self.api_version, - "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) - } + "x-ms-version": self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False), + }, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access body = serialize_batch_body(request.multipart_mixed_info[0], batch_id) @@ -318,9 +339,7 @@ def _batch_send( temp = request.multipart_mixed_info request.multipart_mixed_info = None - pipeline_response = self._pipeline.run( - request, **kwargs - ) + pipeline_response = self._pipeline.run(request, **kwargs) response = pipeline_response.http_response request.multipart_mixed_info = temp @@ -332,8 +351,7 @@ def _batch_send( parts = list(response.parts()) if any(p for p in parts if not 200 <= p.status_code < 300): error = PartialBatchErrorException( - message="There is a partial failure in the batch operation.", - response=response, parts=parts + message="There is a partial failure in the batch operation.", response=response, parts=parts ) raise error return iter(parts) @@ -347,6 +365,7 @@ class TransportWrapper(HttpTransport): by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, transport): self._transport = transport @@ -368,7 +387,9 @@ def __exit__(self, *args): def _format_shared_key_credential( account_name: Optional[str], - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential] + ] = None, ) -> Any: if isinstance(credential, str): if not account_name: @@ -388,8 +409,12 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -411,14 +436,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -438,7 +460,7 @@ def parse_connection_str( def create_configuration(**kwargs: Any) -> StorageConfiguration: - # Backwards compatibility if someone is not passing sdk_moniker + # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" config = StorageConfiguration(**kwargs) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 6186b29db107..f39a57b24943 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -64,18 +64,26 @@ async def __aenter__(self): async def __aexit__(self, *args): await self._client.__aexit__(*args) - async def close(self): - """ This method is to close the sockets opened by the client. + async def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self._client.close() def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -83,7 +91,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): query_str += credential.lstrip("?") # type: ignore [union-attr] credential = None @@ -92,35 +101,40 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[ - Union[AsyncStorageBearerTokenCredentialPolicy, - SharedKeyCredentialPolicy, - AzureSasCredentialPolicy]] = None - if hasattr(credential, 'get_token'): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + Union[AsyncStorageBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy] + ] = None + if hasattr(credential, "get_token"): + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = AsyncStorageBearerTokenCredentialPolicy( - cast(AsyncTokenCredential, credential), audience) + cast(AsyncTokenCredential, credential), audience + ) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): self._credential_policy = AzureSasCredentialPolicy(credential) elif credential is not None: raise TypeError(f"Unsupported credential: {type(credential)}") - config = kwargs.get('_configuration') or create_configuration(**kwargs) - if kwargs.get('_pipeline'): - return config, kwargs['_pipeline'] - transport = kwargs.get('transport') + config = kwargs.get("_configuration") or create_configuration(**kwargs) + if kwargs.get("_pipeline"): + return config, kwargs["_pipeline"] + transport = kwargs.get("transport") kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) if not transport: try: - from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + AioHttpTransport, + ) except ImportError as exc: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc transport = AioHttpTransport(**kwargs) @@ -143,53 +157,41 @@ def _create_pipeline( HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore - config.transport = transport #type: ignore - return config, AsyncPipeline(transport, policies=policies) #type: ignore + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, AsyncPipeline(transport, policies=policies) # type: ignore - async def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> AsyncList["HttpResponse"]: + async def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> AsyncList["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An AsyncList of HttpResponse objects. + :return: An AsyncList of HttpResponse objects. :rtype: AsyncList[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), - headers={ - 'x-ms-version': self.api_version - } + headers={"x-ms-version": self.api_version}, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) # type: ignore - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) - pipeline_response = await self._pipeline.run( - request, **kwargs - ) + pipeline_response = await self._pipeline.run(request, **kwargs) response = pipeline_response.http_response try: if response.status_code not in [202]: raise HttpResponseError(response=response) - parts = response.parts() # Return an AsyncIterator + parts = response.parts() # Return an AsyncIterator if raise_on_any_failure: parts_list = [] async for part in parts: @@ -197,7 +199,8 @@ async def _batch_send( if any(p for p in parts_list if not 200 <= p.status_code < 300): error = PartialBatchErrorException( message="There is a partial failure in the batch operation.", - response=response, parts=parts_list + response=response, + parts=parts_list, ) raise error return AsyncList(parts_list) @@ -205,11 +208,16 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) + def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -231,14 +239,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -256,11 +261,13 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential + class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, async_transport): self._transport = async_transport diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/constants.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/constants.py index 0b4b029a2d1b..0926f04c4081 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/constants.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/constants.py @@ -16,4 +16,4 @@ DEFAULT_OAUTH_SCOPE = "/.default" STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default" -SERVICE_HOST_BASE = 'core.windows.net' +SERVICE_HOST_BASE = "core.windows.net" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py index d78cd9113133..a446a5f9a514 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py @@ -22,6 +22,7 @@ def get_enum_value(value): class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Error codes returned by the service.""" # Generic storage values ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists" @@ -172,26 +173,26 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed" # DataLake values - CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero' - PATH_ALREADY_EXISTS = 'PathAlreadyExists' - INVALID_FLUSH_POSITION = 'InvalidFlushPosition' - INVALID_PROPERTY_NAME = 'InvalidPropertyName' - INVALID_SOURCE_URI = 'InvalidSourceUri' - UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion' - FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound' - PATH_NOT_FOUND = 'PathNotFound' - RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound' - SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound' - DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted' - FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists' - FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted' - INVALID_DESTINATION_PATH = 'InvalidDestinationPath' - INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath' - INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType' - LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken' - LEASE_NAME_MISMATCH = 'LeaseNameMismatch' - PATH_CONFLICT = 'PathConflict' - SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted' + CONTENT_LENGTH_MUST_BE_ZERO = "ContentLengthMustBeZero" + PATH_ALREADY_EXISTS = "PathAlreadyExists" + INVALID_FLUSH_POSITION = "InvalidFlushPosition" + INVALID_PROPERTY_NAME = "InvalidPropertyName" + INVALID_SOURCE_URI = "InvalidSourceUri" + UNSUPPORTED_REST_VERSION = "UnsupportedRestVersion" + FILE_SYSTEM_NOT_FOUND = "FilesystemNotFound" + PATH_NOT_FOUND = "PathNotFound" + RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = "RenameDestinationParentPathNotFound" + SOURCE_PATH_NOT_FOUND = "SourcePathNotFound" + DESTINATION_PATH_IS_BEING_DELETED = "DestinationPathIsBeingDeleted" + FILE_SYSTEM_ALREADY_EXISTS = "FilesystemAlreadyExists" + FILE_SYSTEM_BEING_DELETED = "FilesystemBeingDeleted" + INVALID_DESTINATION_PATH = "InvalidDestinationPath" + INVALID_RENAME_SOURCE_PATH = "InvalidRenameSourcePath" + INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = "InvalidSourceOrDestinationResourceType" + LEASE_IS_ALREADY_BROKEN = "LeaseIsAlreadyBroken" + LEASE_NAME_MISMATCH = "LeaseNameMismatch" + PATH_CONFLICT = "PathConflict" + SOURCE_PATH_IS_BEING_DELETED = "SourcePathIsBeingDeleted" class DictMixin(object): @@ -222,7 +223,7 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')}) + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) def __contains__(self, key): return key in self.__dict__ @@ -234,13 +235,13 @@ def update(self, *args, **kwargs): return self.__dict__.update(*args, **kwargs) def keys(self): - return [k for k in self.__dict__ if not k.startswith('_')] + return [k for k in self.__dict__ if not k.startswith("_")] def values(self): - return [v for k, v in self.__dict__.items() if not k.startswith('_')] + return [v for k, v in self.__dict__.items() if not k.startswith("_")] def items(self): - return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')] + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] def get(self, key, default=None): if key in self.__dict__: @@ -255,8 +256,8 @@ class LocationMode(object): must use PRIMARY. """ - PRIMARY = 'primary' #: Requests should be sent to the primary location. - SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible. + PRIMARY = "primary" #: Requests should be sent to the primary location. + SECONDARY = "secondary" #: Requests should be sent to the secondary location, if possible. class ResourceTypes(object): @@ -281,17 +282,12 @@ class ResourceTypes(object): _str: str def __init__( - self, - service: bool = False, - container: bool = False, - object: bool = False # pylint: disable=redefined-builtin + self, service: bool = False, container: bool = False, object: bool = False # pylint: disable=redefined-builtin ) -> None: self.service = service self.container = container self.object = object - self._str = (('s' if self.service else '') + - ('c' if self.container else '') + - ('o' if self.object else '')) + self._str = ("s" if self.service else "") + ("c" if self.container else "") + ("o" if self.object else "") def __str__(self): return self._str @@ -309,9 +305,9 @@ def from_string(cls, string): :return: A ResourceTypes object :rtype: ~azure.storage.blob.ResourceTypes """ - res_service = 's' in string - res_container = 'c' in string - res_object = 'o' in string + res_service = "s" in string + res_container = "c" in string + res_object = "o" in string parsed = cls(res_service, res_container, res_object) parsed._str = string @@ -392,29 +388,30 @@ def __init__( self.write = write self.delete = delete self.delete_previous_version = delete_previous_version - self.permanent_delete = kwargs.pop('permanent_delete', False) + self.permanent_delete = kwargs.pop("permanent_delete", False) self.list = list self.add = add self.create = create self.update = update self.process = process - self.tag = kwargs.pop('tag', False) - self.filter_by_tags = kwargs.pop('filter_by_tags', False) - self.set_immutability_policy = kwargs.pop('set_immutability_policy', False) - self._str = (('r' if self.read else '') + - ('w' if self.write else '') + - ('d' if self.delete else '') + - ('x' if self.delete_previous_version else '') + - ('y' if self.permanent_delete else '') + - ('l' if self.list else '') + - ('a' if self.add else '') + - ('c' if self.create else '') + - ('u' if self.update else '') + - ('p' if self.process else '') + - ('f' if self.filter_by_tags else '') + - ('t' if self.tag else '') + - ('i' if self.set_immutability_policy else '') - ) + self.tag = kwargs.pop("tag", False) + self.filter_by_tags = kwargs.pop("filter_by_tags", False) + self.set_immutability_policy = kwargs.pop("set_immutability_policy", False) + self._str = ( + ("r" if self.read else "") + + ("w" if self.write else "") + + ("d" if self.delete else "") + + ("x" if self.delete_previous_version else "") + + ("y" if self.permanent_delete else "") + + ("l" if self.list else "") + + ("a" if self.add else "") + + ("c" if self.create else "") + + ("u" if self.update else "") + + ("p" if self.process else "") + + ("f" if self.filter_by_tags else "") + + ("t" if self.tag else "") + + ("i" if self.set_immutability_policy else "") + ) def __str__(self): return self._str @@ -432,23 +429,34 @@ def from_string(cls, permission): :return: An AccountSasPermissions object :rtype: ~azure.storage.blob.AccountSasPermissions """ - p_read = 'r' in permission - p_write = 'w' in permission - p_delete = 'd' in permission - p_delete_previous_version = 'x' in permission - p_permanent_delete = 'y' in permission - p_list = 'l' in permission - p_add = 'a' in permission - p_create = 'c' in permission - p_update = 'u' in permission - p_process = 'p' in permission - p_tag = 't' in permission - p_filter_by_tags = 'f' in permission - p_set_immutability_policy = 'i' in permission - parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version, - list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag, - filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy, - permanent_delete=p_permanent_delete) + p_read = "r" in permission + p_write = "w" in permission + p_delete = "d" in permission + p_delete_previous_version = "x" in permission + p_permanent_delete = "y" in permission + p_list = "l" in permission + p_add = "a" in permission + p_create = "c" in permission + p_update = "u" in permission + p_process = "p" in permission + p_tag = "t" in permission + p_filter_by_tags = "f" in permission + p_set_immutability_policy = "i" in permission + parsed = cls( + read=p_read, + write=p_write, + delete=p_delete, + delete_previous_version=p_delete_previous_version, + list=p_list, + add=p_add, + create=p_create, + update=p_update, + process=p_process, + tag=p_tag, + filter_by_tags=p_filter_by_tags, + set_immutability_policy=p_set_immutability_policy, + permanent_delete=p_permanent_delete, + ) return parsed @@ -464,18 +472,11 @@ class Services(object): Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False. """ - def __init__( - self, *, - blob: bool = False, - queue: bool = False, - fileshare: bool = False - ) -> None: + def __init__(self, *, blob: bool = False, queue: bool = False, fileshare: bool = False) -> None: self.blob = blob self.queue = queue self.fileshare = fileshare - self._str = (('b' if self.blob else '') + - ('q' if self.queue else '') + - ('f' if self.fileshare else '')) + self._str = ("b" if self.blob else "") + ("q" if self.queue else "") + ("f" if self.fileshare else "") def __str__(self): return self._str @@ -493,9 +494,9 @@ def from_string(cls, string): :return: A Services object :rtype: ~azure.storage.blob.Services """ - res_blob = 'b' in string - res_queue = 'q' in string - res_file = 'f' in string + res_blob = "b" in string + res_queue = "q" in string + res_file = "f" in string parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file) parsed._str = string @@ -573,13 +574,13 @@ class StorageConfiguration(Configuration): def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) - self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) + self.max_single_put_size = kwargs.pop("max_single_put_size", 64 * 1024 * 1024) self.copy_polling_interval = 15 - self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) - self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) - self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) - self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) - self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) - self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) - self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) - self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) + self.max_block_size = kwargs.pop("max_block_size", 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop("use_byte_buffer", False) + self.max_page_size = kwargs.pop("max_page_size", 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop("max_single_get_size", 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop("max_chunk_get_size", 4 * 1024 * 1024) + self.max_range_size = kwargs.pop("max_range_size", 4 * 1024 * 1024) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/parser.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/parser.py index 112c1984f4fb..e4fcb8f041ba 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/parser.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/parser.py @@ -12,14 +12,14 @@ def _to_utc_datetime(value: datetime) -> str: - return value.strftime('%Y-%m-%dT%H:%M:%SZ') + return value.strftime("%Y-%m-%dT%H:%M:%SZ") def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: """Converts an RFC 1123 date string to a UTC datetime. :param str rfc_1123: The time and date in RFC 1123 format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not rfc_1123: @@ -33,7 +33,7 @@ def _filetime_to_datetime(filetime: str) -> Optional[datetime]: If parsing MS Filetime fails, tries RFC 1123 as backup. :param str filetime: The time and date in MS filetime format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not filetime: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index ee75cd5a466c..a08fee7afaac 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -28,7 +28,7 @@ HTTPPolicy, NetworkTraceLoggingPolicy, RequestHistory, - SansIOHTTPPolicy + SansIOHTTPPolicy, ) from .authentication import AzureSigningError, StorageHttpChallenge @@ -39,7 +39,7 @@ from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -48,14 +48,14 @@ def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -63,8 +63,8 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): - if settings['hook']: - settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + if settings["hook"]: + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -95,40 +95,39 @@ def is_retry(response, mode): def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False def urljoin(base_url, stub_url): parsed = urlparse(base_url) - parsed = parsed._replace(path=parsed.path + '/' + stub_url) + parsed = parsed._replace(path=parsed.path + "/" + stub_url) return parsed.geturl() class QueueMessagePolicy(SansIOHTTPPolicy): def on_request(self, request): - message_id = request.context.options.pop('queue_message_id', None) + message_id = request.context.options.pop("queue_message_id", None) if message_id: - request.http_request.url = urljoin( - request.http_request.url, - message_id) + request.http_request.url = urljoin(request.http_request.url, message_id) class StorageHeadersPolicy(HeadersPolicy): - request_id_header_name = 'x-ms-client-request-id' + request_id_header_name = "x-ms-client-request-id" def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) - request.http_request.headers['x-ms-date'] = current_time + request.http_request.headers["x-ms-date"] = current_time - custom_id = request.context.options.pop('client_request_id', None) - request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) + custom_id = request.context.options.pop("client_request_id", None) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -153,7 +152,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument super(StorageHosts, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request.context.options['hosts'] = self.hosts + request.context.options["hosts"] = self.hosts parsed_url = urlparse(request.http_request.url) # Detect what location mode we're currently requesting with @@ -163,10 +162,10 @@ def on_request(self, request: "PipelineRequest") -> None: location_mode = key # See if a specific location mode has been specified, and if so, redirect - use_location = request.context.options.pop('use_location', None) + use_location = request.context.options.pop("use_location", None) if use_location: # Lock retries to the specific location - request.context.options['retry_to_secondary'] = False + request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: @@ -175,7 +174,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.url = updated.geturl() location_mode = use_location - request.context.options['location_mode'] = location_mode + request.context.options["location_mode"] = location_mode class StorageLoggingPolicy(NetworkTraceLoggingPolicy): @@ -200,19 +199,19 @@ def on_request(self, request: "PipelineRequest") -> None: try: log_url = http_request.url query_params = http_request.query - if 'sig' in query_params: - log_url = log_url.replace(query_params['sig'], "sig=*****") + if "sig" in query_params: + log_url = log_url.replace(query_params["sig"], "sig=*****") _LOGGER.debug("Request URL: %r", log_url) _LOGGER.debug("Request method: %r", http_request.method) _LOGGER.debug("Request headers:") for header, value in http_request.headers.items(): - if header.lower() == 'authorization': - value = '*****' - elif header.lower() == 'x-ms-copy-source' and 'sig' in value: + if header.lower() == "authorization": + value = "*****" + elif header.lower() == "x-ms-copy-source" and "sig" in value: # take the url apart and scrub away the signed signature scheme, netloc, path, params, query, fragment = urlparse(value) parsed_qs = dict(parse_qsl(query)) - parsed_qs['sig'] = '*****' + parsed_qs["sig"] = "*****" # the SAS needs to be put back together value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) @@ -242,11 +241,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") # We don't want to log binary data if the response is a file. _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) - header = response.http_response.headers.get('content-disposition') + header = response.http_response.headers.get("content-disposition") resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): - filename = header.partition('=')[2] + filename = header.partition("=")[2] _LOGGER.debug("File attachments: %s", filename) elif resp_content_type.endswith("octet-stream"): _LOGGER.debug("Body contains binary data.") @@ -268,11 +267,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") class StorageRequestHook(SansIOHTTPPolicy): def __init__(self, **kwargs): - self._request_callback = kwargs.get('raw_request_hook') + self._request_callback = kwargs.get("raw_request_hook") super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop('raw_request_hook', self._request_callback) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -280,49 +279,50 @@ def on_request(self, request: "PipelineRequest") -> None: class StorageResponseHook(HTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(StorageResponseHook, self).__init__() def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response @@ -332,7 +332,8 @@ class StorageContentValidation(SansIOHTTPPolicy): This will overwrite any headers already defined in the request. """ - header_name = 'Content-MD5' + + header_name = "Content-MD5" def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @@ -342,10 +343,10 @@ def get_content_md5(data): # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. data = data or b"" - md5 = hashlib.md5() # nosec + md5 = hashlib.md5() # nosec if isinstance(data, bytes): md5.update(data) - elif hasattr(data, 'read'): + elif hasattr(data, "read"): pos = 0 try: pos = data.tell() @@ -363,22 +364,25 @@ def get_content_md5(data): return md5.digest() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop('validate_content', False) - if validate_content and request.http_request.method != 'GET': + validate_content = request.context.options.pop("validate_content", False) + if validate_content and request.http_request.method != "GET": computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) request.http_request.headers[self.header_name] = computed_md5 - request.context['validate_content_md5'] = computed_md5 - request.context['validate_content'] = validate_content + request.context["validate_content_md5"] = computed_md5 + request.context["validate_content"] = validate_content def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = request.context.get('validate_content_md5') or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: - raise AzureError(( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'."), - response=response.http_response + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, ) @@ -399,33 +403,41 @@ class StorageRetryPolicy(HTTPPolicy): """Whether the secondary endpoint should be retried.""" def __init__(self, **kwargs: Any) -> None: - self.total_retries = kwargs.pop('retry_total', 10) - self.connect_retries = kwargs.pop('retry_connect', 3) - self.read_retries = kwargs.pop('retry_read', 3) - self.status_retries = kwargs.pop('retry_status', 3) - self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) + self.total_retries = kwargs.pop("retry_total", 10) + self.connect_retries = kwargs.pop("retry_connect", 3) + self.read_retries = kwargs.pop("retry_read", 3) + self.status_retries = kwargs.pop("retry_status", 3) + self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. - :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param Dict[str, Any] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ - if settings['hosts'] and all(settings['hosts'].values()): + if settings["hosts"] and all(settings["hosts"].values()): url = urlparse(request.url) # If there's more than one possible location, retry to the alternative - if settings['mode'] == LocationMode.PRIMARY: - settings['mode'] = LocationMode.SECONDARY + if settings["mode"] == LocationMode.PRIMARY: + settings["mode"] = LocationMode.SECONDARY else: - settings['mode'] = LocationMode.PRIMARY - updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) + settings["mode"] = LocationMode.PRIMARY + updated = url._replace(netloc=settings["hosts"].get(settings["mode"])) request.url = updated.geturl() def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: + """ + Configure the retry settings for the request. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A dictionary containing the retry settings. + :rtype: Dict[str, Any] + """ body_position = None - if hasattr(request.http_request.body, 'read'): + if hasattr(request.http_request.body, "read"): try: body_position = request.http_request.body.tell() except (AttributeError, UnsupportedOperation): @@ -433,129 +445,140 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: pass options = request.context.options return { - 'total': options.pop("retry_total", self.total_retries), - 'connect': options.pop("retry_connect", self.connect_retries), - 'read': options.pop("retry_read", self.read_retries), - 'status': options.pop("retry_status", self.status_retries), - 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), - 'mode': options.pop("location_mode", LocationMode.PRIMARY), - 'hosts': options.pop("hosts", None), - 'hook': options.pop("retry_hook", None), - 'body_position': body_position, - 'count': 0, - 'history': [] + "total": options.pop("retry_total", self.total_retries), + "connect": options.pop("retry_connect", self.connect_retries), + "read": options.pop("retry_read", self.read_retries), + "status": options.pop("retry_status", self.status_retries), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "mode": options.pop("location_mode", LocationMode.PRIMARY), + "hosts": options.pop("hosts", None), + "hook": options.pop("retry_hook", None), + "body_position": body_position, + "count": 0, + "history": [], } def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument - """ Formula for computing the current backoff. + """Formula for computing the current backoff. Should be calculated by child class. :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. - :returns: The backoff time. + :return: The backoff time. :rtype: float """ return 0 def sleep(self, settings, transport): + """Sleep for the backoff time. + + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. + :param transport: The transport to use for sleeping. + :type transport: + ~azure.core.pipeline.transport.AsyncioBaseTransport or + ~azure.core.pipeline.transport.BaseTransport + """ backoff = self.get_backoff_time(settings) if not backoff or backoff < 0: return transport.sleep(backoff) def increment( - self, settings: Dict[str, Any], + self, + settings: Dict[str, Any], request: "PipelineRequest", response: Optional["PipelineResponse"] = None, - error: Optional[AzureError] = None + error: Optional[AzureError] = None, ) -> bool: """Increment the retry counters. :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. - :param PipelineRequest request: A pipeline request object. - :param Optional[PipelineResponse] response: A pipeline response object. - :param Optional[AzureError] error: An error encountered during the request, or + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse or None + :param error: An error encountered during the request, or None if the response was received successfully. - :returns: Whether the retry attempts are exhausted. + :type error: ~azure.core.exceptions.AzureError or None + :return: Whether the retry attempts are exhausted. :rtype: bool """ - settings['total'] -= 1 + settings["total"] -= 1 if error and isinstance(error, ServiceRequestError): # Errors when we're fairly sure that the server did not receive the # request, so it should be safe to retry. - settings['connect'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["connect"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) elif error and isinstance(error, ServiceResponseError): # Errors that occur after the request has been started, so we should # assume that the server began processing it. - settings['read'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["read"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist if response: - settings['status'] -= 1 - settings['history'].append(RequestHistory(request, http_response=response)) + settings["status"] -= 1 + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): - if request.method not in ['PUT'] and settings['retry_secondary']: + if request.method not in ["PUT"] and settings["retry_secondary"]: self._set_next_host_location(settings, request) # rewind the request body if it is a stream - if request.body and hasattr(request.body, 'read'): + if request.body and hasattr(request.body, "read"): # no position was saved, then retry would not work - if settings['body_position'] is None: + if settings["body_position"] is None: return False try: # attempt to rewind the body to the initial position - request.body.seek(settings['body_position'], SEEK_SET) + request.body.seek(settings["body_position"], SEEK_SET) except (UnsupportedOperation, ValueError): # if body is not seekable, then retry would not work return False - settings['count'] += 1 + settings["count"] += 1 return True return False def send(self, request): + """Send the request with retry logic. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ retries_remaining = True response = None retry_settings = self.configure_retries(request) while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + retry_hook(retry_settings, request=request.http_request, response=None, error=err) self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -571,12 +594,13 @@ class ExponentialRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, initial_backoff: int = 15, + self, + initial_backoff: int = 15, increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -601,21 +625,20 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to get backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -630,11 +653,12 @@ class LinearRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs a Linear retry object. @@ -653,15 +677,14 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float @@ -669,19 +692,27 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + """Handle the challenge from the service and authorize the request. + + :param request: The request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The response object. + :type response: ~azure.core.pipeline.PipelineResponse + :return: True if the request was authorized, False otherwise. + :rtype: bool + """ try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index c44e19ca06ea..4cb32f23248b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -21,7 +21,7 @@ from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -29,29 +29,25 @@ async def retry_hook(settings, **kwargs): - if settings['hook']: - if asyncio.iscoroutine(settings['hook']): - await settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + if settings["hook"]: + if asyncio.iscoroutine(settings["hook"]): + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() # Load the body in memory and close the socket + except (StreamClosedError, StreamConsumedError): + pass + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -59,54 +55,56 @@ async def is_checksum_retry(response): class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(AsyncStorageResponseHook, self).__init__() async def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) # type: ignore + await response_callback(response) # type: ignore else: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response + class AsyncStorageRetryPolicy(StorageRetryPolicy): """ The base class for Exponential and Linear retries containing shared code. @@ -125,37 +123,29 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: await retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) await self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + await retry_hook(retry_settings, request=request.http_request, response=None, error=err) await self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -176,7 +166,8 @@ def __init__( increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, - random_jitter_range: int = 3, **kwargs + random_jitter_range: int = 3, + **kwargs ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -203,21 +194,20 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -232,7 +222,8 @@ class LinearRetry(AsyncStorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, @@ -255,14 +246,13 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. :return: An integer indicating how long to wait before retrying the request, or None to indicate no retry should be performed. @@ -271,14 +261,13 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py index af500c8727fa..b23f65859690 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py @@ -6,7 +6,7 @@ import logging import stat -from io import (SEEK_END, SEEK_SET, UnsupportedOperation) +from io import SEEK_END, SEEK_SET, UnsupportedOperation from os import fstat from typing import Dict, Optional @@ -37,12 +37,13 @@ def serialize_iso(attr): raise OverflowError("Hit max or min date") date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}" - return date + 'Z' + return date + "Z" except (ValueError, OverflowError) as err: raise ValueError("Unable to serialize datetime object.") from err except AttributeError as err: raise TypeError("ISO-8601 object must be valid datetime object.") from err + def get_length(data): length = None # Check if object implements the __len__ method, covers most input cases such as bytearray. @@ -62,7 +63,7 @@ def get_length(data): try: mode = fstat(fileno).st_mode if stat.S_ISREG(mode) or stat.S_ISLNK(mode): - #st_size only meaningful if regular file or symlink, other types + # st_size only meaningful if regular file or symlink, other types # e.g. sockets may return misleading sizes like 0 return fstat(fileno).st_size except OSError: @@ -84,13 +85,13 @@ def get_length(data): def read_length(data): try: - if hasattr(data, 'read'): - read_data = b'' + if hasattr(data, "read"): + read_data = b"" for chunk in iter(lambda: data.read(4096), b""): read_data += chunk return len(read_data), read_data - if hasattr(data, '__iter__'): - read_data = b'' + if hasattr(data, "__iter__"): + read_data = b"" for chunk in data: read_data += chunk return len(read_data), read_data @@ -100,8 +101,13 @@ def read_length(data): def validate_and_format_range_headers( - start_range, end_range, start_range_required=True, - end_range_required=True, check_content_md5=False, align_to_page=False): + start_range, + end_range, + start_range_required=True, + end_range_required=True, + check_content_md5=False, + align_to_page=False, +): # If end range is provided, start range must be provided if (start_range_required or end_range is not None) and start_range is None: raise ValueError("start_range value cannot be None.") @@ -111,16 +117,18 @@ def validate_and_format_range_headers( # Page ranges must be 512 aligned if align_to_page: if start_range is not None and start_range % 512 != 0: - raise ValueError(f"Invalid page blob start_range: {start_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob start_range: {start_range}. " "The size must be aligned to a 512-byte boundary." + ) if end_range is not None and end_range % 512 != 511: - raise ValueError(f"Invalid page blob end_range: {end_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob end_range: {end_range}. " "The size must be aligned to a 512-byte boundary." + ) # Format based on whether end_range is present range_header = None if end_range is not None: - range_header = f'bytes={start_range}-{end_range}' + range_header = f"bytes={start_range}-{end_range}" elif start_range is not None: range_header = f"bytes={start_range}-" @@ -131,7 +139,7 @@ def validate_and_format_range_headers( raise ValueError("Both start and end range required for MD5 content validation.") if end_range - start_range > 4 * 1024 * 1024: raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") - range_validation = 'true' + range_validation = "true" return range_header, range_validation @@ -140,7 +148,7 @@ def add_metadata_headers(metadata: Optional[Dict[str, str]] = None) -> Dict[str, headers = {} if metadata: for key, value in metadata.items(): - headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value + headers[f"x-ms-meta-{key.strip()}"] = value.strip() if value else value return headers @@ -158,29 +166,26 @@ def serialize_batch_body(requests, batch_id): a list of sub-request for the batch request :param str batch_id: to be embedded in batch sub-request delimiter - :returns: The body bytes for this batch. + :return: The body bytes for this batch. :rtype: bytes """ if requests is None or len(requests) == 0: - raise ValueError('Please provide sub-request(s) for this batch request') + raise ValueError("Please provide sub-request(s) for this batch request") - delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8') - newline_bytes = _HTTP_LINE_ENDING.encode('utf-8') + delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode("utf-8") + newline_bytes = _HTTP_LINE_ENDING.encode("utf-8") batch_body = [] content_index = 0 for request in requests: - request.headers.update({ - "Content-ID": str(content_index), - "Content-Length": str(0) - }) + request.headers.update({"Content-ID": str(content_index), "Content-Length": str(0)}) batch_body.append(delimiter_bytes) batch_body.append(_make_body_from_sub_request(request)) batch_body.append(newline_bytes) content_index += 1 - batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8')) + batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode("utf-8")) # final line of body MUST have \r\n at the end, or it will not be properly read by the service batch_body.append(newline_bytes) @@ -197,35 +202,35 @@ def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_da Whether to include the starting dashes. Used in the body, but non on defining the delimiter. :param bool is_append_dashes: Whether to include the ending dashes. Used in the body on the closing delimiter only. - :returns: The delimiter, WITHOUT a trailing newline. + :return: The delimiter, WITHOUT a trailing newline. :rtype: str """ - prepend_dashes = '--' if is_prepend_dashes else '' - append_dashes = '--' if is_append_dashes else '' + prepend_dashes = "--" if is_prepend_dashes else "" + append_dashes = "--" if is_append_dashes else "" return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes def _make_body_from_sub_request(sub_request): """ - Content-Type: application/http - Content-ID: - Content-Transfer-Encoding: (if present) + Content-Type: application/http + Content-ID: + Content-Transfer-Encoding: (if present) - HTTP/ -
:
(repeated as necessary) - Content-Length: - (newline if content length > 0) - (if content length > 0) + HTTP/ +
:
(repeated as necessary) + Content-Length: + (newline if content length > 0) + (if content length > 0) - Serializes an http request. + Serializes an http request. - :param ~azure.core.pipeline.transport.HttpRequest sub_request: - Request to serialize. - :returns: The serialized sub-request in bytes - :rtype: bytes - """ + :param ~azure.core.pipeline.transport.HttpRequest sub_request: + Request to serialize. + :return: The serialized sub-request in bytes + :rtype: bytes + """ # put the sub-request's headers into a list for efficient str concatenation sub_request_body = [] @@ -249,9 +254,9 @@ def _make_body_from_sub_request(sub_request): # append HTTP verb and path and query and HTTP version sub_request_body.append(sub_request.method) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(sub_request.url) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(_HTTP1_1_IDENTIFIER) sub_request_body.append(_HTTP_LINE_ENDING) @@ -266,4 +271,4 @@ def _make_body_from_sub_request(sub_request): # append blank line sub_request_body.append(_HTTP_LINE_ENDING) - return ''.join(sub_request_body).encode() + return "".join(sub_request_body).encode() diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py index af9a2fcdcdc2..bcfa4147763e 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/response_handlers.py @@ -46,23 +46,23 @@ def parse_length_from_content_range(content_range): # First, split in space and take the second half: '1-3/65537' # Next, split on slash and take the second half: '65537' # Finally, convert to an int: 65537 - return int(content_range.split(' ', 1)[1].split('/', 1)[1]) + return int(content_range.split(" ", 1)[1].split("/", 1)[1]) def normalize_headers(headers): normalized = {} for key, value in headers.items(): - if key.startswith('x-ms-'): + if key.startswith("x-ms-"): key = key[5:] - normalized[key.lower().replace('-', '_')] = get_enum_value(value) + normalized[key.lower().replace("-", "_")] = get_enum_value(value) return normalized def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument try: - raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith("x-ms-meta-")} except AttributeError: - raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith("x-ms-meta-")} return {k[10:]: v for k, v in raw_metadata.items()} @@ -82,19 +82,23 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches raise_error = HttpResponseError serialized = False if isinstance(storage_error, AzureSigningError): - storage_error.message = storage_error.message + \ - '. This is likely due to an invalid shared key. Please check your shared key and try again.' + storage_error.message = ( + storage_error.message + + ". This is likely due to an invalid shared key. Please check your shared key and try again." + ) if not storage_error.response or storage_error.response.status_code in [200, 204]: raise storage_error # If it is one of those three then it has been serialized prior by the generated layer. - if isinstance(storage_error, (PartialBatchErrorException, - ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)): + if isinstance( + storage_error, + (PartialBatchErrorException, ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError), + ): serialized = True - error_code = storage_error.response.headers.get('x-ms-error-code') + error_code = storage_error.response.headers.get("x-ms-error-code") error_message = storage_error.message additional_data = {} error_dict = {} @@ -104,27 +108,25 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py if error_body is None or len(error_body) == 0: error_body = storage_error.response.reason except AttributeError: - error_body = '' + error_body = "" # If it is an XML response if isinstance(error_body, Element): - error_dict = { - child.tag.lower(): child.text - for child in error_body - } + error_dict = {child.tag.lower(): child.text for child in error_body} # If it is a JSON response elif isinstance(error_body, dict): - error_dict = error_body.get('error', {}) + error_dict = error_body.get("error", {}) elif not error_code: _LOGGER.warning( - 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) - error_dict = {'message': str(error_body)} + "Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.", type(error_body) + ) + error_dict = {"message": str(error_body)} # If we extracted from a Json or XML response # There is a chance error_dict is just a string if error_dict and isinstance(error_dict, dict): - error_code = error_dict.get('code') - error_message = error_dict.get('message') - additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}} + error_code = error_dict.get("code") + error_message = error_dict.get("message") + additional_data = {k: v for k, v in error_dict.items() if k not in {"code", "message"}} except DecodeError: pass @@ -132,31 +134,33 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py # This check would be unnecessary if we have already serialized the error if error_code and not serialized: error_code = StorageErrorCode(error_code) - if error_code in [StorageErrorCode.condition_not_met, - StorageErrorCode.blob_overwritten]: + if error_code in [StorageErrorCode.condition_not_met, StorageErrorCode.blob_overwritten]: raise_error = ResourceModifiedError - if error_code in [StorageErrorCode.invalid_authentication_info, - StorageErrorCode.authentication_failed]: + if error_code in [StorageErrorCode.invalid_authentication_info, StorageErrorCode.authentication_failed]: raise_error = ClientAuthenticationError - if error_code in [StorageErrorCode.resource_not_found, - StorageErrorCode.cannot_verify_copy_source, - StorageErrorCode.blob_not_found, - StorageErrorCode.queue_not_found, - StorageErrorCode.container_not_found, - StorageErrorCode.parent_not_found, - StorageErrorCode.share_not_found]: + if error_code in [ + StorageErrorCode.resource_not_found, + StorageErrorCode.cannot_verify_copy_source, + StorageErrorCode.blob_not_found, + StorageErrorCode.queue_not_found, + StorageErrorCode.container_not_found, + StorageErrorCode.parent_not_found, + StorageErrorCode.share_not_found, + ]: raise_error = ResourceNotFoundError - if error_code in [StorageErrorCode.account_already_exists, - StorageErrorCode.account_being_created, - StorageErrorCode.resource_already_exists, - StorageErrorCode.resource_type_mismatch, - StorageErrorCode.blob_already_exists, - StorageErrorCode.queue_already_exists, - StorageErrorCode.container_already_exists, - StorageErrorCode.container_being_deleted, - StorageErrorCode.queue_being_deleted, - StorageErrorCode.share_already_exists, - StorageErrorCode.share_being_deleted]: + if error_code in [ + StorageErrorCode.account_already_exists, + StorageErrorCode.account_being_created, + StorageErrorCode.resource_already_exists, + StorageErrorCode.resource_type_mismatch, + StorageErrorCode.blob_already_exists, + StorageErrorCode.queue_already_exists, + StorageErrorCode.container_already_exists, + StorageErrorCode.container_being_deleted, + StorageErrorCode.queue_being_deleted, + StorageErrorCode.share_already_exists, + StorageErrorCode.share_being_deleted, + ]: raise_error = ResourceExistsError except ValueError: # Got an unknown error code @@ -183,7 +187,7 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py error.args = (error.message,) try: # `from None` prevents us from double printing the exception (suppresses generated layer error context) - exec("raise error from None") # pylint: disable=exec-used # nosec + exec("raise error from None") # pylint: disable=exec-used # nosec except SyntaxError as exc: raise error from exc diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/shared_access_signature.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/shared_access_signature.py index 3a0530a58bdb..ac20533d393c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/shared_access_signature.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/shared_access_signature.py @@ -11,44 +11,45 @@ from .constants import X_MS_VERSION from . import sign_string, url_quote + # cspell:ignoreRegExp rsc. # cspell:ignoreRegExp s..?id class QueryStringConstants(object): - SIGNED_SIGNATURE = 'sig' - SIGNED_PERMISSION = 'sp' - SIGNED_START = 'st' - SIGNED_EXPIRY = 'se' - SIGNED_RESOURCE = 'sr' - SIGNED_IDENTIFIER = 'si' - SIGNED_IP = 'sip' - SIGNED_PROTOCOL = 'spr' - SIGNED_VERSION = 'sv' - SIGNED_CACHE_CONTROL = 'rscc' - SIGNED_CONTENT_DISPOSITION = 'rscd' - SIGNED_CONTENT_ENCODING = 'rsce' - SIGNED_CONTENT_LANGUAGE = 'rscl' - SIGNED_CONTENT_TYPE = 'rsct' - START_PK = 'spk' - START_RK = 'srk' - END_PK = 'epk' - END_RK = 'erk' - SIGNED_RESOURCE_TYPES = 'srt' - SIGNED_SERVICES = 'ss' - SIGNED_OID = 'skoid' - SIGNED_TID = 'sktid' - SIGNED_KEY_START = 'skt' - SIGNED_KEY_EXPIRY = 'ske' - SIGNED_KEY_SERVICE = 'sks' - SIGNED_KEY_VERSION = 'skv' - SIGNED_ENCRYPTION_SCOPE = 'ses' - SIGNED_KEY_DELEGATED_USER_TID = 'skdutid' - SIGNED_DELEGATED_USER_OID = 'sduoid' + SIGNED_SIGNATURE = "sig" + SIGNED_PERMISSION = "sp" + SIGNED_START = "st" + SIGNED_EXPIRY = "se" + SIGNED_RESOURCE = "sr" + SIGNED_IDENTIFIER = "si" + SIGNED_IP = "sip" + SIGNED_PROTOCOL = "spr" + SIGNED_VERSION = "sv" + SIGNED_CACHE_CONTROL = "rscc" + SIGNED_CONTENT_DISPOSITION = "rscd" + SIGNED_CONTENT_ENCODING = "rsce" + SIGNED_CONTENT_LANGUAGE = "rscl" + SIGNED_CONTENT_TYPE = "rsct" + START_PK = "spk" + START_RK = "srk" + END_PK = "epk" + END_RK = "erk" + SIGNED_RESOURCE_TYPES = "srt" + SIGNED_SERVICES = "ss" + SIGNED_OID = "skoid" + SIGNED_TID = "sktid" + SIGNED_KEY_START = "skt" + SIGNED_KEY_EXPIRY = "ske" + SIGNED_KEY_SERVICE = "sks" + SIGNED_KEY_VERSION = "skv" + SIGNED_ENCRYPTION_SCOPE = "ses" + SIGNED_KEY_DELEGATED_USER_TID = "skdutid" + SIGNED_DELEGATED_USER_OID = "sduoid" # for ADLS - SIGNED_AUTHORIZED_OID = 'saoid' - SIGNED_UNAUTHORIZED_OID = 'suoid' - SIGNED_CORRELATION_ID = 'scid' - SIGNED_DIRECTORY_DEPTH = 'sdd' + SIGNED_AUTHORIZED_OID = "saoid" + SIGNED_UNAUTHORIZED_OID = "suoid" + SIGNED_CORRELATION_ID = "scid" + SIGNED_DIRECTORY_DEPTH = "sdd" @staticmethod def to_list(): @@ -91,38 +92,30 @@ def to_list(): class SharedAccessSignature(object): - ''' + """ Provides a factory for creating account access signature tokens with an account name and account key. Users can either use the factory or can construct the appropriate service and use the generate_*_shared_access_signature method directly. - ''' + """ def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION): - ''' + """ :param str account_name: The storage account name used to generate the shared access signatures. :param str account_key: The access key to generate the shares access signatures. :param str x_ms_version: The service version used to generate the shared access signatures. - ''' + """ self.account_name = account_name self.account_key = account_key self.x_ms_version = x_ms_version def generate_account( - self, services, - resource_types, - permission, - expiry, - start=None, - ip=None, - protocol=None, - sts_hook=None, - **kwargs + self, services, resource_types, permission, expiry, start=None, ip=None, protocol=None, sts_hook=None, **kwargs ) -> str: - ''' + """ Generates a shared access signature for the account. Use the returned signature with the sas_token parameter of the service or to create a new account object. @@ -169,9 +162,9 @@ def generate_account( For debugging purposes only. If provided, the hook is called with the string to sign that was used to generate the SAS. :type sts_hook: Optional[Callable[[str], None]] - :returns: The generated SAS token for the account. + :return: The generated SAS token for the account. :rtype: str - ''' + """ sas = _SharedAccessHelper() sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version) sas.add_account(services, resource_types) @@ -194,7 +187,7 @@ def _add_query(self, name, val): self.query_dict[name] = str(val) if val is not None else None def add_encryption_scope(self, **kwargs): - self._add_query(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE, kwargs.pop('encryption_scope', None)) + self._add_query(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE, kwargs.pop("encryption_scope", None)) def add_base(self, permission, expiry, start, ip, protocol, x_ms_version): if isinstance(start, date): @@ -220,11 +213,9 @@ def add_account(self, services, resource_types): self._add_query(QueryStringConstants.SIGNED_SERVICES, services) self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types) - def add_override_response_headers(self, cache_control, - content_disposition, - content_encoding, - content_language, - content_type): + def add_override_response_headers( + self, cache_control, content_disposition, content_encoding, content_language, content_type + ): self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control) self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition) self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding) @@ -233,24 +224,25 @@ def add_override_response_headers(self, cache_control, def add_account_signature(self, account_name, account_key): def get_value_to_append(query): - return_value = self.query_dict.get(query) or '' - return return_value + '\n' - - string_to_sign = \ - (account_name + '\n' + - get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + - get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + - get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + - get_value_to_append(QueryStringConstants.SIGNED_START) + - get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + - get_value_to_append(QueryStringConstants.SIGNED_IP) + - get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + - get_value_to_append(QueryStringConstants.SIGNED_VERSION) + - get_value_to_append(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE)) - - self._add_query(QueryStringConstants.SIGNED_SIGNATURE, - sign_string(account_key, string_to_sign)) + return_value = self.query_dict.get(query) or "" + return return_value + "\n" + + string_to_sign = ( + account_name + + "\n" + + get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + + get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + + get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + + get_value_to_append(QueryStringConstants.SIGNED_START) + + get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + + get_value_to_append(QueryStringConstants.SIGNED_IP) + + get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + + get_value_to_append(QueryStringConstants.SIGNED_VERSION) + + get_value_to_append(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE) + ) + + self._add_query(QueryStringConstants.SIGNED_SIGNATURE, sign_string(account_key, string_to_sign)) self.string_to_sign = string_to_sign def get_token(self) -> str: - return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None]) + return "&".join([f"{n}={url_quote(v)}" for n, v in self.query_dict.items() if v is not None]) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads.py index b31cfb3291d9..7a5fb3f3dc91 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads.py @@ -12,7 +12,7 @@ from azure.core.tracing.common import with_current_context -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers @@ -41,20 +41,21 @@ def _parallel_uploads(executor, uploader, pending, running): def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - validate_content=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + validate_content=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -64,7 +65,8 @@ def upload_data_chunks( parallel=parallel, validate_content=validate_content, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: upload_tasks = uploader.get_chunk_streams() @@ -81,18 +83,19 @@ def upload_data_chunks( def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -100,7 +103,8 @@ def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: @@ -120,15 +124,17 @@ def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -253,7 +259,7 @@ def __init__(self, *args, **kwargs): def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) self.service.stage_block( block_id, @@ -261,20 +267,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) return index, block_id def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -302,11 +308,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): pass @@ -326,19 +332,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) def _upload_substream_block(self, index, block_stream): @@ -356,11 +363,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): try: @@ -371,7 +378,7 @@ def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -388,9 +395,9 @@ def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - return f'bytes={chunk_offset}-{chunk_end}', response + return f"bytes={chunk_offset}-{chunk_end}", response # TODO: Implement this method. def _upload_substream_block(self, index, block_stream): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads_async.py index a056cd290230..6ed5ba1d0f91 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/uploads_async.py @@ -12,7 +12,7 @@ from math import ceil from typing import AsyncGenerator, Union -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers from .uploads import SubStream, IterStreamer # pylint: disable=unused-import @@ -59,19 +59,20 @@ async def _parallel_uploads(uploader, pending, running): async def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -80,7 +81,8 @@ async def upload_data_chunks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_chunk_streams() @@ -104,18 +106,19 @@ async def upload_data_chunks( async def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -123,13 +126,13 @@ async def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_substream_blocks() running_futures = [ - asyncio.ensure_future(uploader.process_substream_block(u)) - for u in islice(upload_tasks, 0, max_concurrency) + asyncio.ensure_future(uploader.process_substream_block(u)) for u in islice(upload_tasks, 0, max_concurrency) ] range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures) else: @@ -144,15 +147,17 @@ async def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -178,7 +183,7 @@ def __init__( async def get_chunk_streams(self): index = 0 while True: - data = b'' + data = b"" read_size = self.chunk_size # Buffer until we either reach the end of the stream or get a whole chunk. @@ -189,12 +194,12 @@ async def get_chunk_streams(self): if inspect.isawaitable(temp): temp = await temp if not isinstance(temp, bytes): - raise TypeError('Blob data should be of type bytes.') + raise TypeError("Blob data should be of type bytes.") data += temp or b"" # We have read an empty string and so are at the end # of the buffer or we have read a full chunk. - if temp == b'' or len(data) == self.chunk_size: + if temp == b"" or len(data) == self.chunk_size: break if len(data) == self.chunk_size: @@ -273,13 +278,13 @@ def set_response_properties(self, resp): class BlockBlobChunkUploader(_ChunkUploader): def __init__(self, *args, **kwargs): - kwargs.pop('modified_access_conditions', None) + kwargs.pop("modified_access_conditions", None) super(BlockBlobChunkUploader, self).__init__(*args, **kwargs) self.current_length = None async def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) await self.service.stage_block( block_id, @@ -287,19 +292,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): body=chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) return index, block_id async def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" await self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) finally: block_stream.close() return block_id @@ -311,7 +318,7 @@ def _is_chunk_empty(self, chunk_data): # read until non-zero byte is encountered # if reached the end without returning, then chunk_data is all 0's for each_byte in chunk_data: - if each_byte not in [0, b'\x00']: + if each_byte not in [0, b"\x00"]: return False return True @@ -319,7 +326,7 @@ async def _upload_chunk(self, chunk_offset, chunk_data): # avoid uploading the empty pages if not self._is_chunk_empty(chunk_data): chunk_end = chunk_offset + len(chunk_data) - 1 - content_range = f'bytes={chunk_offset}-{chunk_end}' + content_range = f"bytes={chunk_offset}-{chunk_end}" computed_md5 = None self.response_headers = await self.service.upload_pages( body=chunk_data, @@ -329,10 +336,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): pass @@ -352,18 +360,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) - self.current_length = int(self.response_headers['blob_append_offset']) + **self.request_options, + ) + self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = await self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) async def _upload_substream_block(self, index, block_stream): pass @@ -379,11 +390,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): try: @@ -394,7 +405,7 @@ async def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -411,9 +422,9 @@ async def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - range_id = f'bytes={chunk_offset}-{chunk_end}' + range_id = f"bytes={chunk_offset}-{chunk_end}" return range_id, response # TODO: Implement this method. @@ -421,10 +432,11 @@ async def _upload_substream_block(self, index, block_stream): pass -class AsyncIterStreamer(): +class AsyncIterStreamer: """ File-like streaming object for AsyncGenerators. """ + def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"): self.iterator = generator.__aiter__() self.leftover = b"" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/__init__.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/__init__.py index a755e6a2d59b..09f8adb73a9a 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/__init__.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/__init__.py @@ -73,7 +73,7 @@ async def upload_blob_to_url( entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. :keyword str encoding: Encoding to use if text is supplied as input. Defaults to UTF-8. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: dict[str, Any] """ async with BlobClient.from_blob_url(blob_url, credential=credential) as client: @@ -102,7 +102,7 @@ async def download_blob_from_url( :param output: Where the data should be downloaded to. This could be either a file path to write to, or an open IO handle to write to. - :type output: str or writable stream + :type output: str or IO :param credential: The credentials with which to authenticate. This is optional if the blob URL already has a SAS token or the blob is public. The value can be a SAS token string, @@ -139,6 +139,7 @@ async def download_blob_from_url( blob. Also note that if enabled, the memory-efficient upload algorithm will not be used, because computing the MD5 hash requires buffering entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :return: None :rtype: None """ overwrite = kwargs.pop('overwrite', False) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index a88fe9980655..84fe1347910f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -228,7 +228,11 @@ def from_blob_url( - except in the case of AzureSasCredential, where the conflicting SAS tokens will raise a ValueError. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :type credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :type credential: + ~azure.core.credentials.AzureNamedKeyCredential or + ~azure.core.credentials.AzureSasCredential or + ~azure.core.credentials_async.AsyncTokenCredential or + str or dict[str, str] or None :param str snapshot: The optional blob snapshot on which to operate. This can be the snapshot ID string or the response returned from :func:`create_snapshot`. If specified, this will override @@ -238,7 +242,7 @@ def from_blob_url( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A Blob client. + :return: A Blob client. :rtype: ~azure.storage.blob.BlobClient """ account_url, container_name, blob_name, path_snapshot = _from_blob_url(blob_url=blob_url, snapshot=snapshot) @@ -276,13 +280,17 @@ def from_connection_string( Credentials provided here will take precedence over those in the connection string. If using an instance of AzureNamedKeyCredential, "name" should be the storage account name, and "key" should be the storage account key. - :type credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] # pylint: disable=line-too-long + :type credential: + ~azure.core.credentials.AzureNamedKeyCredential or + ~azure.core.credentials.AzureSasCredential or + ~azure.core.credentials_async.AsyncTokenCredential or + str or dict[str, str] or None :keyword str version_id: The version id parameter is an opaque DateTime value that, when present, specifies the version of the blob to operate on. :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A Blob client. + :return: A Blob client. :rtype: ~azure.storage.blob.BlobClient .. admonition:: Example: @@ -309,7 +317,7 @@ async def get_account_information(self, **kwargs: Any) -> Dict[str, str]: The information can also be retrieved if the user has a SAS to a container or blob. The keys in the returned dictionary include 'sku_name' and 'account_kind'. - :returns: A dict of account information (SKU and account type). + :return: A dict of account information (SKU and account type). :rtype: dict(str, str) """ try: @@ -430,7 +438,7 @@ async def upload_blob_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Response from creating a new block blob for a given URL. + :return: Response from creating a new block blob for a given URL. :rtype: Dict[str, Any] """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -581,7 +589,7 @@ async def upload_blob( the timeout will apply to each call individually. multiple calls to the Azure service and the timeout will apply to each call individually. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: dict[str, Any] .. admonition:: Example: @@ -726,7 +734,7 @@ async def download_blob( the timeout will apply to each call individually. multiple calls to the Azure service and the timeout will apply to each call individually. - :returns: A streaming object (StorageStreamDownloader) + :return: A streaming object (StorageStreamDownloader) :rtype: ~azure.storage.blob.aio.StorageStreamDownloader .. admonition:: Example: @@ -853,7 +861,7 @@ async def query_blob( This value is not tracked or validated on the client. To configure client-side network timeouts see `here `__. - :returns: A streaming object (BlobQueryReader) + :return: A streaming object (BlobQueryReader) :rtype: ~azure.storage.blob.aio.BlobQueryReader """ error_cls = kwargs.pop("error_cls", BlobQueryError) @@ -905,7 +913,7 @@ async def delete_blob(self, delete_snapshots: Optional[str] = None, **kwargs: An and retains the blob for a specified number of days. After the specified number of days, the blob's data is removed from the service during garbage collection. Soft deleted blob is accessible through :func:`~ContainerClient.list_blobs()` specifying `include=['deleted']` - option. Soft-deleted blob can be restored using :func:`undelete` operation. + option. Soft-deleted blob can be restored using :func:`~BlobClient.undelete_blob()` operation. :param str delete_snapshots: Required if the blob has associated snapshots. Values include: @@ -953,6 +961,7 @@ async def delete_blob(self, delete_snapshots: Optional[str] = None, **kwargs: An This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -991,6 +1000,7 @@ async def undelete_blob(self, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -1022,7 +1032,7 @@ async def exists(self, **kwargs: Any) -> bool: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: boolean + :return: boolean :rtype: bool """ version_id = get_version_id(self.version_id, kwargs) @@ -1092,7 +1102,7 @@ async def get_blob_properties(self, **kwargs: Any) -> BlobProperties: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: BlobProperties + :return: BlobProperties :rtype: ~azure.storage.blob.BlobProperties .. admonition:: Example: @@ -1180,7 +1190,7 @@ async def set_http_headers( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: Dict[str, Any] """ options = _set_http_headers_options(content_settings=content_settings, **kwargs) @@ -1247,7 +1257,7 @@ async def set_blob_metadata( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: Dict[str, Union[str, datetime]] """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -1283,7 +1293,7 @@ async def set_immutability_policy( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, str] """ @@ -1309,7 +1319,7 @@ async def delete_immutability_policy(self, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, str] """ @@ -1334,7 +1344,7 @@ async def set_legal_hold(self, legal_hold: bool, **kwargs: Any) -> Dict[str, Uni This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, Union[str, datetime, bool]] """ @@ -1431,7 +1441,7 @@ async def create_page_blob( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -1527,7 +1537,7 @@ async def create_append_blob( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -1607,7 +1617,7 @@ async def create_snapshot( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Snapshot ID, Etag, and last modified). + :return: Blob-updated property dict (Snapshot ID, Etag, and last modified). :rtype: dict[str, Any] .. admonition:: Example: @@ -1813,7 +1823,7 @@ async def start_copy_from_url( .. versionadded:: 12.10.0 - :returns: A dictionary of copy properties (etag, last_modified, copy_id, copy_status). + :return: A dictionary of copy properties (etag, last_modified, copy_id, copy_status). :rtype: dict[str, Union[str, ~datetime.datetime]] .. admonition:: Example: @@ -1852,6 +1862,7 @@ async def abort_copy( The copy operation to abort. This can be either an ID, or an instance of BlobProperties. :type copy_id: str or ~azure.storage.blob.BlobProperties + :return: None :rtype: None .. admonition:: Example: @@ -1918,7 +1929,7 @@ async def acquire_lease( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A BlobLeaseClient object. + :return: A BlobLeaseClient object. :rtype: ~azure.storage.blob.aio.BlobLeaseClient .. admonition:: Example: @@ -1967,6 +1978,7 @@ async def set_standard_blob_tier(self, standard_blob_tier: Union[str, "StandardB Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.aio.BlobLeaseClient or str + :return: None :rtype: None """ access_conditions = get_access_conditions(kwargs.pop('lease', None)) @@ -2034,7 +2046,7 @@ async def stage_block( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob property dict. + :return: Blob property dict. :rtype: Dict[str, Any] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -2109,7 +2121,7 @@ async def stage_block_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Blob property dict. + :return: Blob property dict. :rtype: Dict[str, Any] """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -2155,7 +2167,7 @@ async def get_block_list( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A tuple of two lists - committed and uncommitted blocks + :return: A tuple of two lists - committed and uncommitted blocks :rtype: Tuple[List[BlobBlock], List[BlobBlock]] """ access_conditions = get_access_conditions(kwargs.pop('lease', None)) @@ -2268,7 +2280,7 @@ async def commit_block_list( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -2310,6 +2322,7 @@ async def set_premium_page_blob_tier(self, premium_page_blob_tier: "PremiumPageB Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.aio.BlobLeaseClient or str + :return: None :rtype: None """ access_conditions = get_access_conditions(kwargs.pop('lease', None)) @@ -2365,7 +2378,7 @@ async def set_blob_tags(self, tags: Optional[Dict[str, str]] = None, **kwargs: A This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified) + :return: Blob-updated property dict (Etag and last modified) :rtype: Dict[str, Any] """ version_id = get_version_id(self.version_id, kwargs) @@ -2398,7 +2411,7 @@ async def get_blob_tags(self, **kwargs: Any) -> Dict[str, str]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Key value pairs of blob tags. + :return: Key value pairs of blob tags. :rtype: Dict[str, str] """ version_id = get_version_id(self.version_id, kwargs) @@ -2470,7 +2483,7 @@ async def get_page_ranges( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: + :return: A tuple of two lists of page ranges as dictionaries with 'start' and 'end' keys. The first element are filled page ranges, the 2nd element is cleared page ranges. :rtype: tuple(list(dict(str, str), list(dict(str, str)) @@ -2563,7 +2576,7 @@ def list_page_ranges( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) of PageRange. + :return: An iterable (auto-paging) of PageRange. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.PageRange] """ results_per_page = kwargs.pop('results_per_page', None) @@ -2646,7 +2659,7 @@ async def get_page_range_diff_for_managed_disk( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: + :return: A tuple of two lists of page ranges as dictionaries with 'start' and 'end' keys. The first element are filled page ranges, the 2nd element is cleared page ranges. :rtype: tuple(list(dict(str, str), list(dict(str, str)) @@ -2711,7 +2724,7 @@ async def set_sequence_number( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ options = _set_sequence_number_options(sequence_number_action, sequence_number=sequence_number, **kwargs) @@ -2767,7 +2780,7 @@ async def resize_blob(self, size: int, **kwargs: Any) -> Dict[str, Union[str, da This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if kwargs.get('cpk') and self.scheme.lower() != 'https': @@ -2863,7 +2876,7 @@ async def upload_page( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -2994,7 +3007,7 @@ async def upload_pages_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Response after uploading pages from specified URL. + :return: Response after uploading pages from specified URL. :rtype: Dict[str, Any] """ @@ -3075,7 +3088,7 @@ async def clear_page(self, offset: int, length: int, **kwargs: Any) -> Dict[str, This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag and last modified). + :return: Blob-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3172,7 +3185,7 @@ async def append_block( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag, last modified, append offset, committed block count). + :return: Blob-updated property dict (Etag, last modified, append offset, committed block count). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3296,7 +3309,7 @@ async def append_block_from_url( ACLs are bypassed and full permissions are granted. User must also have required RBAC permission. :paramtype source_token_intent: Literal['backup'] - :returns: Result after appending a new block. + :return: Result after appending a new block. :rtype: Dict[str, Union[str, datetime, int]] """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3354,7 +3367,7 @@ async def seal_append_blob(self, **kwargs: Any) -> Dict[str, Union[str, datetime This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Blob-updated property dict (Etag, last modified, append offset, committed block count). + :return: Blob-updated property dict (Etag, last modified, append offset, committed block count). :rtype: dict(str, Any) """ if self.require_encryption or (self.key_encryption_key is not None): @@ -3370,7 +3383,7 @@ def _get_container_client(self) -> "ContainerClient": The container need not already exist. Defaults to current blob's credentials. - :returns: A ContainerClient. + :return: A ContainerClient. :rtype: ~azure.storage.blob.ContainerClient .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py index 8f76aa98c8cf..0f0f783f192f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_service_client_async.py @@ -145,7 +145,7 @@ def _format_url(self, hostname): :param str hostname: The hostname of the current location mode. - :returns: A formatted endpoint URL including current location mode hostname. + :return: A formatted endpoint URL including current location mode hostname. :rtype: str """ return f"{self.scheme}://{hostname}/{self._query_str}" @@ -177,7 +177,7 @@ def from_connection_string( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A Blob service client. + :return: A Blob service client. :rtype: ~azure.storage.blob.BlobServiceClient .. admonition:: Example: @@ -235,7 +235,7 @@ async def get_account_information(self, **kwargs: Any) -> Dict[str, str]: The information can also be retrieved if the user has a SAS to a container or blob. The keys in the returned dictionary include 'sku_name' and 'account_kind'. - :returns: A dict of account information (SKU and account type). + :return: A dict of account information (SKU and account type). :rtype: dict(str, str) .. admonition:: Example: @@ -309,7 +309,7 @@ async def get_service_properties(self, **kwargs: Any) -> Dict[str, Any]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An object containing blob service properties such as + :return: An object containing blob service properties such as analytics logging, hour/minute metrics, cors rules, etc. :rtype: Dict[str, Any] @@ -379,6 +379,7 @@ async def set_service_properties( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -443,7 +444,7 @@ def list_containers( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) of ContainerProperties. + :return: An iterable (auto-paging) of ContainerProperties. :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.blob.ContainerProperties] .. admonition:: Example: @@ -496,7 +497,7 @@ def find_blobs_by_tags(self, filter_expression: str, **kwargs: Any) -> AsyncItem This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of BlobProperties. + :return: An iterable (auto-paging) response of BlobProperties. :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.blob.FilteredBlob] """ @@ -545,7 +546,7 @@ async def create_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A container client to interact with the newly created container. + :return: A container client to interact with the newly created container. :rtype: ~azure.storage.blob.aio.ContainerClient .. admonition:: Example: @@ -607,6 +608,7 @@ async def delete_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -646,7 +648,7 @@ async def _rename_container(self, name: str, new_name: str, **kwargs: Any) -> Co This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A container client for the renamed container. + :return: A container client for the renamed container. :rtype: ~azure.storage.blob.ContainerClient """ renamed_container = self.get_container_client(new_name) @@ -685,7 +687,7 @@ async def undelete_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: The recovered soft-deleted ContainerClient. + :return: The recovered soft-deleted ContainerClient. :rtype: ~azure.storage.blob.aio.ContainerClient """ new_name = kwargs.pop('new_name', None) @@ -709,7 +711,7 @@ def get_container_client(self, container: Union[ContainerProperties, str]) -> Co The container. This can either be the name of the container, or an instance of ContainerProperties. :type container: str or ~azure.storage.blob.ContainerProperties - :returns: A ContainerClient. + :return: A ContainerClient. :rtype: ~azure.storage.blob.aio.ContainerClient .. admonition:: Example: @@ -760,7 +762,7 @@ def get_blob_client( :type snapshot: str or dict(str, Any) :keyword str version_id: The version id parameter is an opaque DateTime value that, when present, specifies the version of the blob to operate on. - :returns: A BlobClient. + :return: A BlobClient. :rtype: ~azure.storage.blob.aio.BlobClient .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 306e3acf5519..551d2ed723eb 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -186,7 +186,7 @@ def from_container_url( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A container client. + :return: A container client. :rtype: ~azure.storage.blob.ContainerClient """ try: @@ -239,7 +239,7 @@ def from_connection_string( :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. - :returns: A container client. + :return: A container client. :rtype: ~azure.storage.blob.ContainerClient .. admonition:: Example: @@ -286,7 +286,7 @@ async def create_container( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A dictionary of response headers. + :return: A dictionary of response headers. :rtype: Dict[str, Union[str, datetime]] .. admonition:: Example: @@ -331,7 +331,7 @@ async def _rename_container(self, new_name: str, **kwargs: Any) -> "ContainerCli This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: The renamed container. + :return: The renamed container. :rtype: ~azure.storage.blob.ContainerClient """ lease = kwargs.pop('lease', None) @@ -385,6 +385,7 @@ async def delete_container(self, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None .. admonition:: Example: @@ -451,7 +452,7 @@ async def acquire_lease( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: A BlobLeaseClient object, that can be run in a context manager. + :return: A BlobLeaseClient object, that can be run in a context manager. :rtype: ~azure.storage.blob.aio.BlobLeaseClient .. admonition:: Example: @@ -476,7 +477,7 @@ async def get_account_information(self, **kwargs: Any) -> Dict[str, str]: The information can also be retrieved if the user has a SAS to a container or blob. The keys in the returned dictionary include 'sku_name' and 'account_kind'. - :returns: A dict of account information (SKU and account type). + :return: A dict of account information (SKU and account type). :rtype: dict(str, str) """ try: @@ -536,7 +537,7 @@ async def exists(self, **kwargs: Any) -> bool: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: boolean + :return: boolean :rtype: bool """ try: @@ -578,7 +579,7 @@ async def set_container_metadata( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Container-updated property dict (Etag and last modified). + :return: Container-updated property dict (Etag and last modified). :rtype: Dict[str, Union[str, datetime]] .. admonition:: Example: @@ -613,7 +614,7 @@ def _get_blob_service_client(self) -> "BlobServiceClient": Defaults to current container's credentials. - :returns: A BlobServiceClient. + :return: A BlobServiceClient. :rtype: ~azure.storage.blob.BlobServiceClient .. admonition:: Example: @@ -656,7 +657,7 @@ async def get_container_access_policy(self, **kwargs: Any) -> Dict[str, Any]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Access policy information in a dict. + :return: Access policy information in a dict. :rtype: dict[str, Any] .. admonition:: Example: @@ -723,7 +724,7 @@ async def set_container_access_policy( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: Container-updated property dict (Etag and last modified). + :return: Container-updated property dict (Etag and last modified). :rtype: dict[str, str or ~datetime.datetime] .. admonition:: Example: @@ -787,7 +788,7 @@ def list_blobs( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of BlobProperties. + :return: An iterable (auto-paging) response of BlobProperties. :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.blob.BlobProperties] .. admonition:: Example: @@ -840,7 +841,7 @@ def list_blob_names(self, **kwargs: Any) -> AsyncItemPaged[str]: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of blob names as strings. + :return: An iterable (auto-paging) response of blob names as strings. :rtype: ~azure.core.async_paging.AsyncItemPaged[str] """ if kwargs.pop('prefix', None): @@ -873,7 +874,7 @@ def walk_blobs( include: Optional[Union[List[str], str]] = None, delimiter: str = "/", **kwargs: Any - ) -> AsyncItemPaged[BlobProperties]: + ) -> AsyncItemPaged[Union[BlobProperties, BlobPrefix]]: """Returns a generator to list the blobs under the specified container. The generator will lazily follow the continuation tokens returned by the service. This operation will list blobs in accordance with a hierarchy, @@ -898,8 +899,9 @@ def walk_blobs( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of BlobProperties. - :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.blob.BlobProperties] + :return: An iterable (auto-paging) response of BlobProperties. + :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.storage.blob.BlobProperties or + ~azure.storage.blob.aio.BlobPrefix] """ if kwargs.pop('prefix', None): raise ValueError("Passing 'prefix' has no effect on filtering, " + @@ -944,7 +946,7 @@ def find_blobs_by_tags( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. - :returns: An iterable (auto-paging) response of FilteredBlob. + :return: An iterable (auto-paging) response of FilteredBlob. :rtype: ~azure.core.paging.ItemPaged[~azure.storage.blob.BlobProperties] """ results_per_page = kwargs.pop('results_per_page', None) @@ -1071,7 +1073,7 @@ async def upload_blob( function(current: int, total: Optional[int]) where current is the number of bytes transferred so far, and total is the size of the blob or None if the size is unknown. :paramtype progress_hook: Callable[[int, Optional[int]], Awaitable[None]] - :returns: A BlobClient to interact with the newly uploaded blob. + :return: A BlobClient to interact with the newly uploaded blob. :rtype: ~azure.storage.blob.aio.BlobClient .. admonition:: Example: @@ -1121,7 +1123,7 @@ async def delete_blob( and retains the blob or snapshot for specified number of days. After specified number of days, blob's data is removed from the service during garbage collection. Soft deleted blobs or snapshots are accessible through :func:`list_blobs()` specifying `include=["deleted"]` - Soft-deleted blob or snapshot can be restored using :func:`~azure.storage.blob.aio.BlobClient.undelete()` + Soft-deleted blob or snapshot can be restored using :func:`~azure.storage.blob.aio.BlobClient.undelete_blob()` :param str blob: The blob with which to interact. :param str delete_snapshots: @@ -1169,6 +1171,7 @@ async def delete_blob( This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None """ if isinstance(blob, BlobProperties): @@ -1295,7 +1298,7 @@ async def download_blob( the timeout will apply to each call individually. multiple calls to the Azure service and the timeout will apply to each call individually. - :returns: A streaming object. (StorageStreamDownloader) + :return: A streaming object. (StorageStreamDownloader) :rtype: ~azure.storage.blob.aio.StorageStreamDownloader """ if isinstance(blob, BlobProperties): @@ -1327,7 +1330,7 @@ async def delete_blobs( and retains the blobs or snapshots for specified number of days. After specified number of days, blobs' data is removed from the service during garbage collection. Soft deleted blobs or snapshots are accessible through :func:`list_blobs()` specifying `include=["deleted"]` - Soft-deleted blobs or snapshots can be restored using :func:`~azure.storage.blob.aio.BlobClient.undelete()` + Soft-deleted blobs or snapshots can be restored using :func:`~azure.storage.blob.aio.BlobClient.undelete_blob()` The maximum number of blobs that can be deleted in a single request is 256. @@ -1393,7 +1396,7 @@ async def delete_blobs( see `here `__. :return: An async iterator of responses, one for each blob in order - :rtype: asynciterator[~azure.core.pipeline.transport.AsyncHttpResponse] + :rtype: AsyncIterator[~azure.core.pipeline.transport.AsyncHttpResponse] .. admonition:: Example: @@ -1485,7 +1488,7 @@ async def set_standard_blob_tier_blobs( is raised even if there is a single operation failure. For optimal performance, this should be set to False. :return: An async iterator of responses, one for each blob in order - :rtype: asynciterator[~azure.core.pipeline.transport.AsyncHttpResponse] + :rtype: AsyncIterator[~azure.core.pipeline.transport.AsyncHttpResponse] """ if self._is_localhost: kwargs['url_prepend'] = self.account_name @@ -1546,7 +1549,7 @@ async def set_premium_page_blob_tier_blobs( is raised even if there is a single operation failure. For optimal performance, this should be set to False. :return: An async iterator of responses, one for each blob in order - :rtype: asynciterator[~azure.core.pipeline.transport.AsyncHttpResponse] + :rtype: AsyncIterator[~azure.core.pipeline.transport.AsyncHttpResponse] """ if self._is_localhost: kwargs['url_prepend'] = self.account_name @@ -1577,7 +1580,7 @@ def get_blob_client( or the response returned from :func:`~BlobClient.create_snapshot()`. :keyword str version_id: The version id parameter is an opaque DateTime value that, when present, specifies the version of the blob to operate on. - :returns: A BlobClient. + :return: A BlobClient. :rtype: ~azure.storage.blob.aio.BlobClient .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 8a929647e78f..a620883a1a64 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -46,8 +46,10 @@ async def process_content(data: Any, start_offset: int, end_offset: int, encryption: Dict[str, Any]) -> bytes: if data is None: raise ValueError("Response cannot be None.") - await data.response.load_body() - content = cast(bytes, data.response.body()) + if hasattr(data.response, "is_stream_consumed") and data.response.is_stream_consumed: + content = data.response.content + else: + content = b"".join([d async for d in data]) if encryption.get('key') is not None or encryption.get('resolver') is not None: try: return decrypt_blob( @@ -57,12 +59,14 @@ async def process_content(data: Any, start_offset: int, end_offset: int, encrypt content, start_offset, end_offset, - data.response.headers) + data.response.headers + ) except Exception as error: raise HttpResponseError( message="Decryption failed.", response=data.response, - error=error) from error + error=error + ) from error return content @@ -449,7 +453,7 @@ def chunks(self) -> AsyncIterator[bytes]: NOTE: If the stream has been partially read, some data may be re-downloaded by the iterator. - :returns: An async iterator of the chunks in the download stream. + :return: An async iterator of the chunks in the download stream. :rtype: AsyncIterator[bytes] .. admonition:: Example: @@ -523,7 +527,7 @@ async def read(self, size: int = -1, *, chars: Optional[int] = None) -> T: The number of chars to download from the stream. Leave unspecified or set negative to download all chars. Note, this can only be used when encoding is specified on `download_blob`. - :returns: + :return: The requested data as bytes or a string if encoding was specified. If the return value is empty, there is no more data to read. :rtype: T @@ -676,7 +680,7 @@ async def readall(self) -> T: Read the entire contents of this blob. This operation is blocking until all data is downloaded. - :returns: The requested data as bytes or a string if encoding was specified. + :return: The requested data as bytes or a string if encoding was specified. :rtype: T """ return await self.read() @@ -688,7 +692,7 @@ async def readinto(self, stream: IO[bytes]) -> int: The stream to download to. This can be an open file-handle, or any writable stream. The stream must be seekable if the download uses more than one parallel connection. - :returns: The number of bytes read. + :return: The number of bytes read. :rtype: int """ if self._text_mode: @@ -805,7 +809,7 @@ async def content_as_bytes(self, max_concurrency=1): :param int max_concurrency: The number of parallel connections with which to download. - :returns: The contents of the file as bytes. + :return: The contents of the file as bytes. :rtype: bytes """ warnings.warn( @@ -830,7 +834,7 @@ async def content_as_text(self, max_concurrency=1, encoding="UTF-8"): The number of parallel connections with which to download. :param str encoding: Test encoding to decode the downloaded bytes. Default is UTF-8. - :returns: The content of the file as a str. + :return: The content of the file as a str. :rtype: str """ warnings.warn( @@ -856,7 +860,7 @@ async def download_to_stream(self, stream, max_concurrency=1): uses more than one parallel connection. :param int max_concurrency: The number of parallel connections with which to download. - :returns: The properties of the downloaded blob. + :return: The properties of the downloaded blob. :rtype: Any """ warnings.warn( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_lease_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_lease_async.py index b5bfad95f53f..e09dce54b05d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_lease_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_lease_async.py @@ -107,6 +107,7 @@ async def acquire(self, lease_duration: int = -1, **kwargs: Any) -> None: This value is not tracked or validated on the client. To configure client-side network timesouts see `here `__. + :return: None :rtype: None """ mod_conditions = get_modify_conditions(kwargs) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_quick_query_helper_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_quick_query_helper_async.py index 3d05d2e771e9..cd90a8212d38 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_quick_query_helper_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_quick_query_helper_async.py @@ -88,7 +88,7 @@ async def readall(self) -> bytes: This operation is blocking until all data is downloaded. - :returns: The query results. + :return: The query results. :rtype: bytes """ stream = BytesIO() @@ -104,7 +104,7 @@ async def readinto(self, stream: IO) -> None: :param IO stream: The stream to download to. This can be an open file-handle, or any writable stream. - :returns: None + :return: None """ async for record in self._aiter_stream(): stream.write(record) @@ -114,7 +114,7 @@ async def records(self) -> AsyncIterable[bytes]: Records will be returned line by line. - :returns: A record generator for the query result. + :return: A record generator for the query result. :rtype: AsyncIterable[bytes] """ delimiter = self.record_delimiter.encode('utf-8') diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob.py b/sdk/storage/azure-storage-blob/tests/test_common_blob.py index c2d225ec4d74..090596925873 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob.py @@ -53,11 +53,7 @@ from devtools_testutils import FakeTokenCredential, recorded_by_proxy from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers import ( - MockStorageTransport, - _build_base_file_share_headers, - _create_file_share_oauth, -) +from test_helpers import _build_base_file_share_headers, _create_file_share_oauth # ------------------------------------------------------------------------------ SMALL_BLOB_SIZE = 1024 @@ -3597,57 +3593,4 @@ def test_upload_blob_partial_stream_chunked(self, **kwargs): result = blob.download_blob().readall() assert result == data[:length] - @BlobPreparer() - def test_mock_transport_no_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - content = blob_client.download_blob() - assert content is not None - - props = blob_client.get_blob_properties() - assert props is not None - - data = b"Hello World!" - resp = blob_client.upload_blob(data, overwrite=True) - assert resp is not None - - blob_data = blob_client.download_blob().read() - assert blob_data == b"Hello World!" # data is fixed by mock transport - - resp = blob_client.delete_blob() - assert resp is None - - @BlobPreparer() - def test_mock_transport_with_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - data = b"Hello World!" - resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) - assert resp is not None - - blob_data = blob_client.download_blob(validate_content=True).read() - assert blob_data == b"Hello World!" # data is fixed by mock transport - # ------------------------------------------------------------------------------ \ No newline at end of file diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py index 96e5f3260de0..25661d616dc5 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py @@ -55,7 +55,6 @@ from settings.testcase import BlobPreparer from test_helpers_async import ( AsyncStream, - MockStorageTransport, _build_base_file_share_headers, _create_file_share_oauth ) @@ -3525,58 +3524,4 @@ async def test_upload_blob_partial_stream_chunked(self, **kwargs): result = await (await blob.download_blob()).readall() assert result == data[:length] - @BlobPreparer() - async def test_mock_transport_no_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - content = await blob_client.download_blob() - assert content is not None - - props = await blob_client.get_blob_properties() - assert props is not None - - data = b"Hello Async World!" - stream = AsyncStream(data) - resp = await blob_client.upload_blob(stream, overwrite=True) - assert resp is not None - - blob_data = await (await blob_client.download_blob()).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport - - resp = await blob_client.delete_blob() - assert resp is None - - @BlobPreparer() - async def test_mock_transport_with_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - data = b"Hello Async World!" - stream = AsyncStream(data) - resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) - assert resp is not None - - blob_data = await (await blob_client.download_blob(validate_content=True)).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport # ------------------------------------------------------------------------------ diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index c51e4c606d36..c62b98afd53a 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Tuple from typing_extensions import Self -from azure.core.pipeline.transport import HttpTransport, RequestsTransportResponse +from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse from azure.core.rest import HttpRequest from azure.storage.blob._serialize import get_api_version from requests import Response @@ -92,7 +92,7 @@ def tell(self): return self.wrapped_stream.tell() -class MockHttpClientResponse(Response): +class MockClientResponse(Response): def __init__( self, url: str, body_bytes: bytes, @@ -100,7 +100,7 @@ def __init__( status: int = 200, reason: str = "OK" ) -> None: - super(MockHttpClientResponse).__init__() + super(MockClientResponse).__init__() self._url = url self._body = body_bytes self._content = body_bytes @@ -113,9 +113,9 @@ def __init__( self.raw = HTTPResponse() -class MockStorageTransport(HttpTransport): +class MockLegacyTransport(RequestsTransport): """ - This transport returns legacy http response objects from azure core and is + This transport returns http response objects from azure core pipelines and is intended only to test our backwards compatibility support. """ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse: @@ -132,7 +132,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"Hello World!", headers, @@ -142,7 +142,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # get_blob_properties rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -155,7 +155,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # upload_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -169,7 +169,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # delete_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -180,7 +180,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse ) ) else: - raise ValueError("The request is not accepted as part of MockStorageTransport.") + raise ValueError("The request is not accepted as part of MockLegacyTransport.") return rest_response def __enter__(self) -> Self: diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py index f183b389f42a..f95d832e4c22 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py @@ -3,16 +3,20 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - +import asyncio import aiohttp +from collections import deque from datetime import datetime, timezone from io import IOBase, UnsupportedOperation from typing import Any, Dict, Optional, Tuple +from unittest.mock import Mock, AsyncMock from azure.core.pipeline.transport import AioHttpTransportResponse, AsyncHttpTransport from azure.core.rest import HttpRequest from azure.storage.blob._serialize import get_api_version from aiohttp import ClientResponse +from aiohttp.streams import StreamReader +from aiohttp.client_proto import ResponseHandler def _build_base_file_share_headers(bearer_token_string: str, content_length: int = 0) -> Dict[str, Any]: @@ -126,11 +130,15 @@ def __init__( self._loop = None self.status = status self.reason = reason + self.content = StreamReader(ResponseHandler(asyncio.get_event_loop()), 65535) + self.content.total_bytes = len(body_bytes) + self.content._buffer = deque([body_bytes]) + self.content._eof = True -class MockStorageTransport(AsyncHttpTransport): +class MockLegacyTransport(AsyncHttpTransport): """ - This transport returns legacy http response objects from azure core and is + This transport returns legacy http response objects from azure core and is intended only to test our backwards compatibility support. """ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportResponse: @@ -199,7 +207,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes decompress=False ) else: - raise ValueError("The request is not accepted as part of MockStorageTransport.") + raise ValueError("The request is not accepted as part of MockLegacyTransport.") await rest_response.load_body() return rest_response diff --git a/sdk/storage/azure-storage-blob/tests/test_page_blob.py b/sdk/storage/azure-storage-blob/tests/test_page_blob.py index 160bc46081a3..19bb1642f8ca 100644 --- a/sdk/storage/azure-storage-blob/tests/test_page_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_page_blob.py @@ -1981,49 +1981,51 @@ def test_create_blob_with_md5_large(self, **kwargs): # Assert - @pytest.mark.skip(reason="Requires further investigation. Failing for unexpected kwarg seal_blob") @BlobPreparer() + @recorded_by_proxy def test_incremental_copy_blob(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=storage_account_key, max_page_size=4 * 1024) self._setup(bsc) - source_blob = self._create_blob(bsc, length=2048) - data = self.get_random_bytes(512) - resp1 = source_blob.upload_page(data, offset=0, length=512) - resp2 = source_blob.upload_page(data, offset=1024, length=512) - source_snapshot_blob = source_blob.create_snapshot() - - snapshot_blob = BlobClient.from_blob_url( - source_blob.url, credential=source_blob.credential, snapshot=source_snapshot_blob) - sas_token = self.generate_sas( - generate_blob_sas, - snapshot_blob.account_name, - snapshot_blob.container_name, - snapshot_blob.blob_name, - snapshot=snapshot_blob.snapshot, - account_key=snapshot_blob.credential.account_key, - permission=BlobSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1), - ) - sas_blob = BlobClient.from_blob_url(snapshot_blob.url, credential=sas_token) - # Act - dest_blob = bsc.get_blob_client(self.container_name, 'dest_blob') - copy = dest_blob.start_copy_from_url(sas_blob.url, incremental_copy=True) + try: + source_blob = self._create_blob(bsc, length=2048) + data = self.get_random_bytes(512) + resp1 = source_blob.upload_page(data, offset=0, length=512) + resp2 = source_blob.upload_page(data, offset=1024, length=512) + source_snapshot_blob = source_blob.create_snapshot() + + snapshot_blob = BlobClient.from_blob_url( + source_blob.url, credential=source_blob.credential, snapshot=source_snapshot_blob) + sas_token = self.generate_sas( + generate_blob_sas, + snapshot_blob.account_name, + snapshot_blob.container_name, + snapshot_blob.blob_name, + snapshot=snapshot_blob.snapshot, + account_key=snapshot_blob.credential.account_key, + permission=BlobSasPermissions(read=True), + expiry=datetime.utcnow() + timedelta(hours=1), + ) + sas_blob = BlobClient.from_blob_url(snapshot_blob.url, credential=sas_token) - # Assert - assert copy is not None - assert copy['copy_id'] is not None - assert copy['copy_status'] == 'pending' + # Act + dest_blob = bsc.get_blob_client(self.container_name, 'dest_blob') + copy = dest_blob.start_copy_from_url(sas_blob.url, incremental_copy=True) - copy_blob = self._wait_for_async_copy(dest_blob) - assert copy_blob.copy.status == 'success' - assert copy_blob.copy.destination_snapshot is not None + # Assert + assert copy is not None + assert copy['copy_id'] is not None + assert copy['copy_status'] == 'pending' - # strip off protocol - assert copy_blob.copy.source.endswith(sas_blob.url[5:]) + copy_blob = self._wait_for_async_copy(dest_blob) + assert copy_blob.copy.status == 'success' + assert copy_blob.copy.destination_snapshot is not None + finally: + bsc.delete_container(self.container_name) + bsc.delete_container(self.source_container_name) @BlobPreparer() @recorded_by_proxy diff --git a/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py index 0b3c9d16f6b8..11fa91a70a3c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py @@ -1965,7 +1965,6 @@ async def test_create_blob_with_md5_large(self, **kwargs): # Assert - @pytest.mark.skip(reason="Requires further investigation. Failing for unexpected kwarg seal_blob") @BlobPreparer() @recorded_by_proxy_async async def test_incremental_copy_blob(self, **kwargs): @@ -1974,42 +1973,44 @@ async def test_incremental_copy_blob(self, **kwargs): bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=storage_account_key, max_page_size=4 * 1024) await self._setup(bsc) - source_blob = await self._create_blob(bsc, 2048) - data = self.get_random_bytes(512) - resp1 = await source_blob.upload_page(data, offset=0, length=512) - resp2 = await source_blob.upload_page(data, offset=1024, length=512) - source_snapshot_blob = await source_blob.create_snapshot() - - snapshot_blob = BlobClient.from_blob_url( - source_blob.url, credential=source_blob.credential, snapshot=source_snapshot_blob) - sas_token = self.generate_sas( - generate_blob_sas, - snapshot_blob.account_name, - snapshot_blob.container_name, - snapshot_blob.blob_name, - snapshot=snapshot_blob.snapshot, - account_key=snapshot_blob.credential.account_key, - permission=BlobSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1), - ) - sas_blob = BlobClient.from_blob_url(snapshot_blob.url, credential=sas_token) + try: + source_blob = await self._create_blob(bsc, 2048) + data = self.get_random_bytes(512) + resp1 = await source_blob.upload_page(data, offset=0, length=512) + resp2 = await source_blob.upload_page(data, offset=1024, length=512) + source_snapshot_blob = await source_blob.create_snapshot() + + snapshot_blob = BlobClient.from_blob_url( + source_blob.url, credential=source_blob.credential, snapshot=source_snapshot_blob) + sas_token = self.generate_sas( + generate_blob_sas, + snapshot_blob.account_name, + snapshot_blob.container_name, + snapshot_blob.blob_name, + snapshot=snapshot_blob.snapshot, + account_key=snapshot_blob.credential.account_key, + permission=BlobSasPermissions(read=True), + expiry=datetime.utcnow() + timedelta(hours=1), + ) + sas_blob = BlobClient.from_blob_url(snapshot_blob.url, credential=sas_token) - # Act - dest_blob = bsc.get_blob_client(self.container_name, 'dest_blob') - copy = await dest_blob.start_copy_from_url(sas_blob.url, incremental_copy=True) - # Assert - assert copy is not None - assert copy['copy_id'] is not None - assert copy['copy_status'] == 'pending' + # Act + dest_blob = bsc.get_blob_client(self.container_name, 'dest_blob') + copy = await dest_blob.start_copy_from_url(sas_blob.url, incremental_copy=True) - copy_blob = await self._wait_for_async_copy(dest_blob) - assert copy_blob.copy.status == 'success' - assert copy_blob.copy.destination_snapshot is not None + # Assert + assert copy is not None + assert copy['copy_id'] is not None + assert copy['copy_status'] == 'pending' - # strip off protocol - assert copy_blob.copy.source.endswith(sas_blob.url[5:]) + copy_blob = await self._wait_for_async_copy(dest_blob) + assert copy_blob.copy.status == 'success' + assert copy_blob.copy.destination_snapshot is not None + finally: + await bsc.delete_container(self.container_name) + await bsc.delete_container(self.source_container_name) @BlobPreparer() @recorded_by_proxy_async diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py new file mode 100644 index 000000000000..4ad076ef7543 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from azure.storage.blob import BlobClient, BlobServiceClient +from azure.core.exceptions import ResourceExistsError +from azure.core.pipeline.transport import RequestsTransport + +from devtools_testutils import recorded_by_proxy +from devtools_testutils.storage import StorageRecordedTestCase +from settings.testcase import BlobPreparer +from test_helpers import MockLegacyTransport + + +class TestStorageTransports(StorageRecordedTestCase): + def _setup(self, storage_account_name, key): + self.bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=key) + self.container_name = self.get_resource_name('utcontainer') + if self.is_live: + try: + self.bsc.create_container(self.container_name, timeout=5) + except ResourceExistsError: + pass + + @BlobPreparer() + def test_legacy_transport_old_response(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + props = blob_client.get_blob_properties() + assert props is not None + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True) + assert resp is not None + + blob_data = blob_client.download_blob().read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + def test_legacy_transport_old_response_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = blob_client.download_blob(validate_content=True).read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + @recorded_by_proxy + def test_legacy_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + self._setup(storage_account_name, storage_account_key) + + transport = RequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name=self.container_name, + blob_name=self.get_resource_name('blob'), + credential=storage_account_key, + transport=transport + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True) + assert resp is not None + + blob_data = blob_client.download_blob().read() + assert blob_data == b"Hello World!" + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + @recorded_by_proxy + def test_legacy_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + self._setup(storage_account_name, storage_account_key) + + transport = RequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name=self.container_name, + blob_name=self.get_resource_name('blob'), + credential=storage_account_key, + transport=transport + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = blob_client.download_blob(validate_content=True).read() + assert blob_data == b"Hello World!" + + resp = blob_client.delete_blob() + assert resp is None diff --git a/sdk/storage/azure-storage-blob/tests/test_transports_async.py b/sdk/storage/azure-storage-blob/tests/test_transports_async.py new file mode 100644 index 000000000000..0d6f03eb02b2 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_transports_async.py @@ -0,0 +1,192 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest + +from azure.storage.blob.aio import BlobClient, BlobServiceClient +from azure.core.exceptions import ResourceExistsError +from azure.core.pipeline.transport import AioHttpTransport, AsyncioRequestsTransport + +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase +from settings.testcase import BlobPreparer +from test_helpers_async import AsyncStream, MockLegacyTransport + + +class TestStorageTransportsAsync(AsyncStorageRecordedTestCase): + async def _setup(self, storage_account_name, key): + self.bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=key) + self.container_name = self.get_resource_name('utcontainer') + self.byte_data = self.get_random_bytes(1024) + if self.is_live: + try: + await self.bsc.create_container(self.container_name) + except ResourceExistsError: + pass + + @BlobPreparer() + async def test_legacy_transport_old_response(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport + + resp = await blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + async def test_legacy_transport_old_response_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport + + resp = await blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + @recorded_by_proxy_async + async def test_legacy_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + await self._setup(storage_account_name, storage_account_key) + + transport = AioHttpTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name=self.container_name, + blob_name=self.get_resource_name('blob'), + credential=storage_account_key, + transport=transport + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" + + resp = await blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + @recorded_by_proxy_async + async def test_legacy_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + await self._setup(storage_account_name, storage_account_key) + + transport = AioHttpTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name=self.container_name, + blob_name=self.get_resource_name('blob'), + credential=storage_account_key, + transport=transport + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" + + resp = await blob_client.delete_blob() + assert resp is None + + @pytest.mark.live_test_only + @BlobPreparer() + async def test_asyncio_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + await self._setup(storage_account_name, storage_account_key) + + transport = AsyncioRequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name=self.container_name, + blob_name=self.get_resource_name('blob'), + credential=storage_account_key, + transport=transport + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" + + resp = await blob_client.delete_blob() + assert resp is None + + @pytest.mark.live_test_only + @BlobPreparer() + async def test_asyncio_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + await self._setup(storage_account_name, storage_account_key) + + transport = AsyncioRequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name=self.container_name, + blob_name=self.get_resource_name('blob'), + credential=storage_account_key, + transport=transport + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" + + resp = await blob_client.delete_blob() + assert resp is None diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/__init__.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/__init__.py index a8b1a27d48f9..4dbbb7ed7b09 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/__init__.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/__init__.py @@ -11,7 +11,7 @@ try: from urllib.parse import quote, unquote except ImportError: - from urllib2 import quote, unquote # type: ignore + from urllib2 import quote, unquote # type: ignore def url_quote(url): @@ -24,20 +24,20 @@ def url_unquote(url): def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") def decode_base64_to_bytes(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") return base64.b64decode(data) def decode_base64_to_text(data): decoded_bytes = decode_base64_to_bytes(data) - return decoded_bytes.decode('utf-8') + return decoded_bytes.decode("utf-8") def sign_string(key, string_to_sign, key_is_base64=True): @@ -45,9 +45,9 @@ def sign_string(key, string_to_sign, key_is_base64=True): key = decode_base64_to_bytes(key) else: if isinstance(key, str): - key = key.encode('utf-8') + key = key.encode("utf-8") if isinstance(string_to_sign, str): - string_to_sign = string_to_sign.encode('utf-8') + string_to_sign = string_to_sign.encode("utf-8") signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = encode_base64(digest) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/authentication.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/authentication.py index b41f2391ed4a..f778dc71eec4 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/authentication.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/authentication.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) +# fmt: off table_lv0 = [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, @@ -51,6 +52,8 @@ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, ] +# fmt: on + def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements tables = [table_lv0, table_lv4] @@ -95,6 +98,7 @@ def _wrap_exception(ex, desired_type): msg = ex.args[0] return desired_type(msg) + # This method attempts to emulate the sorting done by the service def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: @@ -135,38 +139,42 @@ def __init__(self, account_name, account_key): @staticmethod def _get_headers(request, headers_to_sign): headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value) - if 'content-length' in headers and headers['content-length'] == '0': - del headers['content-length'] - return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n' + if "content-length" in headers and headers["content-length"] == "0": + del headers["content-length"] + return "\n".join(headers.get(x, "") for x in headers_to_sign) + "\n" @staticmethod def _get_verb(request): - return request.http_request.method + '\n' + return request.http_request.method + "\n" def _get_canonicalized_resource(self, request): uri_path = urlparse(request.http_request.url).path try: - if isinstance(request.context.transport, AioHttpTransport) or \ - isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \ - isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None), - AioHttpTransport): + if ( + isinstance(request.context.transport, AioHttpTransport) + or isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) + or isinstance( + getattr(getattr(request.context.transport, "_transport", None), "_transport", None), + AioHttpTransport, + ) + ): uri_path = URL(uri_path) - return '/' + self.account_name + str(uri_path) + return "/" + self.account_name + str(uri_path) except TypeError: pass - return '/' + self.account_name + uri_path + return "/" + self.account_name + uri_path @staticmethod def _get_canonicalized_headers(request): - string_to_sign = '' + string_to_sign = "" x_ms_headers = [] for name, value in request.http_request.headers.items(): - if name.startswith('x-ms-'): + if name.startswith("x-ms-"): x_ms_headers.append((name.lower(), value)) x_ms_headers = _storage_header_sort(x_ms_headers) for name, value in x_ms_headers: if value is not None: - string_to_sign += ''.join([name, ':', value, '\n']) + string_to_sign += "".join([name, ":", value, "\n"]) return string_to_sign @staticmethod @@ -174,37 +182,46 @@ def _get_canonicalized_resource_query(request): sorted_queries = list(request.http_request.query.items()) sorted_queries.sort() - string_to_sign = '' + string_to_sign = "" for name, value in sorted_queries: if value is not None: - string_to_sign += '\n' + name.lower() + ':' + unquote(value) + string_to_sign += "\n" + name.lower() + ":" + unquote(value) return string_to_sign def _add_authorization_header(self, request, string_to_sign): try: signature = sign_string(self.account_key, string_to_sign) - auth_string = 'SharedKey ' + self.account_name + ':' + signature - request.http_request.headers['Authorization'] = auth_string + auth_string = "SharedKey " + self.account_name + ":" + signature + request.http_request.headers["Authorization"] = auth_string except Exception as ex: # Wrap any error that occurred as signing error # Doing so will clarify/locate the source of problem raise _wrap_exception(ex, AzureSigningError) from ex def on_request(self, request): - string_to_sign = \ - self._get_verb(request) + \ - self._get_headers( + string_to_sign = ( + self._get_verb(request) + + self._get_headers( request, [ - 'content-encoding', 'content-language', 'content-length', - 'content-md5', 'content-type', 'date', 'if-modified-since', - 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range' - ] - ) + \ - self._get_canonicalized_headers(request) + \ - self._get_canonicalized_resource(request) + \ - self._get_canonicalized_resource_query(request) + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", + ], + ) + + self._get_canonicalized_headers(request) + + self._get_canonicalized_resource(request) + + self._get_canonicalized_resource_query(request) + ) self._add_authorization_header(request, string_to_sign) # logger.debug("String_to_sign=%s", string_to_sign) @@ -212,7 +229,7 @@ def on_request(self, request): class StorageHttpChallenge(object): def __init__(self, challenge): - """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """ + """Parses an HTTP WWW-Authentication Bearer challenge from the Storage service.""" if not challenge: raise ValueError("Challenge cannot be empty") @@ -221,7 +238,7 @@ def __init__(self, challenge): # name=value pairs either comma or space separated with values possibly being # enclosed in quotes - for item in re.split('[, ]', trimmed_challenge): + for item in re.split("[, ]", trimmed_challenge): comps = item.split("=") if len(comps) == 2: key = comps[0].strip(' "') @@ -230,11 +247,11 @@ def __init__(self, challenge): self._parameters[key] = value # Extract and verify required parameters - self.authorization_uri = self._parameters.get('authorization_uri') + self.authorization_uri = self._parameters.get("authorization_uri") if not self.authorization_uri: raise ValueError("Authorization Uri not found") - self.resource_id = self._parameters.get('resource_id') + self.resource_id = self._parameters.get("resource_id") if not self.resource_id: raise ValueError("Resource id not found") diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py index 7de14050b963..217eb2110f15 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py @@ -20,7 +20,10 @@ from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline -from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module + HttpTransport, + RequestsTransport, +) from azure.core.pipeline.policies import ( AzureSasCredentialPolicy, ContentDecodePolicy, @@ -73,8 +76,17 @@ def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + credential: Optional[ + Union[ + str, + Dict[str, str], + AzureNamedKeyCredential, + AzureSasCredential, + "AsyncTokenCredential", + TokenCredential, + ] + ] = None, + **kwargs: Any, ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts", {}) @@ -83,12 +95,15 @@ def __init__( if service not in ["blob", "queue", "file-share", "dfs"]: raise ValueError(f"Invalid service: {service}") - service_name = service.split('-')[0] + service_name = service.split("-")[0] account = parsed_url.netloc.split(f".{service_name}.core.") self.account_name = account[0] if len(account) > 1 else None - if not self.account_name and parsed_url.netloc.startswith("localhost") \ - or parsed_url.netloc.startswith("127.0.0.1"): + if ( + not self.account_name + and parsed_url.netloc.startswith("localhost") + or parsed_url.netloc.startswith("127.0.0.1") + ): self._is_localhost = True self.account_name = parsed_url.path.strip("/") @@ -106,7 +121,7 @@ def __init__( secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") if kwargs.get("secondary_hostname"): secondary_hostname = kwargs["secondary_hostname"] - primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') + primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip("/") self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} self._sdk_moniker = f"storage-{service}/{VERSION}" @@ -119,71 +134,76 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): - """ This method is to close the sockets opened by the client. + def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. """ self._client.close() @property - def url(self): + def url(self) -> str: """The full endpoint URL to this entity, including SAS token if used. This could be either the primary endpoint, or the secondary endpoint depending on the current :func:`location_mode`. - :returns: The full endpoint URL to this entity, including SAS token if used. + :return: The full endpoint URL to this entity, including SAS token if used. :rtype: str """ - return self._format_url(self._hosts[self._location_mode]) + return self._format_url(self._hosts[self._location_mode]) # type: ignore @property - def primary_endpoint(self): + def primary_endpoint(self) -> str: """The full primary endpoint URL. + :return: The full primary endpoint URL. :rtype: str """ - return self._format_url(self._hosts[LocationMode.PRIMARY]) + return self._format_url(self._hosts[LocationMode.PRIMARY]) # type: ignore @property - def primary_hostname(self): + def primary_hostname(self) -> str: """The hostname of the primary endpoint. + :return: The hostname of the primary endpoint. :rtype: str """ return self._hosts[LocationMode.PRIMARY] @property - def secondary_endpoint(self): + def secondary_endpoint(self) -> str: """The full secondary endpoint URL if configured. If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The full secondary endpoint URL. :rtype: str - :raise ValueError: + :raise ValueError: If no secondary endpoint is configured. """ if not self._hosts[LocationMode.SECONDARY]: raise ValueError("No secondary host configured.") - return self._format_url(self._hosts[LocationMode.SECONDARY]) + return self._format_url(self._hosts[LocationMode.SECONDARY]) # type: ignore @property - def secondary_hostname(self): + def secondary_hostname(self) -> Optional[str]: """The hostname of the secondary endpoint. If not available this will be None. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The hostname of the secondary endpoint, or None if not configured. :rtype: Optional[str] """ return self._hosts[LocationMode.SECONDARY] @property - def location_mode(self): + def location_mode(self) -> str: """The location mode that the client is currently using. By default this will be "primary". Options include "primary" and "secondary". + :return: The current location mode. :rtype: str """ @@ -206,11 +226,16 @@ def api_version(self): return self._client._config.version # pylint: disable=protected-access def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -218,7 +243,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): credential = cast(str, credential) query_str += credential.lstrip("?") @@ -228,13 +254,16 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) @@ -268,22 +297,18 @@ def _create_pipeline( config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs) + HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore return config, Pipeline(transport, policies=policies) - def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> Iterator["HttpResponse"]: + def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> Iterator["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An iterator of HttpResponse objects. + :return: An iterator of HttpResponse objects. :rtype: Iterator[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg @@ -292,25 +317,21 @@ def _batch_send( request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), headers={ - 'x-ms-version': self.api_version, - "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) - } + "x-ms-version": self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False), + }, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access body = serialize_batch_body(request.multipart_mixed_info[0], batch_id) @@ -318,9 +339,7 @@ def _batch_send( temp = request.multipart_mixed_info request.multipart_mixed_info = None - pipeline_response = self._pipeline.run( - request, **kwargs - ) + pipeline_response = self._pipeline.run(request, **kwargs) response = pipeline_response.http_response request.multipart_mixed_info = temp @@ -332,8 +351,7 @@ def _batch_send( parts = list(response.parts()) if any(p for p in parts if not 200 <= p.status_code < 300): error = PartialBatchErrorException( - message="There is a partial failure in the batch operation.", - response=response, parts=parts + message="There is a partial failure in the batch operation.", response=response, parts=parts ) raise error return iter(parts) @@ -347,6 +365,7 @@ class TransportWrapper(HttpTransport): by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, transport): self._transport = transport @@ -368,7 +387,9 @@ def __exit__(self, *args): def _format_shared_key_credential( account_name: Optional[str], - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential] + ] = None, ) -> Any: if isinstance(credential, str): if not account_name: @@ -388,8 +409,12 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -411,14 +436,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -438,7 +460,7 @@ def parse_connection_str( def create_configuration(**kwargs: Any) -> StorageConfiguration: - # Backwards compatibility if someone is not passing sdk_moniker + # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" config = StorageConfiguration(**kwargs) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py index 6186b29db107..f39a57b24943 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py @@ -64,18 +64,26 @@ async def __aenter__(self): async def __aexit__(self, *args): await self._client.__aexit__(*args) - async def close(self): - """ This method is to close the sockets opened by the client. + async def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self._client.close() def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -83,7 +91,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): query_str += credential.lstrip("?") # type: ignore [union-attr] credential = None @@ -92,35 +101,40 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[ - Union[AsyncStorageBearerTokenCredentialPolicy, - SharedKeyCredentialPolicy, - AzureSasCredentialPolicy]] = None - if hasattr(credential, 'get_token'): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + Union[AsyncStorageBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy] + ] = None + if hasattr(credential, "get_token"): + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = AsyncStorageBearerTokenCredentialPolicy( - cast(AsyncTokenCredential, credential), audience) + cast(AsyncTokenCredential, credential), audience + ) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): self._credential_policy = AzureSasCredentialPolicy(credential) elif credential is not None: raise TypeError(f"Unsupported credential: {type(credential)}") - config = kwargs.get('_configuration') or create_configuration(**kwargs) - if kwargs.get('_pipeline'): - return config, kwargs['_pipeline'] - transport = kwargs.get('transport') + config = kwargs.get("_configuration") or create_configuration(**kwargs) + if kwargs.get("_pipeline"): + return config, kwargs["_pipeline"] + transport = kwargs.get("transport") kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) if not transport: try: - from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + AioHttpTransport, + ) except ImportError as exc: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc transport = AioHttpTransport(**kwargs) @@ -143,53 +157,41 @@ def _create_pipeline( HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore - config.transport = transport #type: ignore - return config, AsyncPipeline(transport, policies=policies) #type: ignore + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, AsyncPipeline(transport, policies=policies) # type: ignore - async def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> AsyncList["HttpResponse"]: + async def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> AsyncList["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An AsyncList of HttpResponse objects. + :return: An AsyncList of HttpResponse objects. :rtype: AsyncList[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), - headers={ - 'x-ms-version': self.api_version - } + headers={"x-ms-version": self.api_version}, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) # type: ignore - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) - pipeline_response = await self._pipeline.run( - request, **kwargs - ) + pipeline_response = await self._pipeline.run(request, **kwargs) response = pipeline_response.http_response try: if response.status_code not in [202]: raise HttpResponseError(response=response) - parts = response.parts() # Return an AsyncIterator + parts = response.parts() # Return an AsyncIterator if raise_on_any_failure: parts_list = [] async for part in parts: @@ -197,7 +199,8 @@ async def _batch_send( if any(p for p in parts_list if not 200 <= p.status_code < 300): error = PartialBatchErrorException( message="There is a partial failure in the batch operation.", - response=response, parts=parts_list + response=response, + parts=parts_list, ) raise error return AsyncList(parts_list) @@ -205,11 +208,16 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) + def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -231,14 +239,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -256,11 +261,13 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential + class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, async_transport): self._transport = async_transport diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/constants.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/constants.py index 0b4b029a2d1b..0926f04c4081 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/constants.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/constants.py @@ -16,4 +16,4 @@ DEFAULT_OAUTH_SCOPE = "/.default" STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default" -SERVICE_HOST_BASE = 'core.windows.net' +SERVICE_HOST_BASE = "core.windows.net" diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/models.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/models.py index 183d6f64a8be..e61148718712 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/models.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/models.py @@ -22,6 +22,7 @@ def get_enum_value(value): class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Error codes returned by the service.""" # Generic storage values ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists" @@ -172,26 +173,26 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed" # DataLake values - CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero' - PATH_ALREADY_EXISTS = 'PathAlreadyExists' - INVALID_FLUSH_POSITION = 'InvalidFlushPosition' - INVALID_PROPERTY_NAME = 'InvalidPropertyName' - INVALID_SOURCE_URI = 'InvalidSourceUri' - UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion' - FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound' - PATH_NOT_FOUND = 'PathNotFound' - RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound' - SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound' - DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted' - FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists' - FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted' - INVALID_DESTINATION_PATH = 'InvalidDestinationPath' - INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath' - INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType' - LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken' - LEASE_NAME_MISMATCH = 'LeaseNameMismatch' - PATH_CONFLICT = 'PathConflict' - SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted' + CONTENT_LENGTH_MUST_BE_ZERO = "ContentLengthMustBeZero" + PATH_ALREADY_EXISTS = "PathAlreadyExists" + INVALID_FLUSH_POSITION = "InvalidFlushPosition" + INVALID_PROPERTY_NAME = "InvalidPropertyName" + INVALID_SOURCE_URI = "InvalidSourceUri" + UNSUPPORTED_REST_VERSION = "UnsupportedRestVersion" + FILE_SYSTEM_NOT_FOUND = "FilesystemNotFound" + PATH_NOT_FOUND = "PathNotFound" + RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = "RenameDestinationParentPathNotFound" + SOURCE_PATH_NOT_FOUND = "SourcePathNotFound" + DESTINATION_PATH_IS_BEING_DELETED = "DestinationPathIsBeingDeleted" + FILE_SYSTEM_ALREADY_EXISTS = "FilesystemAlreadyExists" + FILE_SYSTEM_BEING_DELETED = "FilesystemBeingDeleted" + INVALID_DESTINATION_PATH = "InvalidDestinationPath" + INVALID_RENAME_SOURCE_PATH = "InvalidRenameSourcePath" + INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = "InvalidSourceOrDestinationResourceType" + LEASE_IS_ALREADY_BROKEN = "LeaseIsAlreadyBroken" + LEASE_NAME_MISMATCH = "LeaseNameMismatch" + PATH_CONFLICT = "PathConflict" + SOURCE_PATH_IS_BEING_DELETED = "SourcePathIsBeingDeleted" class DictMixin(object): @@ -222,7 +223,7 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')}) + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) def __contains__(self, key): return key in self.__dict__ @@ -234,13 +235,13 @@ def update(self, *args, **kwargs): return self.__dict__.update(*args, **kwargs) def keys(self): - return [k for k in self.__dict__ if not k.startswith('_')] + return [k for k in self.__dict__ if not k.startswith("_")] def values(self): - return [v for k, v in self.__dict__.items() if not k.startswith('_')] + return [v for k, v in self.__dict__.items() if not k.startswith("_")] def items(self): - return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')] + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] def get(self, key, default=None): if key in self.__dict__: @@ -255,8 +256,8 @@ class LocationMode(object): must use PRIMARY. """ - PRIMARY = 'primary' #: Requests should be sent to the primary location. - SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible. + PRIMARY = "primary" #: Requests should be sent to the primary location. + SECONDARY = "secondary" #: Requests should be sent to the secondary location, if possible. class ResourceTypes(object): @@ -281,17 +282,12 @@ class ResourceTypes(object): _str: str def __init__( - self, - service: bool = False, - container: bool = False, - object: bool = False # pylint: disable=redefined-builtin + self, service: bool = False, container: bool = False, object: bool = False # pylint: disable=redefined-builtin ) -> None: self.service = service self.container = container self.object = object - self._str = (('s' if self.service else '') + - ('c' if self.container else '') + - ('o' if self.object else '')) + self._str = ("s" if self.service else "") + ("c" if self.container else "") + ("o" if self.object else "") def __str__(self): return self._str @@ -309,9 +305,9 @@ def from_string(cls, string): :return: A ResourceTypes object :rtype: ~azure.storage.blob.ResourceTypes """ - res_service = 's' in string - res_container = 'c' in string - res_object = 'o' in string + res_service = "s" in string + res_container = "c" in string + res_object = "o" in string parsed = cls(res_service, res_container, res_object) parsed._str = string @@ -392,29 +388,30 @@ def __init__( self.write = write self.delete = delete self.delete_previous_version = delete_previous_version - self.permanent_delete = kwargs.pop('permanent_delete', False) + self.permanent_delete = kwargs.pop("permanent_delete", False) self.list = list self.add = add self.create = create self.update = update self.process = process - self.tag = kwargs.pop('tag', False) - self.filter_by_tags = kwargs.pop('filter_by_tags', False) - self.set_immutability_policy = kwargs.pop('set_immutability_policy', False) - self._str = (('r' if self.read else '') + - ('w' if self.write else '') + - ('d' if self.delete else '') + - ('x' if self.delete_previous_version else '') + - ('y' if self.permanent_delete else '') + - ('l' if self.list else '') + - ('a' if self.add else '') + - ('c' if self.create else '') + - ('u' if self.update else '') + - ('p' if self.process else '') + - ('f' if self.filter_by_tags else '') + - ('t' if self.tag else '') + - ('i' if self.set_immutability_policy else '') - ) + self.tag = kwargs.pop("tag", False) + self.filter_by_tags = kwargs.pop("filter_by_tags", False) + self.set_immutability_policy = kwargs.pop("set_immutability_policy", False) + self._str = ( + ("r" if self.read else "") + + ("w" if self.write else "") + + ("d" if self.delete else "") + + ("x" if self.delete_previous_version else "") + + ("y" if self.permanent_delete else "") + + ("l" if self.list else "") + + ("a" if self.add else "") + + ("c" if self.create else "") + + ("u" if self.update else "") + + ("p" if self.process else "") + + ("f" if self.filter_by_tags else "") + + ("t" if self.tag else "") + + ("i" if self.set_immutability_policy else "") + ) def __str__(self): return self._str @@ -432,23 +429,34 @@ def from_string(cls, permission): :return: An AccountSasPermissions object :rtype: ~azure.storage.filedatalake.AccountSasPermissions """ - p_read = 'r' in permission - p_write = 'w' in permission - p_delete = 'd' in permission - p_delete_previous_version = 'x' in permission - p_permanent_delete = 'y' in permission - p_list = 'l' in permission - p_add = 'a' in permission - p_create = 'c' in permission - p_update = 'u' in permission - p_process = 'p' in permission - p_tag = 't' in permission - p_filter_by_tags = 'f' in permission - p_set_immutability_policy = 'i' in permission - parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version, - list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag, - filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy, - permanent_delete=p_permanent_delete) + p_read = "r" in permission + p_write = "w" in permission + p_delete = "d" in permission + p_delete_previous_version = "x" in permission + p_permanent_delete = "y" in permission + p_list = "l" in permission + p_add = "a" in permission + p_create = "c" in permission + p_update = "u" in permission + p_process = "p" in permission + p_tag = "t" in permission + p_filter_by_tags = "f" in permission + p_set_immutability_policy = "i" in permission + parsed = cls( + read=p_read, + write=p_write, + delete=p_delete, + delete_previous_version=p_delete_previous_version, + list=p_list, + add=p_add, + create=p_create, + update=p_update, + process=p_process, + tag=p_tag, + filter_by_tags=p_filter_by_tags, + set_immutability_policy=p_set_immutability_policy, + permanent_delete=p_permanent_delete, + ) return parsed @@ -464,18 +472,11 @@ class Services(object): Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False. """ - def __init__( - self, *, - blob: bool = False, - queue: bool = False, - fileshare: bool = False - ) -> None: + def __init__(self, *, blob: bool = False, queue: bool = False, fileshare: bool = False) -> None: self.blob = blob self.queue = queue self.fileshare = fileshare - self._str = (('b' if self.blob else '') + - ('q' if self.queue else '') + - ('f' if self.fileshare else '')) + self._str = ("b" if self.blob else "") + ("q" if self.queue else "") + ("f" if self.fileshare else "") def __str__(self): return self._str @@ -493,9 +494,9 @@ def from_string(cls, string): :return: A Services object :rtype: ~azure.storage.blob.Services """ - res_blob = 'b' in string - res_queue = 'q' in string - res_file = 'f' in string + res_blob = "b" in string + res_queue = "q" in string + res_file = "f" in string parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file) parsed._str = string @@ -573,13 +574,13 @@ class StorageConfiguration(Configuration): def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) - self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) + self.max_single_put_size = kwargs.pop("max_single_put_size", 64 * 1024 * 1024) self.copy_polling_interval = 15 - self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) - self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) - self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) - self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) - self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) - self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) - self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) - self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) + self.max_block_size = kwargs.pop("max_block_size", 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop("use_byte_buffer", False) + self.max_page_size = kwargs.pop("max_page_size", 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop("max_single_get_size", 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop("max_chunk_get_size", 4 * 1024 * 1024) + self.max_range_size = kwargs.pop("max_range_size", 4 * 1024 * 1024) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/parser.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/parser.py index 112c1984f4fb..e4fcb8f041ba 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/parser.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/parser.py @@ -12,14 +12,14 @@ def _to_utc_datetime(value: datetime) -> str: - return value.strftime('%Y-%m-%dT%H:%M:%SZ') + return value.strftime("%Y-%m-%dT%H:%M:%SZ") def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: """Converts an RFC 1123 date string to a UTC datetime. :param str rfc_1123: The time and date in RFC 1123 format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not rfc_1123: @@ -33,7 +33,7 @@ def _filetime_to_datetime(filetime: str) -> Optional[datetime]: If parsing MS Filetime fails, tries RFC 1123 as backup. :param str filetime: The time and date in MS filetime format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not filetime: diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index ee75cd5a466c..a08fee7afaac 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -28,7 +28,7 @@ HTTPPolicy, NetworkTraceLoggingPolicy, RequestHistory, - SansIOHTTPPolicy + SansIOHTTPPolicy, ) from .authentication import AzureSigningError, StorageHttpChallenge @@ -39,7 +39,7 @@ from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -48,14 +48,14 @@ def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -63,8 +63,8 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): - if settings['hook']: - settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + if settings["hook"]: + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -95,40 +95,39 @@ def is_retry(response, mode): def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False def urljoin(base_url, stub_url): parsed = urlparse(base_url) - parsed = parsed._replace(path=parsed.path + '/' + stub_url) + parsed = parsed._replace(path=parsed.path + "/" + stub_url) return parsed.geturl() class QueueMessagePolicy(SansIOHTTPPolicy): def on_request(self, request): - message_id = request.context.options.pop('queue_message_id', None) + message_id = request.context.options.pop("queue_message_id", None) if message_id: - request.http_request.url = urljoin( - request.http_request.url, - message_id) + request.http_request.url = urljoin(request.http_request.url, message_id) class StorageHeadersPolicy(HeadersPolicy): - request_id_header_name = 'x-ms-client-request-id' + request_id_header_name = "x-ms-client-request-id" def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) - request.http_request.headers['x-ms-date'] = current_time + request.http_request.headers["x-ms-date"] = current_time - custom_id = request.context.options.pop('client_request_id', None) - request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) + custom_id = request.context.options.pop("client_request_id", None) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -153,7 +152,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument super(StorageHosts, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request.context.options['hosts'] = self.hosts + request.context.options["hosts"] = self.hosts parsed_url = urlparse(request.http_request.url) # Detect what location mode we're currently requesting with @@ -163,10 +162,10 @@ def on_request(self, request: "PipelineRequest") -> None: location_mode = key # See if a specific location mode has been specified, and if so, redirect - use_location = request.context.options.pop('use_location', None) + use_location = request.context.options.pop("use_location", None) if use_location: # Lock retries to the specific location - request.context.options['retry_to_secondary'] = False + request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: @@ -175,7 +174,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.url = updated.geturl() location_mode = use_location - request.context.options['location_mode'] = location_mode + request.context.options["location_mode"] = location_mode class StorageLoggingPolicy(NetworkTraceLoggingPolicy): @@ -200,19 +199,19 @@ def on_request(self, request: "PipelineRequest") -> None: try: log_url = http_request.url query_params = http_request.query - if 'sig' in query_params: - log_url = log_url.replace(query_params['sig'], "sig=*****") + if "sig" in query_params: + log_url = log_url.replace(query_params["sig"], "sig=*****") _LOGGER.debug("Request URL: %r", log_url) _LOGGER.debug("Request method: %r", http_request.method) _LOGGER.debug("Request headers:") for header, value in http_request.headers.items(): - if header.lower() == 'authorization': - value = '*****' - elif header.lower() == 'x-ms-copy-source' and 'sig' in value: + if header.lower() == "authorization": + value = "*****" + elif header.lower() == "x-ms-copy-source" and "sig" in value: # take the url apart and scrub away the signed signature scheme, netloc, path, params, query, fragment = urlparse(value) parsed_qs = dict(parse_qsl(query)) - parsed_qs['sig'] = '*****' + parsed_qs["sig"] = "*****" # the SAS needs to be put back together value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) @@ -242,11 +241,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") # We don't want to log binary data if the response is a file. _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) - header = response.http_response.headers.get('content-disposition') + header = response.http_response.headers.get("content-disposition") resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): - filename = header.partition('=')[2] + filename = header.partition("=")[2] _LOGGER.debug("File attachments: %s", filename) elif resp_content_type.endswith("octet-stream"): _LOGGER.debug("Body contains binary data.") @@ -268,11 +267,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") class StorageRequestHook(SansIOHTTPPolicy): def __init__(self, **kwargs): - self._request_callback = kwargs.get('raw_request_hook') + self._request_callback = kwargs.get("raw_request_hook") super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop('raw_request_hook', self._request_callback) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -280,49 +279,50 @@ def on_request(self, request: "PipelineRequest") -> None: class StorageResponseHook(HTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(StorageResponseHook, self).__init__() def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response @@ -332,7 +332,8 @@ class StorageContentValidation(SansIOHTTPPolicy): This will overwrite any headers already defined in the request. """ - header_name = 'Content-MD5' + + header_name = "Content-MD5" def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @@ -342,10 +343,10 @@ def get_content_md5(data): # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. data = data or b"" - md5 = hashlib.md5() # nosec + md5 = hashlib.md5() # nosec if isinstance(data, bytes): md5.update(data) - elif hasattr(data, 'read'): + elif hasattr(data, "read"): pos = 0 try: pos = data.tell() @@ -363,22 +364,25 @@ def get_content_md5(data): return md5.digest() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop('validate_content', False) - if validate_content and request.http_request.method != 'GET': + validate_content = request.context.options.pop("validate_content", False) + if validate_content and request.http_request.method != "GET": computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) request.http_request.headers[self.header_name] = computed_md5 - request.context['validate_content_md5'] = computed_md5 - request.context['validate_content'] = validate_content + request.context["validate_content_md5"] = computed_md5 + request.context["validate_content"] = validate_content def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = request.context.get('validate_content_md5') or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: - raise AzureError(( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'."), - response=response.http_response + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, ) @@ -399,33 +403,41 @@ class StorageRetryPolicy(HTTPPolicy): """Whether the secondary endpoint should be retried.""" def __init__(self, **kwargs: Any) -> None: - self.total_retries = kwargs.pop('retry_total', 10) - self.connect_retries = kwargs.pop('retry_connect', 3) - self.read_retries = kwargs.pop('retry_read', 3) - self.status_retries = kwargs.pop('retry_status', 3) - self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) + self.total_retries = kwargs.pop("retry_total", 10) + self.connect_retries = kwargs.pop("retry_connect", 3) + self.read_retries = kwargs.pop("retry_read", 3) + self.status_retries = kwargs.pop("retry_status", 3) + self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. - :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param Dict[str, Any] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ - if settings['hosts'] and all(settings['hosts'].values()): + if settings["hosts"] and all(settings["hosts"].values()): url = urlparse(request.url) # If there's more than one possible location, retry to the alternative - if settings['mode'] == LocationMode.PRIMARY: - settings['mode'] = LocationMode.SECONDARY + if settings["mode"] == LocationMode.PRIMARY: + settings["mode"] = LocationMode.SECONDARY else: - settings['mode'] = LocationMode.PRIMARY - updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) + settings["mode"] = LocationMode.PRIMARY + updated = url._replace(netloc=settings["hosts"].get(settings["mode"])) request.url = updated.geturl() def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: + """ + Configure the retry settings for the request. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A dictionary containing the retry settings. + :rtype: Dict[str, Any] + """ body_position = None - if hasattr(request.http_request.body, 'read'): + if hasattr(request.http_request.body, "read"): try: body_position = request.http_request.body.tell() except (AttributeError, UnsupportedOperation): @@ -433,129 +445,140 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: pass options = request.context.options return { - 'total': options.pop("retry_total", self.total_retries), - 'connect': options.pop("retry_connect", self.connect_retries), - 'read': options.pop("retry_read", self.read_retries), - 'status': options.pop("retry_status", self.status_retries), - 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), - 'mode': options.pop("location_mode", LocationMode.PRIMARY), - 'hosts': options.pop("hosts", None), - 'hook': options.pop("retry_hook", None), - 'body_position': body_position, - 'count': 0, - 'history': [] + "total": options.pop("retry_total", self.total_retries), + "connect": options.pop("retry_connect", self.connect_retries), + "read": options.pop("retry_read", self.read_retries), + "status": options.pop("retry_status", self.status_retries), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "mode": options.pop("location_mode", LocationMode.PRIMARY), + "hosts": options.pop("hosts", None), + "hook": options.pop("retry_hook", None), + "body_position": body_position, + "count": 0, + "history": [], } def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument - """ Formula for computing the current backoff. + """Formula for computing the current backoff. Should be calculated by child class. :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. - :returns: The backoff time. + :return: The backoff time. :rtype: float """ return 0 def sleep(self, settings, transport): + """Sleep for the backoff time. + + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. + :param transport: The transport to use for sleeping. + :type transport: + ~azure.core.pipeline.transport.AsyncioBaseTransport or + ~azure.core.pipeline.transport.BaseTransport + """ backoff = self.get_backoff_time(settings) if not backoff or backoff < 0: return transport.sleep(backoff) def increment( - self, settings: Dict[str, Any], + self, + settings: Dict[str, Any], request: "PipelineRequest", response: Optional["PipelineResponse"] = None, - error: Optional[AzureError] = None + error: Optional[AzureError] = None, ) -> bool: """Increment the retry counters. :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. - :param PipelineRequest request: A pipeline request object. - :param Optional[PipelineResponse] response: A pipeline response object. - :param Optional[AzureError] error: An error encountered during the request, or + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse or None + :param error: An error encountered during the request, or None if the response was received successfully. - :returns: Whether the retry attempts are exhausted. + :type error: ~azure.core.exceptions.AzureError or None + :return: Whether the retry attempts are exhausted. :rtype: bool """ - settings['total'] -= 1 + settings["total"] -= 1 if error and isinstance(error, ServiceRequestError): # Errors when we're fairly sure that the server did not receive the # request, so it should be safe to retry. - settings['connect'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["connect"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) elif error and isinstance(error, ServiceResponseError): # Errors that occur after the request has been started, so we should # assume that the server began processing it. - settings['read'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["read"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist if response: - settings['status'] -= 1 - settings['history'].append(RequestHistory(request, http_response=response)) + settings["status"] -= 1 + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): - if request.method not in ['PUT'] and settings['retry_secondary']: + if request.method not in ["PUT"] and settings["retry_secondary"]: self._set_next_host_location(settings, request) # rewind the request body if it is a stream - if request.body and hasattr(request.body, 'read'): + if request.body and hasattr(request.body, "read"): # no position was saved, then retry would not work - if settings['body_position'] is None: + if settings["body_position"] is None: return False try: # attempt to rewind the body to the initial position - request.body.seek(settings['body_position'], SEEK_SET) + request.body.seek(settings["body_position"], SEEK_SET) except (UnsupportedOperation, ValueError): # if body is not seekable, then retry would not work return False - settings['count'] += 1 + settings["count"] += 1 return True return False def send(self, request): + """Send the request with retry logic. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ retries_remaining = True response = None retry_settings = self.configure_retries(request) while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + retry_hook(retry_settings, request=request.http_request, response=None, error=err) self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -571,12 +594,13 @@ class ExponentialRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, initial_backoff: int = 15, + self, + initial_backoff: int = 15, increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -601,21 +625,20 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to get backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -630,11 +653,12 @@ class LinearRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs a Linear retry object. @@ -653,15 +677,14 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float @@ -669,19 +692,27 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + """Handle the challenge from the service and authorize the request. + + :param request: The request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The response object. + :type response: ~azure.core.pipeline.PipelineResponse + :return: True if the request was authorized, False otherwise. + :rtype: bool + """ try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 807a51dd297c..4cb32f23248b 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -21,7 +21,7 @@ from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -29,29 +29,25 @@ async def retry_hook(settings, **kwargs): - if settings['hook']: - if asyncio.iscoroutine(settings['hook']): - await settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + if settings["hook"]: + if asyncio.iscoroutine(settings["hook"]): + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() # Load the body in memory and close the socket + except (StreamClosedError, StreamConsumedError): + pass + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -59,54 +55,56 @@ async def is_checksum_retry(response): class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(AsyncStorageResponseHook, self).__init__() async def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) - will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) # type: ignore + await response_callback(response) # type: ignore else: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response + class AsyncStorageRetryPolicy(StorageRetryPolicy): """ The base class for Exponential and Linear retries containing shared code. @@ -125,37 +123,29 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: await retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) await self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + await retry_hook(retry_settings, request=request.http_request, response=None, error=err) await self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -176,7 +166,8 @@ def __init__( increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, - random_jitter_range: int = 3, **kwargs + random_jitter_range: int = 3, + **kwargs ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -203,8 +194,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -217,7 +207,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -232,7 +222,8 @@ class LinearRetry(AsyncStorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, @@ -255,8 +246,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -271,14 +261,13 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/request_handlers.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/request_handlers.py index af500c8727fa..b23f65859690 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/request_handlers.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/request_handlers.py @@ -6,7 +6,7 @@ import logging import stat -from io import (SEEK_END, SEEK_SET, UnsupportedOperation) +from io import SEEK_END, SEEK_SET, UnsupportedOperation from os import fstat from typing import Dict, Optional @@ -37,12 +37,13 @@ def serialize_iso(attr): raise OverflowError("Hit max or min date") date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}" - return date + 'Z' + return date + "Z" except (ValueError, OverflowError) as err: raise ValueError("Unable to serialize datetime object.") from err except AttributeError as err: raise TypeError("ISO-8601 object must be valid datetime object.") from err + def get_length(data): length = None # Check if object implements the __len__ method, covers most input cases such as bytearray. @@ -62,7 +63,7 @@ def get_length(data): try: mode = fstat(fileno).st_mode if stat.S_ISREG(mode) or stat.S_ISLNK(mode): - #st_size only meaningful if regular file or symlink, other types + # st_size only meaningful if regular file or symlink, other types # e.g. sockets may return misleading sizes like 0 return fstat(fileno).st_size except OSError: @@ -84,13 +85,13 @@ def get_length(data): def read_length(data): try: - if hasattr(data, 'read'): - read_data = b'' + if hasattr(data, "read"): + read_data = b"" for chunk in iter(lambda: data.read(4096), b""): read_data += chunk return len(read_data), read_data - if hasattr(data, '__iter__'): - read_data = b'' + if hasattr(data, "__iter__"): + read_data = b"" for chunk in data: read_data += chunk return len(read_data), read_data @@ -100,8 +101,13 @@ def read_length(data): def validate_and_format_range_headers( - start_range, end_range, start_range_required=True, - end_range_required=True, check_content_md5=False, align_to_page=False): + start_range, + end_range, + start_range_required=True, + end_range_required=True, + check_content_md5=False, + align_to_page=False, +): # If end range is provided, start range must be provided if (start_range_required or end_range is not None) and start_range is None: raise ValueError("start_range value cannot be None.") @@ -111,16 +117,18 @@ def validate_and_format_range_headers( # Page ranges must be 512 aligned if align_to_page: if start_range is not None and start_range % 512 != 0: - raise ValueError(f"Invalid page blob start_range: {start_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob start_range: {start_range}. " "The size must be aligned to a 512-byte boundary." + ) if end_range is not None and end_range % 512 != 511: - raise ValueError(f"Invalid page blob end_range: {end_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob end_range: {end_range}. " "The size must be aligned to a 512-byte boundary." + ) # Format based on whether end_range is present range_header = None if end_range is not None: - range_header = f'bytes={start_range}-{end_range}' + range_header = f"bytes={start_range}-{end_range}" elif start_range is not None: range_header = f"bytes={start_range}-" @@ -131,7 +139,7 @@ def validate_and_format_range_headers( raise ValueError("Both start and end range required for MD5 content validation.") if end_range - start_range > 4 * 1024 * 1024: raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") - range_validation = 'true' + range_validation = "true" return range_header, range_validation @@ -140,7 +148,7 @@ def add_metadata_headers(metadata: Optional[Dict[str, str]] = None) -> Dict[str, headers = {} if metadata: for key, value in metadata.items(): - headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value + headers[f"x-ms-meta-{key.strip()}"] = value.strip() if value else value return headers @@ -158,29 +166,26 @@ def serialize_batch_body(requests, batch_id): a list of sub-request for the batch request :param str batch_id: to be embedded in batch sub-request delimiter - :returns: The body bytes for this batch. + :return: The body bytes for this batch. :rtype: bytes """ if requests is None or len(requests) == 0: - raise ValueError('Please provide sub-request(s) for this batch request') + raise ValueError("Please provide sub-request(s) for this batch request") - delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8') - newline_bytes = _HTTP_LINE_ENDING.encode('utf-8') + delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode("utf-8") + newline_bytes = _HTTP_LINE_ENDING.encode("utf-8") batch_body = [] content_index = 0 for request in requests: - request.headers.update({ - "Content-ID": str(content_index), - "Content-Length": str(0) - }) + request.headers.update({"Content-ID": str(content_index), "Content-Length": str(0)}) batch_body.append(delimiter_bytes) batch_body.append(_make_body_from_sub_request(request)) batch_body.append(newline_bytes) content_index += 1 - batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8')) + batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode("utf-8")) # final line of body MUST have \r\n at the end, or it will not be properly read by the service batch_body.append(newline_bytes) @@ -197,35 +202,35 @@ def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_da Whether to include the starting dashes. Used in the body, but non on defining the delimiter. :param bool is_append_dashes: Whether to include the ending dashes. Used in the body on the closing delimiter only. - :returns: The delimiter, WITHOUT a trailing newline. + :return: The delimiter, WITHOUT a trailing newline. :rtype: str """ - prepend_dashes = '--' if is_prepend_dashes else '' - append_dashes = '--' if is_append_dashes else '' + prepend_dashes = "--" if is_prepend_dashes else "" + append_dashes = "--" if is_append_dashes else "" return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes def _make_body_from_sub_request(sub_request): """ - Content-Type: application/http - Content-ID: - Content-Transfer-Encoding: (if present) + Content-Type: application/http + Content-ID: + Content-Transfer-Encoding: (if present) - HTTP/ -
:
(repeated as necessary) - Content-Length: - (newline if content length > 0) - (if content length > 0) + HTTP/ +
:
(repeated as necessary) + Content-Length: + (newline if content length > 0) + (if content length > 0) - Serializes an http request. + Serializes an http request. - :param ~azure.core.pipeline.transport.HttpRequest sub_request: - Request to serialize. - :returns: The serialized sub-request in bytes - :rtype: bytes - """ + :param ~azure.core.pipeline.transport.HttpRequest sub_request: + Request to serialize. + :return: The serialized sub-request in bytes + :rtype: bytes + """ # put the sub-request's headers into a list for efficient str concatenation sub_request_body = [] @@ -249,9 +254,9 @@ def _make_body_from_sub_request(sub_request): # append HTTP verb and path and query and HTTP version sub_request_body.append(sub_request.method) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(sub_request.url) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(_HTTP1_1_IDENTIFIER) sub_request_body.append(_HTTP_LINE_ENDING) @@ -266,4 +271,4 @@ def _make_body_from_sub_request(sub_request): # append blank line sub_request_body.append(_HTTP_LINE_ENDING) - return ''.join(sub_request_body).encode() + return "".join(sub_request_body).encode() diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/response_handlers.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/response_handlers.py index af9a2fcdcdc2..bcfa4147763e 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/response_handlers.py @@ -46,23 +46,23 @@ def parse_length_from_content_range(content_range): # First, split in space and take the second half: '1-3/65537' # Next, split on slash and take the second half: '65537' # Finally, convert to an int: 65537 - return int(content_range.split(' ', 1)[1].split('/', 1)[1]) + return int(content_range.split(" ", 1)[1].split("/", 1)[1]) def normalize_headers(headers): normalized = {} for key, value in headers.items(): - if key.startswith('x-ms-'): + if key.startswith("x-ms-"): key = key[5:] - normalized[key.lower().replace('-', '_')] = get_enum_value(value) + normalized[key.lower().replace("-", "_")] = get_enum_value(value) return normalized def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument try: - raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith("x-ms-meta-")} except AttributeError: - raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith("x-ms-meta-")} return {k[10:]: v for k, v in raw_metadata.items()} @@ -82,19 +82,23 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches raise_error = HttpResponseError serialized = False if isinstance(storage_error, AzureSigningError): - storage_error.message = storage_error.message + \ - '. This is likely due to an invalid shared key. Please check your shared key and try again.' + storage_error.message = ( + storage_error.message + + ". This is likely due to an invalid shared key. Please check your shared key and try again." + ) if not storage_error.response or storage_error.response.status_code in [200, 204]: raise storage_error # If it is one of those three then it has been serialized prior by the generated layer. - if isinstance(storage_error, (PartialBatchErrorException, - ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)): + if isinstance( + storage_error, + (PartialBatchErrorException, ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError), + ): serialized = True - error_code = storage_error.response.headers.get('x-ms-error-code') + error_code = storage_error.response.headers.get("x-ms-error-code") error_message = storage_error.message additional_data = {} error_dict = {} @@ -104,27 +108,25 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py if error_body is None or len(error_body) == 0: error_body = storage_error.response.reason except AttributeError: - error_body = '' + error_body = "" # If it is an XML response if isinstance(error_body, Element): - error_dict = { - child.tag.lower(): child.text - for child in error_body - } + error_dict = {child.tag.lower(): child.text for child in error_body} # If it is a JSON response elif isinstance(error_body, dict): - error_dict = error_body.get('error', {}) + error_dict = error_body.get("error", {}) elif not error_code: _LOGGER.warning( - 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) - error_dict = {'message': str(error_body)} + "Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.", type(error_body) + ) + error_dict = {"message": str(error_body)} # If we extracted from a Json or XML response # There is a chance error_dict is just a string if error_dict and isinstance(error_dict, dict): - error_code = error_dict.get('code') - error_message = error_dict.get('message') - additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}} + error_code = error_dict.get("code") + error_message = error_dict.get("message") + additional_data = {k: v for k, v in error_dict.items() if k not in {"code", "message"}} except DecodeError: pass @@ -132,31 +134,33 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py # This check would be unnecessary if we have already serialized the error if error_code and not serialized: error_code = StorageErrorCode(error_code) - if error_code in [StorageErrorCode.condition_not_met, - StorageErrorCode.blob_overwritten]: + if error_code in [StorageErrorCode.condition_not_met, StorageErrorCode.blob_overwritten]: raise_error = ResourceModifiedError - if error_code in [StorageErrorCode.invalid_authentication_info, - StorageErrorCode.authentication_failed]: + if error_code in [StorageErrorCode.invalid_authentication_info, StorageErrorCode.authentication_failed]: raise_error = ClientAuthenticationError - if error_code in [StorageErrorCode.resource_not_found, - StorageErrorCode.cannot_verify_copy_source, - StorageErrorCode.blob_not_found, - StorageErrorCode.queue_not_found, - StorageErrorCode.container_not_found, - StorageErrorCode.parent_not_found, - StorageErrorCode.share_not_found]: + if error_code in [ + StorageErrorCode.resource_not_found, + StorageErrorCode.cannot_verify_copy_source, + StorageErrorCode.blob_not_found, + StorageErrorCode.queue_not_found, + StorageErrorCode.container_not_found, + StorageErrorCode.parent_not_found, + StorageErrorCode.share_not_found, + ]: raise_error = ResourceNotFoundError - if error_code in [StorageErrorCode.account_already_exists, - StorageErrorCode.account_being_created, - StorageErrorCode.resource_already_exists, - StorageErrorCode.resource_type_mismatch, - StorageErrorCode.blob_already_exists, - StorageErrorCode.queue_already_exists, - StorageErrorCode.container_already_exists, - StorageErrorCode.container_being_deleted, - StorageErrorCode.queue_being_deleted, - StorageErrorCode.share_already_exists, - StorageErrorCode.share_being_deleted]: + if error_code in [ + StorageErrorCode.account_already_exists, + StorageErrorCode.account_being_created, + StorageErrorCode.resource_already_exists, + StorageErrorCode.resource_type_mismatch, + StorageErrorCode.blob_already_exists, + StorageErrorCode.queue_already_exists, + StorageErrorCode.container_already_exists, + StorageErrorCode.container_being_deleted, + StorageErrorCode.queue_being_deleted, + StorageErrorCode.share_already_exists, + StorageErrorCode.share_being_deleted, + ]: raise_error = ResourceExistsError except ValueError: # Got an unknown error code @@ -183,7 +187,7 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py error.args = (error.message,) try: # `from None` prevents us from double printing the exception (suppresses generated layer error context) - exec("raise error from None") # pylint: disable=exec-used # nosec + exec("raise error from None") # pylint: disable=exec-used # nosec except SyntaxError as exc: raise error from exc diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/shared_access_signature.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/shared_access_signature.py index 3a0530a58bdb..414608a15371 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/shared_access_signature.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/shared_access_signature.py @@ -11,44 +11,45 @@ from .constants import X_MS_VERSION from . import sign_string, url_quote + # cspell:ignoreRegExp rsc. # cspell:ignoreRegExp s..?id class QueryStringConstants(object): - SIGNED_SIGNATURE = 'sig' - SIGNED_PERMISSION = 'sp' - SIGNED_START = 'st' - SIGNED_EXPIRY = 'se' - SIGNED_RESOURCE = 'sr' - SIGNED_IDENTIFIER = 'si' - SIGNED_IP = 'sip' - SIGNED_PROTOCOL = 'spr' - SIGNED_VERSION = 'sv' - SIGNED_CACHE_CONTROL = 'rscc' - SIGNED_CONTENT_DISPOSITION = 'rscd' - SIGNED_CONTENT_ENCODING = 'rsce' - SIGNED_CONTENT_LANGUAGE = 'rscl' - SIGNED_CONTENT_TYPE = 'rsct' - START_PK = 'spk' - START_RK = 'srk' - END_PK = 'epk' - END_RK = 'erk' - SIGNED_RESOURCE_TYPES = 'srt' - SIGNED_SERVICES = 'ss' - SIGNED_OID = 'skoid' - SIGNED_TID = 'sktid' - SIGNED_KEY_START = 'skt' - SIGNED_KEY_EXPIRY = 'ske' - SIGNED_KEY_SERVICE = 'sks' - SIGNED_KEY_VERSION = 'skv' - SIGNED_ENCRYPTION_SCOPE = 'ses' - SIGNED_KEY_DELEGATED_USER_TID = 'skdutid' - SIGNED_DELEGATED_USER_OID = 'sduoid' + SIGNED_SIGNATURE = "sig" + SIGNED_PERMISSION = "sp" + SIGNED_START = "st" + SIGNED_EXPIRY = "se" + SIGNED_RESOURCE = "sr" + SIGNED_IDENTIFIER = "si" + SIGNED_IP = "sip" + SIGNED_PROTOCOL = "spr" + SIGNED_VERSION = "sv" + SIGNED_CACHE_CONTROL = "rscc" + SIGNED_CONTENT_DISPOSITION = "rscd" + SIGNED_CONTENT_ENCODING = "rsce" + SIGNED_CONTENT_LANGUAGE = "rscl" + SIGNED_CONTENT_TYPE = "rsct" + START_PK = "spk" + START_RK = "srk" + END_PK = "epk" + END_RK = "erk" + SIGNED_RESOURCE_TYPES = "srt" + SIGNED_SERVICES = "ss" + SIGNED_OID = "skoid" + SIGNED_TID = "sktid" + SIGNED_KEY_START = "skt" + SIGNED_KEY_EXPIRY = "ske" + SIGNED_KEY_SERVICE = "sks" + SIGNED_KEY_VERSION = "skv" + SIGNED_ENCRYPTION_SCOPE = "ses" + SIGNED_KEY_DELEGATED_USER_TID = "skdutid" + SIGNED_DELEGATED_USER_OID = "sduoid" # for ADLS - SIGNED_AUTHORIZED_OID = 'saoid' - SIGNED_UNAUTHORIZED_OID = 'suoid' - SIGNED_CORRELATION_ID = 'scid' - SIGNED_DIRECTORY_DEPTH = 'sdd' + SIGNED_AUTHORIZED_OID = "saoid" + SIGNED_UNAUTHORIZED_OID = "suoid" + SIGNED_CORRELATION_ID = "scid" + SIGNED_DIRECTORY_DEPTH = "sdd" @staticmethod def to_list(): @@ -91,38 +92,30 @@ def to_list(): class SharedAccessSignature(object): - ''' + """ Provides a factory for creating account access signature tokens with an account name and account key. Users can either use the factory or can construct the appropriate service and use the generate_*_shared_access_signature method directly. - ''' + """ def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION): - ''' + """ :param str account_name: The storage account name used to generate the shared access signatures. :param str account_key: The access key to generate the shares access signatures. :param str x_ms_version: The service version used to generate the shared access signatures. - ''' + """ self.account_name = account_name self.account_key = account_key self.x_ms_version = x_ms_version def generate_account( - self, services, - resource_types, - permission, - expiry, - start=None, - ip=None, - protocol=None, - sts_hook=None, - **kwargs + self, services, resource_types, permission, expiry, start=None, ip=None, protocol=None, sts_hook=None, **kwargs ) -> str: - ''' + """ Generates a shared access signature for the account. Use the returned signature with the sas_token parameter of the service or to create a new account object. @@ -168,10 +161,10 @@ def generate_account( :param sts_hook: For debugging purposes only. If provided, the hook is called with the string to sign that was used to generate the SAS. - :type sts_hook: Optional[Callable[[str], None]] - :returns: The generated SAS token for the account. + :type sts_hook: Optional[~typing.Callable[[str], None]] + :return: The generated SAS token for the account. :rtype: str - ''' + """ sas = _SharedAccessHelper() sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version) sas.add_account(services, resource_types) @@ -194,7 +187,7 @@ def _add_query(self, name, val): self.query_dict[name] = str(val) if val is not None else None def add_encryption_scope(self, **kwargs): - self._add_query(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE, kwargs.pop('encryption_scope', None)) + self._add_query(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE, kwargs.pop("encryption_scope", None)) def add_base(self, permission, expiry, start, ip, protocol, x_ms_version): if isinstance(start, date): @@ -220,11 +213,9 @@ def add_account(self, services, resource_types): self._add_query(QueryStringConstants.SIGNED_SERVICES, services) self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types) - def add_override_response_headers(self, cache_control, - content_disposition, - content_encoding, - content_language, - content_type): + def add_override_response_headers( + self, cache_control, content_disposition, content_encoding, content_language, content_type + ): self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control) self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition) self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding) @@ -233,24 +224,25 @@ def add_override_response_headers(self, cache_control, def add_account_signature(self, account_name, account_key): def get_value_to_append(query): - return_value = self.query_dict.get(query) or '' - return return_value + '\n' - - string_to_sign = \ - (account_name + '\n' + - get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + - get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + - get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + - get_value_to_append(QueryStringConstants.SIGNED_START) + - get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + - get_value_to_append(QueryStringConstants.SIGNED_IP) + - get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + - get_value_to_append(QueryStringConstants.SIGNED_VERSION) + - get_value_to_append(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE)) - - self._add_query(QueryStringConstants.SIGNED_SIGNATURE, - sign_string(account_key, string_to_sign)) + return_value = self.query_dict.get(query) or "" + return return_value + "\n" + + string_to_sign = ( + account_name + + "\n" + + get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + + get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + + get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + + get_value_to_append(QueryStringConstants.SIGNED_START) + + get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + + get_value_to_append(QueryStringConstants.SIGNED_IP) + + get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + + get_value_to_append(QueryStringConstants.SIGNED_VERSION) + + get_value_to_append(QueryStringConstants.SIGNED_ENCRYPTION_SCOPE) + ) + + self._add_query(QueryStringConstants.SIGNED_SIGNATURE, sign_string(account_key, string_to_sign)) self.string_to_sign = string_to_sign def get_token(self) -> str: - return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None]) + return "&".join([f"{n}={url_quote(v)}" for n, v in self.query_dict.items() if v is not None]) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads.py index b31cfb3291d9..7a5fb3f3dc91 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads.py @@ -12,7 +12,7 @@ from azure.core.tracing.common import with_current_context -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers @@ -41,20 +41,21 @@ def _parallel_uploads(executor, uploader, pending, running): def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - validate_content=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + validate_content=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -64,7 +65,8 @@ def upload_data_chunks( parallel=parallel, validate_content=validate_content, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: upload_tasks = uploader.get_chunk_streams() @@ -81,18 +83,19 @@ def upload_data_chunks( def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -100,7 +103,8 @@ def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: @@ -120,15 +124,17 @@ def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -253,7 +259,7 @@ def __init__(self, *args, **kwargs): def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) self.service.stage_block( block_id, @@ -261,20 +267,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) return index, block_id def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -302,11 +308,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): pass @@ -326,19 +332,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) def _upload_substream_block(self, index, block_stream): @@ -356,11 +363,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): try: @@ -371,7 +378,7 @@ def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -388,9 +395,9 @@ def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - return f'bytes={chunk_offset}-{chunk_end}', response + return f"bytes={chunk_offset}-{chunk_end}", response # TODO: Implement this method. def _upload_substream_block(self, index, block_stream): diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads_async.py index a056cd290230..6ed5ba1d0f91 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/uploads_async.py @@ -12,7 +12,7 @@ from math import ceil from typing import AsyncGenerator, Union -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers from .uploads import SubStream, IterStreamer # pylint: disable=unused-import @@ -59,19 +59,20 @@ async def _parallel_uploads(uploader, pending, running): async def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -80,7 +81,8 @@ async def upload_data_chunks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_chunk_streams() @@ -104,18 +106,19 @@ async def upload_data_chunks( async def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -123,13 +126,13 @@ async def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_substream_blocks() running_futures = [ - asyncio.ensure_future(uploader.process_substream_block(u)) - for u in islice(upload_tasks, 0, max_concurrency) + asyncio.ensure_future(uploader.process_substream_block(u)) for u in islice(upload_tasks, 0, max_concurrency) ] range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures) else: @@ -144,15 +147,17 @@ async def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -178,7 +183,7 @@ def __init__( async def get_chunk_streams(self): index = 0 while True: - data = b'' + data = b"" read_size = self.chunk_size # Buffer until we either reach the end of the stream or get a whole chunk. @@ -189,12 +194,12 @@ async def get_chunk_streams(self): if inspect.isawaitable(temp): temp = await temp if not isinstance(temp, bytes): - raise TypeError('Blob data should be of type bytes.') + raise TypeError("Blob data should be of type bytes.") data += temp or b"" # We have read an empty string and so are at the end # of the buffer or we have read a full chunk. - if temp == b'' or len(data) == self.chunk_size: + if temp == b"" or len(data) == self.chunk_size: break if len(data) == self.chunk_size: @@ -273,13 +278,13 @@ def set_response_properties(self, resp): class BlockBlobChunkUploader(_ChunkUploader): def __init__(self, *args, **kwargs): - kwargs.pop('modified_access_conditions', None) + kwargs.pop("modified_access_conditions", None) super(BlockBlobChunkUploader, self).__init__(*args, **kwargs) self.current_length = None async def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) await self.service.stage_block( block_id, @@ -287,19 +292,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): body=chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) return index, block_id async def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" await self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) finally: block_stream.close() return block_id @@ -311,7 +318,7 @@ def _is_chunk_empty(self, chunk_data): # read until non-zero byte is encountered # if reached the end without returning, then chunk_data is all 0's for each_byte in chunk_data: - if each_byte not in [0, b'\x00']: + if each_byte not in [0, b"\x00"]: return False return True @@ -319,7 +326,7 @@ async def _upload_chunk(self, chunk_offset, chunk_data): # avoid uploading the empty pages if not self._is_chunk_empty(chunk_data): chunk_end = chunk_offset + len(chunk_data) - 1 - content_range = f'bytes={chunk_offset}-{chunk_end}' + content_range = f"bytes={chunk_offset}-{chunk_end}" computed_md5 = None self.response_headers = await self.service.upload_pages( body=chunk_data, @@ -329,10 +336,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): pass @@ -352,18 +360,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) - self.current_length = int(self.response_headers['blob_append_offset']) + **self.request_options, + ) + self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = await self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) async def _upload_substream_block(self, index, block_stream): pass @@ -379,11 +390,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): try: @@ -394,7 +405,7 @@ async def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -411,9 +422,9 @@ async def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - range_id = f'bytes={chunk_offset}-{chunk_end}' + range_id = f"bytes={chunk_offset}-{chunk_end}" return range_id, response # TODO: Implement this method. @@ -421,10 +432,11 @@ async def _upload_substream_block(self, index, block_stream): pass -class AsyncIterStreamer(): +class AsyncIterStreamer: """ File-like streaming object for AsyncGenerators. """ + def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"): self.iterator = generator.__aiter__() self.leftover = b"" diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_service_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_service_client_async.py index 6cf5feffee89..214d1c791250 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_service_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_service_client_async.py @@ -134,9 +134,12 @@ async def __aexit__(self, *args: Any) -> None: await self._blob_service_client.close() await super(DataLakeServiceClient, self).__aexit__(*args) - async def close(self) -> None: + async def close(self) -> None: # type: ignore """ This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self.__aexit__() diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_file_system_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_file_system_client_async.py index c0535531b393..b8e49294fa5f 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_file_system_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_file_system_client_async.py @@ -158,9 +158,12 @@ async def __aexit__(self, *args: Any) -> None: await self._datalake_client_for_blob_operation.close() await super(FileSystemClient, self).__aexit__(*args) - async def close(self) -> None: + async def close(self) -> None: # type: ignore """This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self.__aexit__() diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_path_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_path_client_async.py index 630599174d45..e0fa5fd3b4bb 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_path_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_path_client_async.py @@ -155,10 +155,13 @@ async def __aexit__(self, *args: Any) -> None: await self._datalake_client_for_blob_operation.close() await super(PathClient, self).__aexit__(*args) - async def close(self) -> None: + async def close(self) -> None: # type: ignore """ This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self.__aexit__() diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_helpers_async.py b/sdk/storage/azure-storage-file-datalake/tests/test_helpers_async.py index 8b0185e6eb85..964a63da4016 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_helpers_async.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_helpers_async.py @@ -3,12 +3,16 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import asyncio +from collections import deque from typing import Any, Dict, Optional from urllib.parse import urlparse from azure.core.pipeline.transport import AioHttpTransportResponse, AsyncHttpTransport from azure.core.rest import HttpRequest from aiohttp import ClientResponse +from aiohttp.streams import StreamReader +from aiohttp.client_proto import ResponseHandler class ProgressTracker: @@ -65,6 +69,10 @@ def __init__( self._loop = None self.status = status self.reason = reason + self.content = StreamReader(ResponseHandler(asyncio.new_event_loop()), 65535) + self.content.total_bytes = len(body_bytes) + self.content._buffer = deque([body_bytes]) + self.content._eof = True class MockStorageTransport(AsyncHttpTransport): diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/__init__.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/__init__.py index a8b1a27d48f9..4dbbb7ed7b09 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/__init__.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/__init__.py @@ -11,7 +11,7 @@ try: from urllib.parse import quote, unquote except ImportError: - from urllib2 import quote, unquote # type: ignore + from urllib2 import quote, unquote # type: ignore def url_quote(url): @@ -24,20 +24,20 @@ def url_unquote(url): def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") def decode_base64_to_bytes(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") return base64.b64decode(data) def decode_base64_to_text(data): decoded_bytes = decode_base64_to_bytes(data) - return decoded_bytes.decode('utf-8') + return decoded_bytes.decode("utf-8") def sign_string(key, string_to_sign, key_is_base64=True): @@ -45,9 +45,9 @@ def sign_string(key, string_to_sign, key_is_base64=True): key = decode_base64_to_bytes(key) else: if isinstance(key, str): - key = key.encode('utf-8') + key = key.encode("utf-8") if isinstance(string_to_sign, str): - string_to_sign = string_to_sign.encode('utf-8') + string_to_sign = string_to_sign.encode("utf-8") signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = encode_base64(digest) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/authentication.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/authentication.py index 44c563d8c75e..f778dc71eec4 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/authentication.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/authentication.py @@ -27,6 +27,8 @@ logger = logging.getLogger(__name__) + +# fmt: off table_lv0 = [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, @@ -50,6 +52,8 @@ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, ] +# fmt: on + def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements tables = [table_lv0, table_lv4] @@ -94,6 +98,7 @@ def _wrap_exception(ex, desired_type): msg = ex.args[0] return desired_type(msg) + # This method attempts to emulate the sorting done by the service def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: @@ -134,38 +139,42 @@ def __init__(self, account_name, account_key): @staticmethod def _get_headers(request, headers_to_sign): headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value) - if 'content-length' in headers and headers['content-length'] == '0': - del headers['content-length'] - return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n' + if "content-length" in headers and headers["content-length"] == "0": + del headers["content-length"] + return "\n".join(headers.get(x, "") for x in headers_to_sign) + "\n" @staticmethod def _get_verb(request): - return request.http_request.method + '\n' + return request.http_request.method + "\n" def _get_canonicalized_resource(self, request): uri_path = urlparse(request.http_request.url).path try: - if isinstance(request.context.transport, AioHttpTransport) or \ - isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \ - isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None), - AioHttpTransport): + if ( + isinstance(request.context.transport, AioHttpTransport) + or isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) + or isinstance( + getattr(getattr(request.context.transport, "_transport", None), "_transport", None), + AioHttpTransport, + ) + ): uri_path = URL(uri_path) - return '/' + self.account_name + str(uri_path) + return "/" + self.account_name + str(uri_path) except TypeError: pass - return '/' + self.account_name + uri_path + return "/" + self.account_name + uri_path @staticmethod def _get_canonicalized_headers(request): - string_to_sign = '' + string_to_sign = "" x_ms_headers = [] for name, value in request.http_request.headers.items(): - if name.startswith('x-ms-'): + if name.startswith("x-ms-"): x_ms_headers.append((name.lower(), value)) x_ms_headers = _storage_header_sort(x_ms_headers) for name, value in x_ms_headers: if value is not None: - string_to_sign += ''.join([name, ':', value, '\n']) + string_to_sign += "".join([name, ":", value, "\n"]) return string_to_sign @staticmethod @@ -173,37 +182,46 @@ def _get_canonicalized_resource_query(request): sorted_queries = list(request.http_request.query.items()) sorted_queries.sort() - string_to_sign = '' + string_to_sign = "" for name, value in sorted_queries: if value is not None: - string_to_sign += '\n' + name.lower() + ':' + unquote(value) + string_to_sign += "\n" + name.lower() + ":" + unquote(value) return string_to_sign def _add_authorization_header(self, request, string_to_sign): try: signature = sign_string(self.account_key, string_to_sign) - auth_string = 'SharedKey ' + self.account_name + ':' + signature - request.http_request.headers['Authorization'] = auth_string + auth_string = "SharedKey " + self.account_name + ":" + signature + request.http_request.headers["Authorization"] = auth_string except Exception as ex: # Wrap any error that occurred as signing error # Doing so will clarify/locate the source of problem raise _wrap_exception(ex, AzureSigningError) from ex def on_request(self, request): - string_to_sign = \ - self._get_verb(request) + \ - self._get_headers( + string_to_sign = ( + self._get_verb(request) + + self._get_headers( request, [ - 'content-encoding', 'content-language', 'content-length', - 'content-md5', 'content-type', 'date', 'if-modified-since', - 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range' - ] - ) + \ - self._get_canonicalized_headers(request) + \ - self._get_canonicalized_resource(request) + \ - self._get_canonicalized_resource_query(request) + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", + ], + ) + + self._get_canonicalized_headers(request) + + self._get_canonicalized_resource(request) + + self._get_canonicalized_resource_query(request) + ) self._add_authorization_header(request, string_to_sign) # logger.debug("String_to_sign=%s", string_to_sign) @@ -211,7 +229,7 @@ def on_request(self, request): class StorageHttpChallenge(object): def __init__(self, challenge): - """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """ + """Parses an HTTP WWW-Authentication Bearer challenge from the Storage service.""" if not challenge: raise ValueError("Challenge cannot be empty") @@ -220,7 +238,7 @@ def __init__(self, challenge): # name=value pairs either comma or space separated with values possibly being # enclosed in quotes - for item in re.split('[, ]', trimmed_challenge): + for item in re.split("[, ]", trimmed_challenge): comps = item.split("=") if len(comps) == 2: key = comps[0].strip(' "') @@ -229,11 +247,11 @@ def __init__(self, challenge): self._parameters[key] = value # Extract and verify required parameters - self.authorization_uri = self._parameters.get('authorization_uri') + self.authorization_uri = self._parameters.get("authorization_uri") if not self.authorization_uri: raise ValueError("Authorization Uri not found") - self.resource_id = self._parameters.get('resource_id') + self.resource_id = self._parameters.get("resource_id") if not self.resource_id: raise ValueError("Resource id not found") diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py index 7de14050b963..217eb2110f15 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py @@ -20,7 +20,10 @@ from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline -from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module + HttpTransport, + RequestsTransport, +) from azure.core.pipeline.policies import ( AzureSasCredentialPolicy, ContentDecodePolicy, @@ -73,8 +76,17 @@ def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + credential: Optional[ + Union[ + str, + Dict[str, str], + AzureNamedKeyCredential, + AzureSasCredential, + "AsyncTokenCredential", + TokenCredential, + ] + ] = None, + **kwargs: Any, ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts", {}) @@ -83,12 +95,15 @@ def __init__( if service not in ["blob", "queue", "file-share", "dfs"]: raise ValueError(f"Invalid service: {service}") - service_name = service.split('-')[0] + service_name = service.split("-")[0] account = parsed_url.netloc.split(f".{service_name}.core.") self.account_name = account[0] if len(account) > 1 else None - if not self.account_name and parsed_url.netloc.startswith("localhost") \ - or parsed_url.netloc.startswith("127.0.0.1"): + if ( + not self.account_name + and parsed_url.netloc.startswith("localhost") + or parsed_url.netloc.startswith("127.0.0.1") + ): self._is_localhost = True self.account_name = parsed_url.path.strip("/") @@ -106,7 +121,7 @@ def __init__( secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") if kwargs.get("secondary_hostname"): secondary_hostname = kwargs["secondary_hostname"] - primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') + primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip("/") self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} self._sdk_moniker = f"storage-{service}/{VERSION}" @@ -119,71 +134,76 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): - """ This method is to close the sockets opened by the client. + def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. """ self._client.close() @property - def url(self): + def url(self) -> str: """The full endpoint URL to this entity, including SAS token if used. This could be either the primary endpoint, or the secondary endpoint depending on the current :func:`location_mode`. - :returns: The full endpoint URL to this entity, including SAS token if used. + :return: The full endpoint URL to this entity, including SAS token if used. :rtype: str """ - return self._format_url(self._hosts[self._location_mode]) + return self._format_url(self._hosts[self._location_mode]) # type: ignore @property - def primary_endpoint(self): + def primary_endpoint(self) -> str: """The full primary endpoint URL. + :return: The full primary endpoint URL. :rtype: str """ - return self._format_url(self._hosts[LocationMode.PRIMARY]) + return self._format_url(self._hosts[LocationMode.PRIMARY]) # type: ignore @property - def primary_hostname(self): + def primary_hostname(self) -> str: """The hostname of the primary endpoint. + :return: The hostname of the primary endpoint. :rtype: str """ return self._hosts[LocationMode.PRIMARY] @property - def secondary_endpoint(self): + def secondary_endpoint(self) -> str: """The full secondary endpoint URL if configured. If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The full secondary endpoint URL. :rtype: str - :raise ValueError: + :raise ValueError: If no secondary endpoint is configured. """ if not self._hosts[LocationMode.SECONDARY]: raise ValueError("No secondary host configured.") - return self._format_url(self._hosts[LocationMode.SECONDARY]) + return self._format_url(self._hosts[LocationMode.SECONDARY]) # type: ignore @property - def secondary_hostname(self): + def secondary_hostname(self) -> Optional[str]: """The hostname of the secondary endpoint. If not available this will be None. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The hostname of the secondary endpoint, or None if not configured. :rtype: Optional[str] """ return self._hosts[LocationMode.SECONDARY] @property - def location_mode(self): + def location_mode(self) -> str: """The location mode that the client is currently using. By default this will be "primary". Options include "primary" and "secondary". + :return: The current location mode. :rtype: str """ @@ -206,11 +226,16 @@ def api_version(self): return self._client._config.version # pylint: disable=protected-access def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -218,7 +243,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): credential = cast(str, credential) query_str += credential.lstrip("?") @@ -228,13 +254,16 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) @@ -268,22 +297,18 @@ def _create_pipeline( config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs) + HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore return config, Pipeline(transport, policies=policies) - def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> Iterator["HttpResponse"]: + def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> Iterator["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An iterator of HttpResponse objects. + :return: An iterator of HttpResponse objects. :rtype: Iterator[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg @@ -292,25 +317,21 @@ def _batch_send( request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), headers={ - 'x-ms-version': self.api_version, - "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) - } + "x-ms-version": self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False), + }, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access body = serialize_batch_body(request.multipart_mixed_info[0], batch_id) @@ -318,9 +339,7 @@ def _batch_send( temp = request.multipart_mixed_info request.multipart_mixed_info = None - pipeline_response = self._pipeline.run( - request, **kwargs - ) + pipeline_response = self._pipeline.run(request, **kwargs) response = pipeline_response.http_response request.multipart_mixed_info = temp @@ -332,8 +351,7 @@ def _batch_send( parts = list(response.parts()) if any(p for p in parts if not 200 <= p.status_code < 300): error = PartialBatchErrorException( - message="There is a partial failure in the batch operation.", - response=response, parts=parts + message="There is a partial failure in the batch operation.", response=response, parts=parts ) raise error return iter(parts) @@ -347,6 +365,7 @@ class TransportWrapper(HttpTransport): by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, transport): self._transport = transport @@ -368,7 +387,9 @@ def __exit__(self, *args): def _format_shared_key_credential( account_name: Optional[str], - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential] + ] = None, ) -> Any: if isinstance(credential, str): if not account_name: @@ -388,8 +409,12 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -411,14 +436,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -438,7 +460,7 @@ def parse_connection_str( def create_configuration(**kwargs: Any) -> StorageConfiguration: - # Backwards compatibility if someone is not passing sdk_moniker + # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" config = StorageConfiguration(**kwargs) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py index 6186b29db107..f39a57b24943 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py @@ -64,18 +64,26 @@ async def __aenter__(self): async def __aexit__(self, *args): await self._client.__aexit__(*args) - async def close(self): - """ This method is to close the sockets opened by the client. + async def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self._client.close() def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -83,7 +91,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): query_str += credential.lstrip("?") # type: ignore [union-attr] credential = None @@ -92,35 +101,40 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[ - Union[AsyncStorageBearerTokenCredentialPolicy, - SharedKeyCredentialPolicy, - AzureSasCredentialPolicy]] = None - if hasattr(credential, 'get_token'): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + Union[AsyncStorageBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy] + ] = None + if hasattr(credential, "get_token"): + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = AsyncStorageBearerTokenCredentialPolicy( - cast(AsyncTokenCredential, credential), audience) + cast(AsyncTokenCredential, credential), audience + ) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): self._credential_policy = AzureSasCredentialPolicy(credential) elif credential is not None: raise TypeError(f"Unsupported credential: {type(credential)}") - config = kwargs.get('_configuration') or create_configuration(**kwargs) - if kwargs.get('_pipeline'): - return config, kwargs['_pipeline'] - transport = kwargs.get('transport') + config = kwargs.get("_configuration") or create_configuration(**kwargs) + if kwargs.get("_pipeline"): + return config, kwargs["_pipeline"] + transport = kwargs.get("transport") kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) if not transport: try: - from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + AioHttpTransport, + ) except ImportError as exc: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc transport = AioHttpTransport(**kwargs) @@ -143,53 +157,41 @@ def _create_pipeline( HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore - config.transport = transport #type: ignore - return config, AsyncPipeline(transport, policies=policies) #type: ignore + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, AsyncPipeline(transport, policies=policies) # type: ignore - async def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> AsyncList["HttpResponse"]: + async def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> AsyncList["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An AsyncList of HttpResponse objects. + :return: An AsyncList of HttpResponse objects. :rtype: AsyncList[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), - headers={ - 'x-ms-version': self.api_version - } + headers={"x-ms-version": self.api_version}, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) # type: ignore - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) - pipeline_response = await self._pipeline.run( - request, **kwargs - ) + pipeline_response = await self._pipeline.run(request, **kwargs) response = pipeline_response.http_response try: if response.status_code not in [202]: raise HttpResponseError(response=response) - parts = response.parts() # Return an AsyncIterator + parts = response.parts() # Return an AsyncIterator if raise_on_any_failure: parts_list = [] async for part in parts: @@ -197,7 +199,8 @@ async def _batch_send( if any(p for p in parts_list if not 200 <= p.status_code < 300): error = PartialBatchErrorException( message="There is a partial failure in the batch operation.", - response=response, parts=parts_list + response=response, + parts=parts_list, ) raise error return AsyncList(parts_list) @@ -205,11 +208,16 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) + def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -231,14 +239,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -256,11 +261,13 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential + class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, async_transport): self._transport = async_transport diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/constants.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/constants.py index 0b4b029a2d1b..0926f04c4081 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/constants.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/constants.py @@ -16,4 +16,4 @@ DEFAULT_OAUTH_SCOPE = "/.default" STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default" -SERVICE_HOST_BASE = 'core.windows.net' +SERVICE_HOST_BASE = "core.windows.net" diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/models.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/models.py index 403e6b8bea37..185d58860fae 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/models.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/models.py @@ -22,6 +22,7 @@ def get_enum_value(value): class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Error codes returned by the service.""" # Generic storage values ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists" @@ -172,26 +173,26 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed" # DataLake values - CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero' - PATH_ALREADY_EXISTS = 'PathAlreadyExists' - INVALID_FLUSH_POSITION = 'InvalidFlushPosition' - INVALID_PROPERTY_NAME = 'InvalidPropertyName' - INVALID_SOURCE_URI = 'InvalidSourceUri' - UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion' - FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound' - PATH_NOT_FOUND = 'PathNotFound' - RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound' - SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound' - DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted' - FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists' - FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted' - INVALID_DESTINATION_PATH = 'InvalidDestinationPath' - INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath' - INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType' - LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken' - LEASE_NAME_MISMATCH = 'LeaseNameMismatch' - PATH_CONFLICT = 'PathConflict' - SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted' + CONTENT_LENGTH_MUST_BE_ZERO = "ContentLengthMustBeZero" + PATH_ALREADY_EXISTS = "PathAlreadyExists" + INVALID_FLUSH_POSITION = "InvalidFlushPosition" + INVALID_PROPERTY_NAME = "InvalidPropertyName" + INVALID_SOURCE_URI = "InvalidSourceUri" + UNSUPPORTED_REST_VERSION = "UnsupportedRestVersion" + FILE_SYSTEM_NOT_FOUND = "FilesystemNotFound" + PATH_NOT_FOUND = "PathNotFound" + RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = "RenameDestinationParentPathNotFound" + SOURCE_PATH_NOT_FOUND = "SourcePathNotFound" + DESTINATION_PATH_IS_BEING_DELETED = "DestinationPathIsBeingDeleted" + FILE_SYSTEM_ALREADY_EXISTS = "FilesystemAlreadyExists" + FILE_SYSTEM_BEING_DELETED = "FilesystemBeingDeleted" + INVALID_DESTINATION_PATH = "InvalidDestinationPath" + INVALID_RENAME_SOURCE_PATH = "InvalidRenameSourcePath" + INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = "InvalidSourceOrDestinationResourceType" + LEASE_IS_ALREADY_BROKEN = "LeaseIsAlreadyBroken" + LEASE_NAME_MISMATCH = "LeaseNameMismatch" + PATH_CONFLICT = "PathConflict" + SOURCE_PATH_IS_BEING_DELETED = "SourcePathIsBeingDeleted" class DictMixin(object): @@ -222,7 +223,7 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')}) + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) def __contains__(self, key): return key in self.__dict__ @@ -234,13 +235,13 @@ def update(self, *args, **kwargs): return self.__dict__.update(*args, **kwargs) def keys(self): - return [k for k in self.__dict__ if not k.startswith('_')] + return [k for k in self.__dict__ if not k.startswith("_")] def values(self): - return [v for k, v in self.__dict__.items() if not k.startswith('_')] + return [v for k, v in self.__dict__.items() if not k.startswith("_")] def items(self): - return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')] + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] def get(self, key, default=None): if key in self.__dict__: @@ -255,8 +256,8 @@ class LocationMode(object): must use PRIMARY. """ - PRIMARY = 'primary' #: Requests should be sent to the primary location. - SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible. + PRIMARY = "primary" #: Requests should be sent to the primary location. + SECONDARY = "secondary" #: Requests should be sent to the secondary location, if possible. class ResourceTypes(object): @@ -281,17 +282,12 @@ class ResourceTypes(object): _str: str def __init__( - self, - service: bool = False, - container: bool = False, - object: bool = False # pylint: disable=redefined-builtin + self, service: bool = False, container: bool = False, object: bool = False # pylint: disable=redefined-builtin ) -> None: self.service = service self.container = container self.object = object - self._str = (('s' if self.service else '') + - ('c' if self.container else '') + - ('o' if self.object else '')) + self._str = ("s" if self.service else "") + ("c" if self.container else "") + ("o" if self.object else "") def __str__(self): return self._str @@ -309,9 +305,9 @@ def from_string(cls, string): :return: A ResourceTypes object :rtype: ~azure.storage.fileshare.ResourceTypes """ - res_service = 's' in string - res_container = 'c' in string - res_object = 'o' in string + res_service = "s" in string + res_container = "c" in string + res_object = "o" in string parsed = cls(res_service, res_container, res_object) parsed._str = string @@ -392,29 +388,30 @@ def __init__( self.write = write self.delete = delete self.delete_previous_version = delete_previous_version - self.permanent_delete = kwargs.pop('permanent_delete', False) + self.permanent_delete = kwargs.pop("permanent_delete", False) self.list = list self.add = add self.create = create self.update = update self.process = process - self.tag = kwargs.pop('tag', False) - self.filter_by_tags = kwargs.pop('filter_by_tags', False) - self.set_immutability_policy = kwargs.pop('set_immutability_policy', False) - self._str = (('r' if self.read else '') + - ('w' if self.write else '') + - ('d' if self.delete else '') + - ('x' if self.delete_previous_version else '') + - ('y' if self.permanent_delete else '') + - ('l' if self.list else '') + - ('a' if self.add else '') + - ('c' if self.create else '') + - ('u' if self.update else '') + - ('p' if self.process else '') + - ('f' if self.filter_by_tags else '') + - ('t' if self.tag else '') + - ('i' if self.set_immutability_policy else '') - ) + self.tag = kwargs.pop("tag", False) + self.filter_by_tags = kwargs.pop("filter_by_tags", False) + self.set_immutability_policy = kwargs.pop("set_immutability_policy", False) + self._str = ( + ("r" if self.read else "") + + ("w" if self.write else "") + + ("d" if self.delete else "") + + ("x" if self.delete_previous_version else "") + + ("y" if self.permanent_delete else "") + + ("l" if self.list else "") + + ("a" if self.add else "") + + ("c" if self.create else "") + + ("u" if self.update else "") + + ("p" if self.process else "") + + ("f" if self.filter_by_tags else "") + + ("t" if self.tag else "") + + ("i" if self.set_immutability_policy else "") + ) def __str__(self): return self._str @@ -432,23 +429,34 @@ def from_string(cls, permission): :return: An AccountSasPermissions object :rtype: ~azure.storage.fileshare.AccountSasPermissions """ - p_read = 'r' in permission - p_write = 'w' in permission - p_delete = 'd' in permission - p_delete_previous_version = 'x' in permission - p_permanent_delete = 'y' in permission - p_list = 'l' in permission - p_add = 'a' in permission - p_create = 'c' in permission - p_update = 'u' in permission - p_process = 'p' in permission - p_tag = 't' in permission - p_filter_by_tags = 'f' in permission - p_set_immutability_policy = 'i' in permission - parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version, - list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag, - filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy, - permanent_delete=p_permanent_delete) + p_read = "r" in permission + p_write = "w" in permission + p_delete = "d" in permission + p_delete_previous_version = "x" in permission + p_permanent_delete = "y" in permission + p_list = "l" in permission + p_add = "a" in permission + p_create = "c" in permission + p_update = "u" in permission + p_process = "p" in permission + p_tag = "t" in permission + p_filter_by_tags = "f" in permission + p_set_immutability_policy = "i" in permission + parsed = cls( + read=p_read, + write=p_write, + delete=p_delete, + delete_previous_version=p_delete_previous_version, + list=p_list, + add=p_add, + create=p_create, + update=p_update, + process=p_process, + tag=p_tag, + filter_by_tags=p_filter_by_tags, + set_immutability_policy=p_set_immutability_policy, + permanent_delete=p_permanent_delete, + ) return parsed @@ -464,18 +472,11 @@ class Services(object): Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False. """ - def __init__( - self, *, - blob: bool = False, - queue: bool = False, - fileshare: bool = False - ) -> None: + def __init__(self, *, blob: bool = False, queue: bool = False, fileshare: bool = False) -> None: self.blob = blob self.queue = queue self.fileshare = fileshare - self._str = (('b' if self.blob else '') + - ('q' if self.queue else '') + - ('f' if self.fileshare else '')) + self._str = ("b" if self.blob else "") + ("q" if self.queue else "") + ("f" if self.fileshare else "") def __str__(self): return self._str @@ -493,9 +494,9 @@ def from_string(cls, string): :return: A Services object :rtype: ~azure.storage.fileshare.Services """ - res_blob = 'b' in string - res_queue = 'q' in string - res_file = 'f' in string + res_blob = "b" in string + res_queue = "q" in string + res_file = "f" in string parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file) parsed._str = string @@ -573,13 +574,13 @@ class StorageConfiguration(Configuration): def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) - self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) + self.max_single_put_size = kwargs.pop("max_single_put_size", 64 * 1024 * 1024) self.copy_polling_interval = 15 - self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) - self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) - self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) - self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) - self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) - self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) - self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) - self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) + self.max_block_size = kwargs.pop("max_block_size", 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop("use_byte_buffer", False) + self.max_page_size = kwargs.pop("max_page_size", 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop("max_single_get_size", 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop("max_chunk_get_size", 4 * 1024 * 1024) + self.max_range_size = kwargs.pop("max_range_size", 4 * 1024 * 1024) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/parser.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/parser.py index 112c1984f4fb..e4fcb8f041ba 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/parser.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/parser.py @@ -12,14 +12,14 @@ def _to_utc_datetime(value: datetime) -> str: - return value.strftime('%Y-%m-%dT%H:%M:%SZ') + return value.strftime("%Y-%m-%dT%H:%M:%SZ") def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: """Converts an RFC 1123 date string to a UTC datetime. :param str rfc_1123: The time and date in RFC 1123 format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not rfc_1123: @@ -33,7 +33,7 @@ def _filetime_to_datetime(filetime: str) -> Optional[datetime]: If parsing MS Filetime fails, tries RFC 1123 as backup. :param str filetime: The time and date in MS filetime format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not filetime: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index ee75cd5a466c..a08fee7afaac 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -28,7 +28,7 @@ HTTPPolicy, NetworkTraceLoggingPolicy, RequestHistory, - SansIOHTTPPolicy + SansIOHTTPPolicy, ) from .authentication import AzureSigningError, StorageHttpChallenge @@ -39,7 +39,7 @@ from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -48,14 +48,14 @@ def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -63,8 +63,8 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): - if settings['hook']: - settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + if settings["hook"]: + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -95,40 +95,39 @@ def is_retry(response, mode): def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False def urljoin(base_url, stub_url): parsed = urlparse(base_url) - parsed = parsed._replace(path=parsed.path + '/' + stub_url) + parsed = parsed._replace(path=parsed.path + "/" + stub_url) return parsed.geturl() class QueueMessagePolicy(SansIOHTTPPolicy): def on_request(self, request): - message_id = request.context.options.pop('queue_message_id', None) + message_id = request.context.options.pop("queue_message_id", None) if message_id: - request.http_request.url = urljoin( - request.http_request.url, - message_id) + request.http_request.url = urljoin(request.http_request.url, message_id) class StorageHeadersPolicy(HeadersPolicy): - request_id_header_name = 'x-ms-client-request-id' + request_id_header_name = "x-ms-client-request-id" def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) - request.http_request.headers['x-ms-date'] = current_time + request.http_request.headers["x-ms-date"] = current_time - custom_id = request.context.options.pop('client_request_id', None) - request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) + custom_id = request.context.options.pop("client_request_id", None) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -153,7 +152,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument super(StorageHosts, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request.context.options['hosts'] = self.hosts + request.context.options["hosts"] = self.hosts parsed_url = urlparse(request.http_request.url) # Detect what location mode we're currently requesting with @@ -163,10 +162,10 @@ def on_request(self, request: "PipelineRequest") -> None: location_mode = key # See if a specific location mode has been specified, and if so, redirect - use_location = request.context.options.pop('use_location', None) + use_location = request.context.options.pop("use_location", None) if use_location: # Lock retries to the specific location - request.context.options['retry_to_secondary'] = False + request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: @@ -175,7 +174,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.url = updated.geturl() location_mode = use_location - request.context.options['location_mode'] = location_mode + request.context.options["location_mode"] = location_mode class StorageLoggingPolicy(NetworkTraceLoggingPolicy): @@ -200,19 +199,19 @@ def on_request(self, request: "PipelineRequest") -> None: try: log_url = http_request.url query_params = http_request.query - if 'sig' in query_params: - log_url = log_url.replace(query_params['sig'], "sig=*****") + if "sig" in query_params: + log_url = log_url.replace(query_params["sig"], "sig=*****") _LOGGER.debug("Request URL: %r", log_url) _LOGGER.debug("Request method: %r", http_request.method) _LOGGER.debug("Request headers:") for header, value in http_request.headers.items(): - if header.lower() == 'authorization': - value = '*****' - elif header.lower() == 'x-ms-copy-source' and 'sig' in value: + if header.lower() == "authorization": + value = "*****" + elif header.lower() == "x-ms-copy-source" and "sig" in value: # take the url apart and scrub away the signed signature scheme, netloc, path, params, query, fragment = urlparse(value) parsed_qs = dict(parse_qsl(query)) - parsed_qs['sig'] = '*****' + parsed_qs["sig"] = "*****" # the SAS needs to be put back together value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) @@ -242,11 +241,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") # We don't want to log binary data if the response is a file. _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) - header = response.http_response.headers.get('content-disposition') + header = response.http_response.headers.get("content-disposition") resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): - filename = header.partition('=')[2] + filename = header.partition("=")[2] _LOGGER.debug("File attachments: %s", filename) elif resp_content_type.endswith("octet-stream"): _LOGGER.debug("Body contains binary data.") @@ -268,11 +267,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") class StorageRequestHook(SansIOHTTPPolicy): def __init__(self, **kwargs): - self._request_callback = kwargs.get('raw_request_hook') + self._request_callback = kwargs.get("raw_request_hook") super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop('raw_request_hook', self._request_callback) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -280,49 +279,50 @@ def on_request(self, request: "PipelineRequest") -> None: class StorageResponseHook(HTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(StorageResponseHook, self).__init__() def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response @@ -332,7 +332,8 @@ class StorageContentValidation(SansIOHTTPPolicy): This will overwrite any headers already defined in the request. """ - header_name = 'Content-MD5' + + header_name = "Content-MD5" def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @@ -342,10 +343,10 @@ def get_content_md5(data): # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. data = data or b"" - md5 = hashlib.md5() # nosec + md5 = hashlib.md5() # nosec if isinstance(data, bytes): md5.update(data) - elif hasattr(data, 'read'): + elif hasattr(data, "read"): pos = 0 try: pos = data.tell() @@ -363,22 +364,25 @@ def get_content_md5(data): return md5.digest() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop('validate_content', False) - if validate_content and request.http_request.method != 'GET': + validate_content = request.context.options.pop("validate_content", False) + if validate_content and request.http_request.method != "GET": computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) request.http_request.headers[self.header_name] = computed_md5 - request.context['validate_content_md5'] = computed_md5 - request.context['validate_content'] = validate_content + request.context["validate_content_md5"] = computed_md5 + request.context["validate_content"] = validate_content def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = request.context.get('validate_content_md5') or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: - raise AzureError(( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'."), - response=response.http_response + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, ) @@ -399,33 +403,41 @@ class StorageRetryPolicy(HTTPPolicy): """Whether the secondary endpoint should be retried.""" def __init__(self, **kwargs: Any) -> None: - self.total_retries = kwargs.pop('retry_total', 10) - self.connect_retries = kwargs.pop('retry_connect', 3) - self.read_retries = kwargs.pop('retry_read', 3) - self.status_retries = kwargs.pop('retry_status', 3) - self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) + self.total_retries = kwargs.pop("retry_total", 10) + self.connect_retries = kwargs.pop("retry_connect", 3) + self.read_retries = kwargs.pop("retry_read", 3) + self.status_retries = kwargs.pop("retry_status", 3) + self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. - :param Dict[str, Any]] settings: The configurable values pertaining to the next host location. + :param Dict[str, Any] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ - if settings['hosts'] and all(settings['hosts'].values()): + if settings["hosts"] and all(settings["hosts"].values()): url = urlparse(request.url) # If there's more than one possible location, retry to the alternative - if settings['mode'] == LocationMode.PRIMARY: - settings['mode'] = LocationMode.SECONDARY + if settings["mode"] == LocationMode.PRIMARY: + settings["mode"] = LocationMode.SECONDARY else: - settings['mode'] = LocationMode.PRIMARY - updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) + settings["mode"] = LocationMode.PRIMARY + updated = url._replace(netloc=settings["hosts"].get(settings["mode"])) request.url = updated.geturl() def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: + """ + Configure the retry settings for the request. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A dictionary containing the retry settings. + :rtype: Dict[str, Any] + """ body_position = None - if hasattr(request.http_request.body, 'read'): + if hasattr(request.http_request.body, "read"): try: body_position = request.http_request.body.tell() except (AttributeError, UnsupportedOperation): @@ -433,129 +445,140 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: pass options = request.context.options return { - 'total': options.pop("retry_total", self.total_retries), - 'connect': options.pop("retry_connect", self.connect_retries), - 'read': options.pop("retry_read", self.read_retries), - 'status': options.pop("retry_status", self.status_retries), - 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), - 'mode': options.pop("location_mode", LocationMode.PRIMARY), - 'hosts': options.pop("hosts", None), - 'hook': options.pop("retry_hook", None), - 'body_position': body_position, - 'count': 0, - 'history': [] + "total": options.pop("retry_total", self.total_retries), + "connect": options.pop("retry_connect", self.connect_retries), + "read": options.pop("retry_read", self.read_retries), + "status": options.pop("retry_status", self.status_retries), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "mode": options.pop("location_mode", LocationMode.PRIMARY), + "hosts": options.pop("hosts", None), + "hook": options.pop("retry_hook", None), + "body_position": body_position, + "count": 0, + "history": [], } def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument - """ Formula for computing the current backoff. + """Formula for computing the current backoff. Should be calculated by child class. :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. - :returns: The backoff time. + :return: The backoff time. :rtype: float """ return 0 def sleep(self, settings, transport): + """Sleep for the backoff time. + + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. + :param transport: The transport to use for sleeping. + :type transport: + ~azure.core.pipeline.transport.AsyncioBaseTransport or + ~azure.core.pipeline.transport.BaseTransport + """ backoff = self.get_backoff_time(settings) if not backoff or backoff < 0: return transport.sleep(backoff) def increment( - self, settings: Dict[str, Any], + self, + settings: Dict[str, Any], request: "PipelineRequest", response: Optional["PipelineResponse"] = None, - error: Optional[AzureError] = None + error: Optional[AzureError] = None, ) -> bool: """Increment the retry counters. :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. - :param PipelineRequest request: A pipeline request object. - :param Optional[PipelineResponse] response: A pipeline response object. - :param Optional[AzureError] error: An error encountered during the request, or + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse or None + :param error: An error encountered during the request, or None if the response was received successfully. - :returns: Whether the retry attempts are exhausted. + :type error: ~azure.core.exceptions.AzureError or None + :return: Whether the retry attempts are exhausted. :rtype: bool """ - settings['total'] -= 1 + settings["total"] -= 1 if error and isinstance(error, ServiceRequestError): # Errors when we're fairly sure that the server did not receive the # request, so it should be safe to retry. - settings['connect'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["connect"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) elif error and isinstance(error, ServiceResponseError): # Errors that occur after the request has been started, so we should # assume that the server began processing it. - settings['read'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["read"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist if response: - settings['status'] -= 1 - settings['history'].append(RequestHistory(request, http_response=response)) + settings["status"] -= 1 + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): - if request.method not in ['PUT'] and settings['retry_secondary']: + if request.method not in ["PUT"] and settings["retry_secondary"]: self._set_next_host_location(settings, request) # rewind the request body if it is a stream - if request.body and hasattr(request.body, 'read'): + if request.body and hasattr(request.body, "read"): # no position was saved, then retry would not work - if settings['body_position'] is None: + if settings["body_position"] is None: return False try: # attempt to rewind the body to the initial position - request.body.seek(settings['body_position'], SEEK_SET) + request.body.seek(settings["body_position"], SEEK_SET) except (UnsupportedOperation, ValueError): # if body is not seekable, then retry would not work return False - settings['count'] += 1 + settings["count"] += 1 return True return False def send(self, request): + """Send the request with retry logic. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ retries_remaining = True response = None retry_settings = self.configure_retries(request) while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + retry_hook(retry_settings, request=request.http_request, response=None, error=err) self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -571,12 +594,13 @@ class ExponentialRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, initial_backoff: int = 15, + self, + initial_backoff: int = 15, increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -601,21 +625,20 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to get backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -630,11 +653,12 @@ class LinearRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs a Linear retry object. @@ -653,15 +677,14 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float @@ -669,19 +692,27 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + """Handle the challenge from the service and authorize the request. + + :param request: The request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The response object. + :type response: ~azure.core.pipeline.PipelineResponse + :return: True if the request was authorized, False otherwise. + :rtype: bool + """ try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 807a51dd297c..4cb32f23248b 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -21,7 +21,7 @@ from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -29,29 +29,25 @@ async def retry_hook(settings, **kwargs): - if settings['hook']: - if asyncio.iscoroutine(settings['hook']): - await settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + if settings["hook"]: + if asyncio.iscoroutine(settings["hook"]): + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() # Load the body in memory and close the socket + except (StreamClosedError, StreamConsumedError): + pass + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -59,54 +55,56 @@ async def is_checksum_retry(response): class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(AsyncStorageResponseHook, self).__init__() async def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) - will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) # type: ignore + await response_callback(response) # type: ignore else: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response + class AsyncStorageRetryPolicy(StorageRetryPolicy): """ The base class for Exponential and Linear retries containing shared code. @@ -125,37 +123,29 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: await retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) await self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + await retry_hook(retry_settings, request=request.http_request, response=None, error=err) await self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -176,7 +166,8 @@ def __init__( increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, - random_jitter_range: int = 3, **kwargs + random_jitter_range: int = 3, + **kwargs ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -203,8 +194,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -217,7 +207,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -232,7 +222,8 @@ class LinearRetry(AsyncStorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, @@ -255,8 +246,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -271,14 +261,13 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/request_handlers.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/request_handlers.py index af500c8727fa..b23f65859690 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/request_handlers.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/request_handlers.py @@ -6,7 +6,7 @@ import logging import stat -from io import (SEEK_END, SEEK_SET, UnsupportedOperation) +from io import SEEK_END, SEEK_SET, UnsupportedOperation from os import fstat from typing import Dict, Optional @@ -37,12 +37,13 @@ def serialize_iso(attr): raise OverflowError("Hit max or min date") date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}" - return date + 'Z' + return date + "Z" except (ValueError, OverflowError) as err: raise ValueError("Unable to serialize datetime object.") from err except AttributeError as err: raise TypeError("ISO-8601 object must be valid datetime object.") from err + def get_length(data): length = None # Check if object implements the __len__ method, covers most input cases such as bytearray. @@ -62,7 +63,7 @@ def get_length(data): try: mode = fstat(fileno).st_mode if stat.S_ISREG(mode) or stat.S_ISLNK(mode): - #st_size only meaningful if regular file or symlink, other types + # st_size only meaningful if regular file or symlink, other types # e.g. sockets may return misleading sizes like 0 return fstat(fileno).st_size except OSError: @@ -84,13 +85,13 @@ def get_length(data): def read_length(data): try: - if hasattr(data, 'read'): - read_data = b'' + if hasattr(data, "read"): + read_data = b"" for chunk in iter(lambda: data.read(4096), b""): read_data += chunk return len(read_data), read_data - if hasattr(data, '__iter__'): - read_data = b'' + if hasattr(data, "__iter__"): + read_data = b"" for chunk in data: read_data += chunk return len(read_data), read_data @@ -100,8 +101,13 @@ def read_length(data): def validate_and_format_range_headers( - start_range, end_range, start_range_required=True, - end_range_required=True, check_content_md5=False, align_to_page=False): + start_range, + end_range, + start_range_required=True, + end_range_required=True, + check_content_md5=False, + align_to_page=False, +): # If end range is provided, start range must be provided if (start_range_required or end_range is not None) and start_range is None: raise ValueError("start_range value cannot be None.") @@ -111,16 +117,18 @@ def validate_and_format_range_headers( # Page ranges must be 512 aligned if align_to_page: if start_range is not None and start_range % 512 != 0: - raise ValueError(f"Invalid page blob start_range: {start_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob start_range: {start_range}. " "The size must be aligned to a 512-byte boundary." + ) if end_range is not None and end_range % 512 != 511: - raise ValueError(f"Invalid page blob end_range: {end_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob end_range: {end_range}. " "The size must be aligned to a 512-byte boundary." + ) # Format based on whether end_range is present range_header = None if end_range is not None: - range_header = f'bytes={start_range}-{end_range}' + range_header = f"bytes={start_range}-{end_range}" elif start_range is not None: range_header = f"bytes={start_range}-" @@ -131,7 +139,7 @@ def validate_and_format_range_headers( raise ValueError("Both start and end range required for MD5 content validation.") if end_range - start_range > 4 * 1024 * 1024: raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") - range_validation = 'true' + range_validation = "true" return range_header, range_validation @@ -140,7 +148,7 @@ def add_metadata_headers(metadata: Optional[Dict[str, str]] = None) -> Dict[str, headers = {} if metadata: for key, value in metadata.items(): - headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value + headers[f"x-ms-meta-{key.strip()}"] = value.strip() if value else value return headers @@ -158,29 +166,26 @@ def serialize_batch_body(requests, batch_id): a list of sub-request for the batch request :param str batch_id: to be embedded in batch sub-request delimiter - :returns: The body bytes for this batch. + :return: The body bytes for this batch. :rtype: bytes """ if requests is None or len(requests) == 0: - raise ValueError('Please provide sub-request(s) for this batch request') + raise ValueError("Please provide sub-request(s) for this batch request") - delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8') - newline_bytes = _HTTP_LINE_ENDING.encode('utf-8') + delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode("utf-8") + newline_bytes = _HTTP_LINE_ENDING.encode("utf-8") batch_body = [] content_index = 0 for request in requests: - request.headers.update({ - "Content-ID": str(content_index), - "Content-Length": str(0) - }) + request.headers.update({"Content-ID": str(content_index), "Content-Length": str(0)}) batch_body.append(delimiter_bytes) batch_body.append(_make_body_from_sub_request(request)) batch_body.append(newline_bytes) content_index += 1 - batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8')) + batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode("utf-8")) # final line of body MUST have \r\n at the end, or it will not be properly read by the service batch_body.append(newline_bytes) @@ -197,35 +202,35 @@ def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_da Whether to include the starting dashes. Used in the body, but non on defining the delimiter. :param bool is_append_dashes: Whether to include the ending dashes. Used in the body on the closing delimiter only. - :returns: The delimiter, WITHOUT a trailing newline. + :return: The delimiter, WITHOUT a trailing newline. :rtype: str """ - prepend_dashes = '--' if is_prepend_dashes else '' - append_dashes = '--' if is_append_dashes else '' + prepend_dashes = "--" if is_prepend_dashes else "" + append_dashes = "--" if is_append_dashes else "" return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes def _make_body_from_sub_request(sub_request): """ - Content-Type: application/http - Content-ID: - Content-Transfer-Encoding: (if present) + Content-Type: application/http + Content-ID: + Content-Transfer-Encoding: (if present) - HTTP/ -
:
(repeated as necessary) - Content-Length: - (newline if content length > 0) - (if content length > 0) + HTTP/ +
:
(repeated as necessary) + Content-Length: + (newline if content length > 0) + (if content length > 0) - Serializes an http request. + Serializes an http request. - :param ~azure.core.pipeline.transport.HttpRequest sub_request: - Request to serialize. - :returns: The serialized sub-request in bytes - :rtype: bytes - """ + :param ~azure.core.pipeline.transport.HttpRequest sub_request: + Request to serialize. + :return: The serialized sub-request in bytes + :rtype: bytes + """ # put the sub-request's headers into a list for efficient str concatenation sub_request_body = [] @@ -249,9 +254,9 @@ def _make_body_from_sub_request(sub_request): # append HTTP verb and path and query and HTTP version sub_request_body.append(sub_request.method) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(sub_request.url) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(_HTTP1_1_IDENTIFIER) sub_request_body.append(_HTTP_LINE_ENDING) @@ -266,4 +271,4 @@ def _make_body_from_sub_request(sub_request): # append blank line sub_request_body.append(_HTTP_LINE_ENDING) - return ''.join(sub_request_body).encode() + return "".join(sub_request_body).encode() diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/response_handlers.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/response_handlers.py index af9a2fcdcdc2..bcfa4147763e 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/response_handlers.py @@ -46,23 +46,23 @@ def parse_length_from_content_range(content_range): # First, split in space and take the second half: '1-3/65537' # Next, split on slash and take the second half: '65537' # Finally, convert to an int: 65537 - return int(content_range.split(' ', 1)[1].split('/', 1)[1]) + return int(content_range.split(" ", 1)[1].split("/", 1)[1]) def normalize_headers(headers): normalized = {} for key, value in headers.items(): - if key.startswith('x-ms-'): + if key.startswith("x-ms-"): key = key[5:] - normalized[key.lower().replace('-', '_')] = get_enum_value(value) + normalized[key.lower().replace("-", "_")] = get_enum_value(value) return normalized def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument try: - raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith("x-ms-meta-")} except AttributeError: - raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith("x-ms-meta-")} return {k[10:]: v for k, v in raw_metadata.items()} @@ -82,19 +82,23 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches raise_error = HttpResponseError serialized = False if isinstance(storage_error, AzureSigningError): - storage_error.message = storage_error.message + \ - '. This is likely due to an invalid shared key. Please check your shared key and try again.' + storage_error.message = ( + storage_error.message + + ". This is likely due to an invalid shared key. Please check your shared key and try again." + ) if not storage_error.response or storage_error.response.status_code in [200, 204]: raise storage_error # If it is one of those three then it has been serialized prior by the generated layer. - if isinstance(storage_error, (PartialBatchErrorException, - ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)): + if isinstance( + storage_error, + (PartialBatchErrorException, ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError), + ): serialized = True - error_code = storage_error.response.headers.get('x-ms-error-code') + error_code = storage_error.response.headers.get("x-ms-error-code") error_message = storage_error.message additional_data = {} error_dict = {} @@ -104,27 +108,25 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py if error_body is None or len(error_body) == 0: error_body = storage_error.response.reason except AttributeError: - error_body = '' + error_body = "" # If it is an XML response if isinstance(error_body, Element): - error_dict = { - child.tag.lower(): child.text - for child in error_body - } + error_dict = {child.tag.lower(): child.text for child in error_body} # If it is a JSON response elif isinstance(error_body, dict): - error_dict = error_body.get('error', {}) + error_dict = error_body.get("error", {}) elif not error_code: _LOGGER.warning( - 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) - error_dict = {'message': str(error_body)} + "Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.", type(error_body) + ) + error_dict = {"message": str(error_body)} # If we extracted from a Json or XML response # There is a chance error_dict is just a string if error_dict and isinstance(error_dict, dict): - error_code = error_dict.get('code') - error_message = error_dict.get('message') - additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}} + error_code = error_dict.get("code") + error_message = error_dict.get("message") + additional_data = {k: v for k, v in error_dict.items() if k not in {"code", "message"}} except DecodeError: pass @@ -132,31 +134,33 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py # This check would be unnecessary if we have already serialized the error if error_code and not serialized: error_code = StorageErrorCode(error_code) - if error_code in [StorageErrorCode.condition_not_met, - StorageErrorCode.blob_overwritten]: + if error_code in [StorageErrorCode.condition_not_met, StorageErrorCode.blob_overwritten]: raise_error = ResourceModifiedError - if error_code in [StorageErrorCode.invalid_authentication_info, - StorageErrorCode.authentication_failed]: + if error_code in [StorageErrorCode.invalid_authentication_info, StorageErrorCode.authentication_failed]: raise_error = ClientAuthenticationError - if error_code in [StorageErrorCode.resource_not_found, - StorageErrorCode.cannot_verify_copy_source, - StorageErrorCode.blob_not_found, - StorageErrorCode.queue_not_found, - StorageErrorCode.container_not_found, - StorageErrorCode.parent_not_found, - StorageErrorCode.share_not_found]: + if error_code in [ + StorageErrorCode.resource_not_found, + StorageErrorCode.cannot_verify_copy_source, + StorageErrorCode.blob_not_found, + StorageErrorCode.queue_not_found, + StorageErrorCode.container_not_found, + StorageErrorCode.parent_not_found, + StorageErrorCode.share_not_found, + ]: raise_error = ResourceNotFoundError - if error_code in [StorageErrorCode.account_already_exists, - StorageErrorCode.account_being_created, - StorageErrorCode.resource_already_exists, - StorageErrorCode.resource_type_mismatch, - StorageErrorCode.blob_already_exists, - StorageErrorCode.queue_already_exists, - StorageErrorCode.container_already_exists, - StorageErrorCode.container_being_deleted, - StorageErrorCode.queue_being_deleted, - StorageErrorCode.share_already_exists, - StorageErrorCode.share_being_deleted]: + if error_code in [ + StorageErrorCode.account_already_exists, + StorageErrorCode.account_being_created, + StorageErrorCode.resource_already_exists, + StorageErrorCode.resource_type_mismatch, + StorageErrorCode.blob_already_exists, + StorageErrorCode.queue_already_exists, + StorageErrorCode.container_already_exists, + StorageErrorCode.container_being_deleted, + StorageErrorCode.queue_being_deleted, + StorageErrorCode.share_already_exists, + StorageErrorCode.share_being_deleted, + ]: raise_error = ResourceExistsError except ValueError: # Got an unknown error code @@ -183,7 +187,7 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py error.args = (error.message,) try: # `from None` prevents us from double printing the exception (suppresses generated layer error context) - exec("raise error from None") # pylint: disable=exec-used # nosec + exec("raise error from None") # pylint: disable=exec-used # nosec except SyntaxError as exc: raise error from exc diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/shared_access_signature.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/shared_access_signature.py index fb5b98735d8a..959a5ac5762d 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/shared_access_signature.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/shared_access_signature.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +# pylint: disable=docstring-keyword-should-match-keyword-only from datetime import date @@ -10,44 +11,45 @@ from .constants import X_MS_VERSION from . import sign_string, url_quote + # cspell:ignoreRegExp rsc. # cspell:ignoreRegExp s..?id class QueryStringConstants(object): - SIGNED_SIGNATURE = 'sig' - SIGNED_PERMISSION = 'sp' - SIGNED_START = 'st' - SIGNED_EXPIRY = 'se' - SIGNED_RESOURCE = 'sr' - SIGNED_IDENTIFIER = 'si' - SIGNED_IP = 'sip' - SIGNED_PROTOCOL = 'spr' - SIGNED_VERSION = 'sv' - SIGNED_CACHE_CONTROL = 'rscc' - SIGNED_CONTENT_DISPOSITION = 'rscd' - SIGNED_CONTENT_ENCODING = 'rsce' - SIGNED_CONTENT_LANGUAGE = 'rscl' - SIGNED_CONTENT_TYPE = 'rsct' - START_PK = 'spk' - START_RK = 'srk' - END_PK = 'epk' - END_RK = 'erk' - SIGNED_RESOURCE_TYPES = 'srt' - SIGNED_SERVICES = 'ss' - SIGNED_OID = 'skoid' - SIGNED_TID = 'sktid' - SIGNED_KEY_START = 'skt' - SIGNED_KEY_EXPIRY = 'ske' - SIGNED_KEY_SERVICE = 'sks' - SIGNED_KEY_VERSION = 'skv' - SIGNED_ENCRYPTION_SCOPE = 'ses' - SIGNED_KEY_DELEGATED_USER_TID = 'skdutid' - SIGNED_DELEGATED_USER_OID = 'sduoid' + SIGNED_SIGNATURE = "sig" + SIGNED_PERMISSION = "sp" + SIGNED_START = "st" + SIGNED_EXPIRY = "se" + SIGNED_RESOURCE = "sr" + SIGNED_IDENTIFIER = "si" + SIGNED_IP = "sip" + SIGNED_PROTOCOL = "spr" + SIGNED_VERSION = "sv" + SIGNED_CACHE_CONTROL = "rscc" + SIGNED_CONTENT_DISPOSITION = "rscd" + SIGNED_CONTENT_ENCODING = "rsce" + SIGNED_CONTENT_LANGUAGE = "rscl" + SIGNED_CONTENT_TYPE = "rsct" + START_PK = "spk" + START_RK = "srk" + END_PK = "epk" + END_RK = "erk" + SIGNED_RESOURCE_TYPES = "srt" + SIGNED_SERVICES = "ss" + SIGNED_OID = "skoid" + SIGNED_TID = "sktid" + SIGNED_KEY_START = "skt" + SIGNED_KEY_EXPIRY = "ske" + SIGNED_KEY_SERVICE = "sks" + SIGNED_KEY_VERSION = "skv" + SIGNED_ENCRYPTION_SCOPE = "ses" + SIGNED_KEY_DELEGATED_USER_TID = "skdutid" + SIGNED_DELEGATED_USER_OID = "sduoid" # for ADLS - SIGNED_AUTHORIZED_OID = 'saoid' - SIGNED_UNAUTHORIZED_OID = 'suoid' - SIGNED_CORRELATION_ID = 'scid' - SIGNED_DIRECTORY_DEPTH = 'sdd' + SIGNED_AUTHORIZED_OID = "saoid" + SIGNED_UNAUTHORIZED_OID = "suoid" + SIGNED_CORRELATION_ID = "scid" + SIGNED_DIRECTORY_DEPTH = "sdd" @staticmethod def to_list(): @@ -90,37 +92,30 @@ def to_list(): class SharedAccessSignature(object): - ''' + """ Provides a factory for creating account access signature tokens with an account name and account key. Users can either use the factory or can construct the appropriate service and use the generate_*_shared_access_signature method directly. - ''' + """ def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION): - ''' + """ :param str account_name: The storage account name used to generate the shared access signatures. :param str account_key: The access key to generate the shares access signatures. :param str x_ms_version: The service version used to generate the shared access signatures. - ''' + """ self.account_name = account_name self.account_key = account_key self.x_ms_version = x_ms_version def generate_account( - self, services, - resource_types, - permission, - expiry, - start=None, - ip=None, - protocol=None, - sts_hook=None + self, services, resource_types, permission, expiry, start=None, ip=None, protocol=None, sts_hook=None ) -> str: - ''' + """ Generates a shared access signature for the account. Use the returned signature with the sas_token parameter of the service or to create a new account object. @@ -164,9 +159,9 @@ def generate_account( For debugging purposes only. If provided, the hook is called with the string to sign that was used to generate the SAS. :type sts_hook: Optional[Callable[[str], None]] - :returns: The generated SAS token for the account. + :return: The generated SAS token for the account. :rtype: str - ''' + """ sas = _SharedAccessHelper() sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version) sas.add_account(services, resource_types) @@ -211,11 +206,9 @@ def add_account(self, services, resource_types): self._add_query(QueryStringConstants.SIGNED_SERVICES, services) self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types) - def add_override_response_headers(self, cache_control, - content_disposition, - content_encoding, - content_language, - content_type): + def add_override_response_headers( + self, cache_control, content_disposition, content_encoding, content_language, content_type + ): self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control) self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition) self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding) @@ -224,24 +217,24 @@ def add_override_response_headers(self, cache_control, def add_account_signature(self, account_name, account_key): def get_value_to_append(query): - return_value = self.query_dict.get(query) or '' - return return_value + '\n' - - self.string_to_sign = \ - (account_name + '\n' + - get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + - get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + - get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + - get_value_to_append(QueryStringConstants.SIGNED_START) + - get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + - get_value_to_append(QueryStringConstants.SIGNED_IP) + - get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + - get_value_to_append(QueryStringConstants.SIGNED_VERSION) + - '\n' # Signed Encryption Scope - always empty for fileshare - ) - - self._add_query(QueryStringConstants.SIGNED_SIGNATURE, - sign_string(account_key, self.string_to_sign)) + return_value = self.query_dict.get(query) or "" + return return_value + "\n" + + self.string_to_sign = ( + account_name + + "\n" + + get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + + get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + + get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + + get_value_to_append(QueryStringConstants.SIGNED_START) + + get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + + get_value_to_append(QueryStringConstants.SIGNED_IP) + + get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + + get_value_to_append(QueryStringConstants.SIGNED_VERSION) + + "\n" # Signed Encryption Scope - always empty for fileshare + ) + + self._add_query(QueryStringConstants.SIGNED_SIGNATURE, sign_string(account_key, self.string_to_sign)) def get_token(self) -> str: - return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None]) + return "&".join([f"{n}={url_quote(v)}" for n, v in self.query_dict.items() if v is not None]) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads.py index b31cfb3291d9..7a5fb3f3dc91 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads.py @@ -12,7 +12,7 @@ from azure.core.tracing.common import with_current_context -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers @@ -41,20 +41,21 @@ def _parallel_uploads(executor, uploader, pending, running): def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - validate_content=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + validate_content=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -64,7 +65,8 @@ def upload_data_chunks( parallel=parallel, validate_content=validate_content, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: upload_tasks = uploader.get_chunk_streams() @@ -81,18 +83,19 @@ def upload_data_chunks( def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -100,7 +103,8 @@ def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: @@ -120,15 +124,17 @@ def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -253,7 +259,7 @@ def __init__(self, *args, **kwargs): def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) self.service.stage_block( block_id, @@ -261,20 +267,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) return index, block_id def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -302,11 +308,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): pass @@ -326,19 +332,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) def _upload_substream_block(self, index, block_stream): @@ -356,11 +363,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): try: @@ -371,7 +378,7 @@ def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -388,9 +395,9 @@ def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - return f'bytes={chunk_offset}-{chunk_end}', response + return f"bytes={chunk_offset}-{chunk_end}", response # TODO: Implement this method. def _upload_substream_block(self, index, block_stream): diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads_async.py index a056cd290230..6ed5ba1d0f91 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/uploads_async.py @@ -12,7 +12,7 @@ from math import ceil from typing import AsyncGenerator, Union -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers from .uploads import SubStream, IterStreamer # pylint: disable=unused-import @@ -59,19 +59,20 @@ async def _parallel_uploads(uploader, pending, running): async def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -80,7 +81,8 @@ async def upload_data_chunks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_chunk_streams() @@ -104,18 +106,19 @@ async def upload_data_chunks( async def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -123,13 +126,13 @@ async def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_substream_blocks() running_futures = [ - asyncio.ensure_future(uploader.process_substream_block(u)) - for u in islice(upload_tasks, 0, max_concurrency) + asyncio.ensure_future(uploader.process_substream_block(u)) for u in islice(upload_tasks, 0, max_concurrency) ] range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures) else: @@ -144,15 +147,17 @@ async def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -178,7 +183,7 @@ def __init__( async def get_chunk_streams(self): index = 0 while True: - data = b'' + data = b"" read_size = self.chunk_size # Buffer until we either reach the end of the stream or get a whole chunk. @@ -189,12 +194,12 @@ async def get_chunk_streams(self): if inspect.isawaitable(temp): temp = await temp if not isinstance(temp, bytes): - raise TypeError('Blob data should be of type bytes.') + raise TypeError("Blob data should be of type bytes.") data += temp or b"" # We have read an empty string and so are at the end # of the buffer or we have read a full chunk. - if temp == b'' or len(data) == self.chunk_size: + if temp == b"" or len(data) == self.chunk_size: break if len(data) == self.chunk_size: @@ -273,13 +278,13 @@ def set_response_properties(self, resp): class BlockBlobChunkUploader(_ChunkUploader): def __init__(self, *args, **kwargs): - kwargs.pop('modified_access_conditions', None) + kwargs.pop("modified_access_conditions", None) super(BlockBlobChunkUploader, self).__init__(*args, **kwargs) self.current_length = None async def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) await self.service.stage_block( block_id, @@ -287,19 +292,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): body=chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) return index, block_id async def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" await self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) finally: block_stream.close() return block_id @@ -311,7 +318,7 @@ def _is_chunk_empty(self, chunk_data): # read until non-zero byte is encountered # if reached the end without returning, then chunk_data is all 0's for each_byte in chunk_data: - if each_byte not in [0, b'\x00']: + if each_byte not in [0, b"\x00"]: return False return True @@ -319,7 +326,7 @@ async def _upload_chunk(self, chunk_offset, chunk_data): # avoid uploading the empty pages if not self._is_chunk_empty(chunk_data): chunk_end = chunk_offset + len(chunk_data) - 1 - content_range = f'bytes={chunk_offset}-{chunk_end}' + content_range = f"bytes={chunk_offset}-{chunk_end}" computed_md5 = None self.response_headers = await self.service.upload_pages( body=chunk_data, @@ -329,10 +336,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): pass @@ -352,18 +360,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) - self.current_length = int(self.response_headers['blob_append_offset']) + **self.request_options, + ) + self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = await self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) async def _upload_substream_block(self, index, block_stream): pass @@ -379,11 +390,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): try: @@ -394,7 +405,7 @@ async def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -411,9 +422,9 @@ async def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - range_id = f'bytes={chunk_offset}-{chunk_end}' + range_id = f"bytes={chunk_offset}-{chunk_end}" return range_id, response # TODO: Implement this method. @@ -421,10 +432,11 @@ async def _upload_substream_block(self, index, block_stream): pass -class AsyncIterStreamer(): +class AsyncIterStreamer: """ File-like streaming object for AsyncGenerators. """ + def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"): self.iterator = generator.__aiter__() self.leftover = b"" diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index 44b45084a670..90f5e9192c6e 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -33,8 +33,9 @@ async def process_content(data: Any) -> bytes: raise ValueError("Response cannot be None.") try: - await data.response.load_body() - return cast(bytes, data.response.body()) + if hasattr(data.response, "is_stream_consumed") and data.response.is_stream_consumed: + return data.response.content + return b"".join([d async for d in data]) except Exception as error: raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) from error diff --git a/sdk/storage/azure-storage-file-share/tests/test_file.py b/sdk/storage/azure-storage-file-share/tests/test_file.py index 78aa779be7f4..b0ccfe3afdf7 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_file.py +++ b/sdk/storage/azure-storage-file-share/tests/test_file.py @@ -3848,7 +3848,7 @@ def test_file_permission_format(self, **kwargs): file_client.delete_file() @FileSharePreparer() - def test_mock_transport_no_content_validation(self, **kwargs): + def test_legacy_transport(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") @@ -3878,7 +3878,7 @@ def test_mock_transport_no_content_validation(self, **kwargs): assert file_data == b"Hello World!" # data is fixed by mock transport @FileSharePreparer() - def test_mock_transport_with_content_validation(self, **kwargs): + def test_legacy_transport_with_content_validation(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") diff --git a/sdk/storage/azure-storage-file-share/tests/test_file_async.py b/sdk/storage/azure-storage-file-share/tests/test_file_async.py index 178b7d66d7e4..f107f832ac7c 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_file_async.py +++ b/sdk/storage/azure-storage-file-share/tests/test_file_async.py @@ -3962,7 +3962,7 @@ async def test_file_permission_format(self, **kwargs): await file_client.delete_file() @FileSharePreparer() - async def test_mock_transport_no_content_validation(self, **kwargs): + async def test_legacy_transport(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") @@ -3993,7 +3993,7 @@ async def test_mock_transport_no_content_validation(self, **kwargs): assert file_data == b"Hello Async World!" # data is fixed by mock transport @FileSharePreparer() - async def test_mock_transport_with_content_validation(self, **kwargs): + async def test_legacy_transport_with_content_validation(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") diff --git a/sdk/storage/azure-storage-file-share/tests/test_helpers_async.py b/sdk/storage/azure-storage-file-share/tests/test_helpers_async.py index 48ca12c94dcc..0a7da86e53aa 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_helpers_async.py +++ b/sdk/storage/azure-storage-file-share/tests/test_helpers_async.py @@ -4,11 +4,15 @@ # license information. # -------------------------------------------------------------------------- +import asyncio +from collections import deque from typing import Any, Dict, Optional from azure.core.pipeline.transport import AioHttpTransportResponse, AsyncHttpTransport from azure.core.rest import HttpRequest from aiohttp import ClientResponse +from aiohttp.streams import StreamReader +from aiohttp.client_proto import ResponseHandler class ProgressTracker: @@ -65,6 +69,10 @@ def __init__( self._loop = None self.status = status self.reason = reason + self.content = StreamReader(ResponseHandler(asyncio.get_event_loop()), 65535) + self.content.total_bytes = len(body_bytes) + self.content._buffer = deque([body_bytes]) + self.content._eof = True class MockStorageTransport(AsyncHttpTransport): diff --git a/sdk/storage/azure-storage-queue/azure/__init__.py b/sdk/storage/azure-storage-queue/azure/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/storage/azure-storage-queue/azure/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-queue/azure/storage/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/__init__.py index 0d1f7edf5dc6..d55ccad1f573 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py index 3028519201ba..951b1c1fd0fa 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/__init__.py @@ -9,7 +9,7 @@ from ._queue_service_client import QueueServiceClient from ._shared_access_signature import generate_account_sas, generate_queue_sas from ._shared.policies import ExponentialRetry, LinearRetry -from ._shared.models import( +from ._shared.models import ( LocationMode, ResourceTypes, AccountSasPermissions, @@ -36,27 +36,27 @@ __version__ = VERSION __all__ = [ - 'QueueClient', - 'QueueServiceClient', - 'ExponentialRetry', - 'LinearRetry', - 'LocationMode', - 'ResourceTypes', - 'AccountSasPermissions', - 'StorageErrorCode', - 'QueueMessage', - 'QueueProperties', - 'QueueSasPermissions', - 'AccessPolicy', - 'TextBase64EncodePolicy', - 'TextBase64DecodePolicy', - 'BinaryBase64EncodePolicy', - 'BinaryBase64DecodePolicy', - 'QueueAnalyticsLogging', - 'Metrics', - 'CorsRule', - 'RetentionPolicy', - 'generate_account_sas', - 'generate_queue_sas', - 'Services' + "QueueClient", + "QueueServiceClient", + "ExponentialRetry", + "LinearRetry", + "LocationMode", + "ResourceTypes", + "AccountSasPermissions", + "StorageErrorCode", + "QueueMessage", + "QueueProperties", + "QueueSasPermissions", + "AccessPolicy", + "TextBase64EncodePolicy", + "TextBase64DecodePolicy", + "BinaryBase64EncodePolicy", + "BinaryBase64DecodePolicy", + "QueueAnalyticsLogging", + "Metrics", + "CorsRule", + "RetentionPolicy", + "generate_account_sas", + "generate_queue_sas", + "Services", ] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py index f2016049827e..355a0053c2dc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_deserialize.py @@ -16,35 +16,26 @@ from azure.core.pipeline import PipelineResponse -def deserialize_queue_properties( - response: "PipelineResponse", - obj: Any, - headers: Dict[str, Any] -) -> QueueProperties: +def deserialize_queue_properties(response: "PipelineResponse", obj: Any, headers: Dict[str, Any]) -> QueueProperties: metadata = deserialize_metadata(response, obj, headers) - queue_properties = QueueProperties( - metadata=metadata, - **headers - ) + queue_properties = QueueProperties(metadata=metadata, **headers) return queue_properties -def deserialize_queue_creation( - response: "PipelineResponse", - obj: Any, - headers: Dict[str, Any] -) -> Dict[str, Any]: +def deserialize_queue_creation(response: "PipelineResponse", obj: Any, headers: Dict[str, Any]) -> Dict[str, Any]: response = response.http_response - if response.status_code == 204: # type: ignore + if response.status_code == 204: # type: ignore [attr-defined] error_code = StorageErrorCode.queue_already_exists error = ResourceExistsError( message=( "Queue already exists\n" f"RequestId:{headers['x-ms-request-id']}\n" f"Time:{headers['Date']}\n" - f"ErrorCode:{error_code}"), - response=response) # type: ignore - error.error_code = error_code # type: ignore - error.additional_info = {} # type: ignore + f"ErrorCode:{error_code}" + ), + response=response, # type: ignore [arg-type] + ) + error.error_code = error_code # type: ignore [attr-defined] + error.additional_info = {} # type: ignore [attr-defined] raise error return headers diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py index 42f5c51d0762..2153d1da1da6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_encryption.py @@ -38,51 +38,46 @@ from cryptography.hazmat.primitives.padding import PaddingContext -_ENCRYPTION_PROTOCOL_V1 = '1.0' -_ENCRYPTION_PROTOCOL_V2 = '2.0' -_ENCRYPTION_PROTOCOL_V2_1 = '2.1' +_ENCRYPTION_PROTOCOL_V1 = "1.0" +_ENCRYPTION_PROTOCOL_V2 = "2.0" +_ENCRYPTION_PROTOCOL_V2_1 = "2.1" _VALID_ENCRYPTION_PROTOCOLS = [_ENCRYPTION_PROTOCOL_V1, _ENCRYPTION_PROTOCOL_V2, _ENCRYPTION_PROTOCOL_V2_1] _ENCRYPTION_V2_PROTOCOLS = [_ENCRYPTION_PROTOCOL_V2, _ENCRYPTION_PROTOCOL_V2_1] _GCM_REGION_DATA_LENGTH = 4 * 1024 * 1024 _GCM_NONCE_LENGTH = 12 _GCM_TAG_LENGTH = 16 -_ERROR_OBJECT_INVALID = \ - '{0} does not define a complete interface. Value of {1} is either missing or invalid.' +_ERROR_OBJECT_INVALID = "{0} does not define a complete interface. Value of {1} is either missing or invalid." _ERROR_UNSUPPORTED_METHOD_FOR_ENCRYPTION = ( - 'The require_encryption flag is set, but encryption is not supported' - ' for this method.') + "The require_encryption flag is set, but encryption is not supported for this method." +) class KeyEncryptionKey(Protocol): - def wrap_key(self, key: bytes) -> bytes: - ... + def wrap_key(self, key: bytes) -> bytes: ... - def unwrap_key(self, key: bytes, algorithm: str) -> bytes: - ... + def unwrap_key(self, key: bytes, algorithm: str) -> bytes: ... - def get_kid(self) -> str: - ... + def get_kid(self) -> str: ... - def get_key_wrap_algorithm(self) -> str: - ... + def get_key_wrap_algorithm(self) -> str: ... def _validate_not_none(param_name: str, param: Any): if param is None: - raise ValueError(f'{param_name} should not be None.') + raise ValueError(f"{param_name} should not be None.") def _validate_key_encryption_key_wrap(kek: KeyEncryptionKey): # Note that None is not callable and so will fail the second clause of each check. - if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) - if not hasattr(kek, 'get_kid') or not callable(kek.get_kid): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) - if not hasattr(kek, 'get_key_wrap_algorithm') or not callable(kek.get_key_wrap_algorithm): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm')) + if not hasattr(kek, "wrap_key") or not callable(kek.wrap_key): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "wrap_key")) + if not hasattr(kek, "get_kid") or not callable(kek.get_kid): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "get_kid")) + if not hasattr(kek, "get_key_wrap_algorithm") or not callable(kek.get_key_wrap_algorithm): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "get_key_wrap_algorithm")) class StorageEncryptionMixin(object): @@ -91,19 +86,22 @@ def _configure_encryption(self, kwargs: Dict[str, Any]): self.encryption_version = kwargs.get("encryption_version", "1.0") self.key_encryption_key = kwargs.get("key_encryption_key") self.key_resolver_function = kwargs.get("key_resolver_function") - if self.key_encryption_key and self.encryption_version == '1.0': - warnings.warn("This client has been configured to use encryption with version 1.0. " + - "Version 1.0 is deprecated and no longer considered secure. It is highly " + - "recommended that you switch to using version 2.0. The version can be " + - "specified using the 'encryption_version' keyword.") + if self.key_encryption_key and self.encryption_version == "1.0": + warnings.warn( + "This client has been configured to use encryption with version 1.0. " + + "Version 1.0 is deprecated and no longer considered secure. It is highly " + + "recommended that you switch to using version 2.0. The version can be " + + "specified using the 'encryption_version' keyword." + ) class _EncryptionAlgorithm(object): """ Specifies which client encryption algorithm is used. """ - AES_CBC_256 = 'AES_CBC_256' - AES_GCM_256 = 'AES_GCM_256' + + AES_CBC_256 = "AES_CBC_256" + AES_GCM_256 = "AES_GCM_256" class _WrappedContentKey: @@ -120,9 +118,9 @@ def __init__(self, algorithm: str, encrypted_key: bytes, key_id: str) -> None: :param str key_id: The key-encryption-key identifier string. """ - _validate_not_none('algorithm', algorithm) - _validate_not_none('encrypted_key', encrypted_key) - _validate_not_none('key_id', key_id) + _validate_not_none("algorithm", algorithm) + _validate_not_none("encrypted_key", encrypted_key) + _validate_not_none("key_id", key_id) self.algorithm = algorithm self.encrypted_key = encrypted_key @@ -144,9 +142,9 @@ def __init__(self, data_length: int, nonce_length: int, tag_length: int) -> None :param int tag_length: The length of the encryption tag. """ - _validate_not_none('data_length', data_length) - _validate_not_none('nonce_length', nonce_length) - _validate_not_none('tag_length', tag_length) + _validate_not_none("data_length", data_length) + _validate_not_none("nonce_length", nonce_length) + _validate_not_none("tag_length", tag_length) self.data_length = data_length self.nonce_length = nonce_length @@ -166,8 +164,8 @@ def __init__(self, encryption_algorithm: _EncryptionAlgorithm, protocol: str) -> :param str protocol: The protocol version used for encryption. """ - _validate_not_none('encryption_algorithm', encryption_algorithm) - _validate_not_none('protocol', protocol) + _validate_not_none("encryption_algorithm", encryption_algorithm) + _validate_not_none("protocol", protocol) self.encryption_algorithm = str(encryption_algorithm) self.protocol = protocol @@ -179,11 +177,12 @@ class _EncryptionData: """ def __init__( - self, content_encryption_IV: Optional[bytes], + self, + content_encryption_IV: Optional[bytes], encrypted_region_info: Optional[_EncryptedRegionInfo], encryption_agent: _EncryptionAgent, wrapped_content_key: _WrappedContentKey, - key_wrapping_metadata: Dict[str, Any] + key_wrapping_metadata: Dict[str, Any], ) -> None: """ :param Optional[bytes] content_encryption_IV: @@ -200,14 +199,14 @@ def __init__( :param Dict[str, Any] key_wrapping_metadata: A dict containing metadata related to the key wrapping. """ - _validate_not_none('encryption_agent', encryption_agent) - _validate_not_none('wrapped_content_key', wrapped_content_key) + _validate_not_none("encryption_agent", encryption_agent) + _validate_not_none("wrapped_content_key", wrapped_content_key) # Validate we have the right matching optional parameter for the specified algorithm if encryption_agent.encryption_algorithm == _EncryptionAlgorithm.AES_CBC_256: - _validate_not_none('content_encryption_IV', content_encryption_IV) + _validate_not_none("content_encryption_IV", content_encryption_IV) elif encryption_agent.encryption_algorithm == _EncryptionAlgorithm.AES_GCM_256: - _validate_not_none('encrypted_region_info', encrypted_region_info) + _validate_not_none("encrypted_region_info", encrypted_region_info) else: raise ValueError("Invalid encryption algorithm.") @@ -225,8 +224,10 @@ class GCMBlobEncryptionStream: will use the same encryption key and will generate a guaranteed unique nonce for each encryption region. """ + def __init__( - self, content_encryption_key: bytes, + self, + content_encryption_key: bytes, data_stream: IO[bytes], ) -> None: """ @@ -237,7 +238,7 @@ def __init__( self.data_stream = data_stream self.offset = 0 - self.current = b'' + self.current = b"" self.nonce_counter = 0 def read(self, size: int = -1) -> bytes: @@ -286,7 +287,7 @@ def encrypt_data_v2(data: bytes, nonce: int, key: bytes) -> bytes: :return: The encrypted bytes in the form: nonce + ciphertext + tag. :rtype: bytes """ - nonce_bytes = nonce.to_bytes(_GCM_NONCE_LENGTH, 'big') + nonce_bytes = nonce.to_bytes(_GCM_NONCE_LENGTH, "big") aesgcm = AESGCM(key) # Returns ciphertext + tag @@ -307,11 +308,8 @@ def is_encryption_v2(encryption_data: Optional[_EncryptionData]) -> bool: def modify_user_agent_for_encryption( - user_agent: str, - moniker: str, - encryption_version: str, - request_options: Dict[str, Any] - ) -> None: + user_agent: str, moniker: str, encryption_version: str, request_options: Dict[str, Any] +) -> None: """ Modifies the request options to contain a user agent string updated with encryption information. Adds azstorage-clientsideencryption/ immediately proceeding the SDK descriptor. @@ -322,7 +320,7 @@ def modify_user_agent_for_encryption( :param Dict[str, Any] request_options: The reuqest options to add the user agent override to. """ # If the user has specified user_agent_overwrite=True, don't make any modifications - if request_options.get('user_agent_overwrite'): + if request_options.get("user_agent_overwrite"): return # If the feature flag is already present, don't add it again @@ -333,11 +331,11 @@ def modify_user_agent_for_encryption( index = user_agent.find(f"azsdk-python-{moniker}") user_agent = f"{user_agent[:index]}{feature_flag} {user_agent[index:]}" # Since we are using user_agent_overwrite=True, we must prepend the user's user_agent if there is one - if request_options.get('user_agent'): + if request_options.get("user_agent"): user_agent = f"{request_options.get('user_agent')} {user_agent}" - request_options['user_agent'] = user_agent - request_options['user_agent_overwrite'] = True + request_options["user_agent"] = user_agent + request_options["user_agent_overwrite"] = True def get_adjusted_upload_size(length: int, encryption_version: str) -> int: @@ -362,10 +360,8 @@ def get_adjusted_upload_size(length: int, encryption_version: str) -> int: def get_adjusted_download_range_and_offset( - start: int, - end: int, - length: Optional[int], - encryption_data: Optional[_EncryptionData]) -> Tuple[Tuple[int, int], Tuple[int, int]]: + start: int, end: int, length: Optional[int], encryption_data: Optional[_EncryptionData] +) -> Tuple[Tuple[int, int], Tuple[int, int]]: """ Gets the new download range and offsets into the decrypted data for the given user-specified range. The new download range will include all @@ -453,7 +449,7 @@ def parse_encryption_data(metadata: Dict[str, Any]) -> Optional[_EncryptionData] try: # Use case insensitive dict as key needs to be case-insensitive case_insensitive_metadata = CaseInsensitiveDict(metadata) - return _dict_to_encryption_data(loads(case_insensitive_metadata['encryptiondata'])) + return _dict_to_encryption_data(loads(case_insensitive_metadata["encryptiondata"])) except: # pylint: disable=bare-except return None @@ -468,9 +464,11 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp :return: The new blob size. :rtype: int """ - if (encryption_data is not None and - encryption_data.encrypted_region_info is not None and - is_encryption_v2(encryption_data)): + if ( + encryption_data is not None + and encryption_data.encrypted_region_info is not None + and is_encryption_v2(encryption_data) + ): nonce_length = encryption_data.encrypted_region_info.nonce_length data_length = encryption_data.encrypted_region_info.data_length @@ -485,11 +483,8 @@ def adjust_blob_size_for_encryption(size: int, encryption_data: Optional[_Encryp def _generate_encryption_data_dict( - kek: KeyEncryptionKey, - cek: bytes, - iv: Optional[bytes], - version: str - ) -> TypedOrderedDict[str, Any]: + kek: KeyEncryptionKey, cek: bytes, iv: Optional[bytes], version: str +) -> TypedOrderedDict[str, Any]: """ Generates and returns the encryption metadata as a dict. @@ -506,7 +501,7 @@ def _generate_encryption_data_dict( # For V2, we include the encryption version in the wrapped key. elif version == _ENCRYPTION_PROTOCOL_V2: # We must pad the version to 8 bytes for AES Keywrap algorithms - to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b'\0') + cek + to_wrap = _ENCRYPTION_PROTOCOL_V2.encode().ljust(8, b"\0") + cek wrapped_cek = kek.wrap_key(to_wrap) else: raise ValueError("Invalid encryption version specified.") @@ -514,31 +509,31 @@ def _generate_encryption_data_dict( # Build the encryption_data dict. # Use OrderedDict to comply with Java's ordering requirement. wrapped_content_key = OrderedDict() - wrapped_content_key['KeyId'] = kek.get_kid() - wrapped_content_key['EncryptedKey'] = encode_base64(wrapped_cek) - wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + wrapped_content_key["KeyId"] = kek.get_kid() + wrapped_content_key["EncryptedKey"] = encode_base64(wrapped_cek) + wrapped_content_key["Algorithm"] = kek.get_key_wrap_algorithm() encryption_agent = OrderedDict() - encryption_agent['Protocol'] = version + encryption_agent["Protocol"] = version if version == _ENCRYPTION_PROTOCOL_V1: - encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_CBC_256 + encryption_agent["EncryptionAlgorithm"] = _EncryptionAlgorithm.AES_CBC_256 elif version == _ENCRYPTION_PROTOCOL_V2: - encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_GCM_256 + encryption_agent["EncryptionAlgorithm"] = _EncryptionAlgorithm.AES_GCM_256 encrypted_region_info = OrderedDict() - encrypted_region_info['DataLength'] = _GCM_REGION_DATA_LENGTH - encrypted_region_info['NonceLength'] = _GCM_NONCE_LENGTH + encrypted_region_info["DataLength"] = _GCM_REGION_DATA_LENGTH + encrypted_region_info["NonceLength"] = _GCM_NONCE_LENGTH encryption_data_dict: TypedOrderedDict[str, Any] = OrderedDict() - encryption_data_dict['WrappedContentKey'] = wrapped_content_key - encryption_data_dict['EncryptionAgent'] = encryption_agent + encryption_data_dict["WrappedContentKey"] = wrapped_content_key + encryption_data_dict["EncryptionAgent"] = encryption_agent if version == _ENCRYPTION_PROTOCOL_V1: - encryption_data_dict['ContentEncryptionIV'] = encode_base64(iv) + encryption_data_dict["ContentEncryptionIV"] = encode_base64(iv) elif version == _ENCRYPTION_PROTOCOL_V2: - encryption_data_dict['EncryptedRegionInfo'] = encrypted_region_info - encryption_data_dict['KeyWrappingMetadata'] = OrderedDict({'EncryptionLibrary': 'Python ' + VERSION}) + encryption_data_dict["EncryptedRegionInfo"] = encrypted_region_info + encryption_data_dict["KeyWrappingMetadata"] = OrderedDict({"EncryptionLibrary": "Python " + VERSION}) return encryption_data_dict @@ -554,43 +549,42 @@ def _dict_to_encryption_data(encryption_data_dict: Dict[str, Any]) -> _Encryptio :rtype: _EncryptionData """ try: - protocol = encryption_data_dict['EncryptionAgent']['Protocol'] + protocol = encryption_data_dict["EncryptionAgent"]["Protocol"] if protocol not in _VALID_ENCRYPTION_PROTOCOLS: raise ValueError("Unsupported encryption version.") except KeyError as exc: raise ValueError("Unsupported encryption version.") from exc - wrapped_content_key = encryption_data_dict['WrappedContentKey'] - wrapped_content_key = _WrappedContentKey(wrapped_content_key['Algorithm'], - decode_base64_to_bytes(wrapped_content_key['EncryptedKey']), - wrapped_content_key['KeyId']) - - encryption_agent = encryption_data_dict['EncryptionAgent'] - encryption_agent = _EncryptionAgent(encryption_agent['EncryptionAlgorithm'], - encryption_agent['Protocol']) - - if 'KeyWrappingMetadata' in encryption_data_dict: - key_wrapping_metadata = encryption_data_dict['KeyWrappingMetadata'] + wrapped_content_key = encryption_data_dict["WrappedContentKey"] + wrapped_content_key = _WrappedContentKey( + wrapped_content_key["Algorithm"], + decode_base64_to_bytes(wrapped_content_key["EncryptedKey"]), + wrapped_content_key["KeyId"], + ) + + encryption_agent = encryption_data_dict["EncryptionAgent"] + encryption_agent = _EncryptionAgent(encryption_agent["EncryptionAlgorithm"], encryption_agent["Protocol"]) + + if "KeyWrappingMetadata" in encryption_data_dict: + key_wrapping_metadata = encryption_data_dict["KeyWrappingMetadata"] else: key_wrapping_metadata = None # AES-CBC only encryption_iv = None - if 'ContentEncryptionIV' in encryption_data_dict: - encryption_iv = decode_base64_to_bytes(encryption_data_dict['ContentEncryptionIV']) + if "ContentEncryptionIV" in encryption_data_dict: + encryption_iv = decode_base64_to_bytes(encryption_data_dict["ContentEncryptionIV"]) # AES-GCM only region_info = None - if 'EncryptedRegionInfo' in encryption_data_dict: - encrypted_region_info = encryption_data_dict['EncryptedRegionInfo'] - region_info = _EncryptedRegionInfo(encrypted_region_info['DataLength'], - encrypted_region_info['NonceLength'], - _GCM_TAG_LENGTH) - - encryption_data = _EncryptionData(encryption_iv, - region_info, - encryption_agent, - wrapped_content_key, - key_wrapping_metadata) + if "EncryptedRegionInfo" in encryption_data_dict: + encrypted_region_info = encryption_data_dict["EncryptedRegionInfo"] + region_info = _EncryptedRegionInfo( + encrypted_region_info["DataLength"], encrypted_region_info["NonceLength"], _GCM_TAG_LENGTH + ) + + encryption_data = _EncryptionData( + encryption_iv, region_info, encryption_agent, wrapped_content_key, key_wrapping_metadata + ) return encryption_data @@ -614,7 +608,7 @@ def _generate_AES_CBC_cipher(cek: bytes, iv: bytes) -> Cipher: def _validate_and_unwrap_cek( encryption_data: _EncryptionData, key_encryption_key: Optional[KeyEncryptionKey] = None, - key_resolver: Optional[Callable[[str], KeyEncryptionKey]] = None + key_resolver: Optional[Callable[[str], KeyEncryptionKey]] = None, ) -> bytes: """ Extracts and returns the content_encryption_key stored in the encryption_data object @@ -636,15 +630,15 @@ def _validate_and_unwrap_cek( :rtype: bytes """ - _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) + _validate_not_none("encrypted_key", encryption_data.wrapped_content_key.encrypted_key) # Validate we have the right info for the specified version if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V1: - _validate_not_none('content_encryption_IV', encryption_data.content_encryption_IV) + _validate_not_none("content_encryption_IV", encryption_data.content_encryption_IV) elif encryption_data.encryption_agent.protocol in _ENCRYPTION_V2_PROTOCOLS: - _validate_not_none('encrypted_region_info', encryption_data.encrypted_region_info) + _validate_not_none("encrypted_region_info", encryption_data.encrypted_region_info) else: - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") content_encryption_key: Optional[bytes] = None @@ -654,29 +648,29 @@ def _validate_and_unwrap_cek( if key_encryption_key is None: raise ValueError("Unable to decrypt. key_resolver and key_encryption_key cannot both be None.") - if not hasattr(key_encryption_key, 'get_kid') or not callable(key_encryption_key.get_kid): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) - if not hasattr(key_encryption_key, 'unwrap_key') or not callable(key_encryption_key.unwrap_key): - raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'unwrap_key')) + if not hasattr(key_encryption_key, "get_kid") or not callable(key_encryption_key.get_kid): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "get_kid")) + if not hasattr(key_encryption_key, "unwrap_key") or not callable(key_encryption_key.unwrap_key): + raise AttributeError(_ERROR_OBJECT_INVALID.format("key encryption key", "unwrap_key")) if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): - raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') + raise ValueError("Provided or resolved key-encryption-key does not match the id of key used to encrypt.") # Will throw an exception if the specified algorithm is not supported. content_encryption_key = key_encryption_key.unwrap_key( - encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm + ) # For V2, the version is included with the cek. We need to validate it # and remove it from the actual cek. if encryption_data.encryption_agent.protocol in _ENCRYPTION_V2_PROTOCOLS: - version_2_bytes = encryption_data.encryption_agent.protocol.encode().ljust(8, b'\0') - cek_version_bytes = content_encryption_key[:len(version_2_bytes)] + version_2_bytes = encryption_data.encryption_agent.protocol.encode().ljust(8, b"\0") + cek_version_bytes = content_encryption_key[: len(version_2_bytes)] if cek_version_bytes != version_2_bytes: - raise ValueError('The encryption metadata is not valid and may have been modified.') + raise ValueError("The encryption metadata is not valid and may have been modified.") # Remove version from the start of the cek. - content_encryption_key = content_encryption_key[len(version_2_bytes):] + content_encryption_key = content_encryption_key[len(version_2_bytes) :] - _validate_not_none('content_encryption_key', content_encryption_key) + _validate_not_none("content_encryption_key", content_encryption_key) return content_encryption_key @@ -685,7 +679,7 @@ def _decrypt_message( message: bytes, encryption_data: _EncryptionData, key_encryption_key: Optional[KeyEncryptionKey] = None, - resolver: Optional[Callable[[str], KeyEncryptionKey]] = None + resolver: Optional[Callable[[str], KeyEncryptionKey]] = None, ) -> bytes: """ Decrypts the given ciphertext using AES256 in CBC mode with 128 bit padding. @@ -710,7 +704,7 @@ def _decrypt_message( :return: The decrypted plaintext. :rtype: bytes """ - _validate_not_none('message', message) + _validate_not_none("message", message) content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, resolver) if encryption_data.encryption_agent.protocol == _ENCRYPTION_PROTOCOL_V1: @@ -721,11 +715,11 @@ def _decrypt_message( # decrypt data decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(message) + decryptor.finalize()) + decrypted_data = decryptor.update(message) + decryptor.finalize() # unpad data unpadder = PKCS7(128).unpadder() - decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) + decrypted_data = unpadder.update(decrypted_data) + unpadder.finalize() elif encryption_data.encryption_agent.protocol in _ENCRYPTION_V2_PROTOCOLS: block_info = encryption_data.encrypted_region_info @@ -745,7 +739,7 @@ def _decrypt_message( decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) else: - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") return decrypted_data @@ -773,8 +767,8 @@ def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str :rtype: (str, bytes) """ - _validate_not_none('blob', blob) - _validate_not_none('key_encryption_key', key_encryption_key) + _validate_not_none("blob", blob) + _validate_not_none("key_encryption_key", key_encryption_key) _validate_key_encryption_key_wrap(key_encryption_key) if version == _ENCRYPTION_PROTOCOL_V1: @@ -805,16 +799,16 @@ def encrypt_blob(blob: bytes, key_encryption_key: KeyEncryptionKey, version: str else: raise ValueError("Invalid encryption version specified.") - encryption_data = _generate_encryption_data_dict(key_encryption_key, content_encryption_key, - initialization_vector, version) - encryption_data['EncryptionMode'] = 'FullBlob' + encryption_data = _generate_encryption_data_dict( + key_encryption_key, content_encryption_key, initialization_vector, version + ) + encryption_data["EncryptionMode"] = "FullBlob" return dumps(encryption_data), encrypted_data def generate_blob_encryption_data( - key_encryption_key: Optional[KeyEncryptionKey], - version: str + key_encryption_key: Optional[KeyEncryptionKey], version: str ) -> Tuple[Optional[bytes], Optional[bytes], Optional[str]]: """ Generates the encryption_metadata for the blob. @@ -836,24 +830,23 @@ def generate_blob_encryption_data( # Initialization vector only needed for V1 if version == _ENCRYPTION_PROTOCOL_V1: initialization_vector = os.urandom(16) - encryption_data_dict = _generate_encryption_data_dict(key_encryption_key, - content_encryption_key, - initialization_vector, - version) - encryption_data_dict['EncryptionMode'] = 'FullBlob' + encryption_data_dict = _generate_encryption_data_dict( + key_encryption_key, content_encryption_key, initialization_vector, version + ) + encryption_data_dict["EncryptionMode"] = "FullBlob" encryption_data = dumps(encryption_data_dict) return content_encryption_key, initialization_vector, encryption_data def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements - require_encryption: bool, - key_encryption_key: Optional[KeyEncryptionKey], - key_resolver: Optional[Callable[[str], KeyEncryptionKey]], - content: bytes, - start_offset: int, - end_offset: int, - response_headers: Dict[str, Any] + require_encryption: bool, + key_encryption_key: Optional[KeyEncryptionKey], + key_resolver: Optional[Callable[[str], KeyEncryptionKey]], + content: bytes, + start_offset: int, + end_offset: int, + response_headers: Dict[str, Any], ) -> bytes: """ Decrypts the given blob contents and returns only the requested range. @@ -885,39 +878,40 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements :rtype: bytes """ try: - encryption_data = _dict_to_encryption_data(loads(response_headers['x-ms-meta-encryptiondata'])) + encryption_data = _dict_to_encryption_data(loads(response_headers["x-ms-meta-encryptiondata"])) except Exception as exc: # pylint: disable=broad-except if require_encryption: raise ValueError( - 'Encryption required, but received data does not contain appropriate metadata.' + \ - 'Data was either not encrypted or metadata has been lost.') from exc + "Encryption required, but received data does not contain appropriate metadata." + + "Data was either not encrypted or metadata has been lost." + ) from exc return content algorithm = encryption_data.encryption_agent.encryption_algorithm - if algorithm not in(_EncryptionAlgorithm.AES_CBC_256, _EncryptionAlgorithm.AES_GCM_256): - raise ValueError('Specified encryption algorithm is not supported.') + if algorithm not in (_EncryptionAlgorithm.AES_CBC_256, _EncryptionAlgorithm.AES_GCM_256): + raise ValueError("Specified encryption algorithm is not supported.") version = encryption_data.encryption_agent.protocol if version not in _VALID_ENCRYPTION_PROTOCOLS: - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, key_resolver) if version == _ENCRYPTION_PROTOCOL_V1: - blob_type = response_headers['x-ms-blob-type'] + blob_type = response_headers["x-ms-blob-type"] iv: Optional[bytes] = None unpad = False - if 'content-range' in response_headers: - content_range = response_headers['content-range'] + if "content-range" in response_headers: + content_range = response_headers["content-range"] # Format: 'bytes x-y/size' # Ignore the word 'bytes' - content_range = content_range.split(' ') + content_range = content_range.split(" ") - content_range = content_range[1].split('-') - content_range = content_range[1].split('/') + content_range = content_range[1].split("-") + content_range = content_range[1].split("/") end_range = int(content_range[0]) blob_size = int(content_range[1]) @@ -934,7 +928,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements unpad = True iv = encryption_data.content_encryption_IV - if blob_type == 'PageBlob': + if blob_type == "PageBlob": unpad = False if iv is None: @@ -948,7 +942,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements unpadder = PKCS7(128).unpadder() content = unpadder.update(content) + unpadder.finalize() - return content[start_offset: len(content) - end_offset] + return content[start_offset : len(content) - end_offset] if version in _ENCRYPTION_V2_PROTOCOLS: # We assume the content contains only full encryption regions @@ -967,7 +961,7 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements while offset < total_size: # Process one encryption region at a time process_size = min(region_length, total_size) - encrypted_region = content[offset:offset + process_size] + encrypted_region = content[offset : offset + process_size] # First bytes are the nonce nonce = encrypted_region[:nonce_length] @@ -982,13 +976,11 @@ def decrypt_blob( # pylint: disable=too-many-locals,too-many-statements # Read the caller requested data from the decrypted content return decrypted_content[start_offset:end_offset] - raise ValueError('Specified encryption version is not supported.') + raise ValueError("Specified encryption version is not supported.") def get_blob_encryptor_and_padder( - cek: Optional[bytes], - iv: Optional[bytes], - should_pad: bool + cek: Optional[bytes], iv: Optional[bytes], should_pad: bool ) -> Tuple[Optional["AEADEncryptionContext"], Optional["PaddingContext"]]: encryptor = None padder = None @@ -1022,13 +1014,13 @@ def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, ve :rtype: str """ - _validate_not_none('message', message) - _validate_not_none('key_encryption_key', key_encryption_key) + _validate_not_none("message", message) + _validate_not_none("key_encryption_key", key_encryption_key) _validate_key_encryption_key_wrap(key_encryption_key) # Queue encoding functions all return unicode strings, and encryption should # operate on binary strings. - message_as_bytes: bytes = message.encode('utf-8') + message_as_bytes: bytes = message.encode("utf-8") if version == _ENCRYPTION_PROTOCOL_V1: # AES256 CBC uses 256 bit (32 byte) keys and always with 16 byte blocks @@ -1062,11 +1054,12 @@ def encrypt_queue_message(message: str, key_encryption_key: KeyEncryptionKey, ve raise ValueError("Invalid encryption version specified.") # Build the dictionary structure. - queue_message = {'EncryptedMessageContents': encode_base64(encrypted_data), - 'EncryptionData': _generate_encryption_data_dict(key_encryption_key, - content_encryption_key, - initialization_vector, - version)} + queue_message = { + "EncryptedMessageContents": encode_base64(encrypted_data), + "EncryptionData": _generate_encryption_data_dict( + key_encryption_key, content_encryption_key, initialization_vector, version + ), + } return dumps(queue_message) @@ -1076,7 +1069,7 @@ def decrypt_queue_message( response: "PipelineResponse", require_encryption: bool, key_encryption_key: Optional[KeyEncryptionKey], - resolver: Optional[Callable[[str], KeyEncryptionKey]] + resolver: Optional[Callable[[str], KeyEncryptionKey]], ) -> str: """ Returns the decrypted message contents from an EncryptedQueueMessage. @@ -1106,22 +1099,22 @@ def decrypt_queue_message( try: deserialized_message: Dict[str, Any] = loads(message) - encryption_data = _dict_to_encryption_data(deserialized_message['EncryptionData']) - decoded_data = decode_base64_to_bytes(deserialized_message['EncryptedMessageContents']) + encryption_data = _dict_to_encryption_data(deserialized_message["EncryptionData"]) + decoded_data = decode_base64_to_bytes(deserialized_message["EncryptedMessageContents"]) except (KeyError, ValueError) as exc: # Message was not json formatted and so was not encrypted # or the user provided a json formatted message # or the metadata was malformed. if require_encryption: raise ValueError( - 'Encryption required, but received message does not contain appropriate metatadata. ' + \ - 'Message was either not encrypted or metadata was incorrect.') from exc + "Encryption required, but received message does not contain appropriate metatadata. " + + "Message was either not encrypted or metadata was incorrect." + ) from exc return message try: - return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode('utf-8') + return _decrypt_message(decoded_data, encryption_data, key_encryption_key, resolver).decode("utf-8") except Exception as error: raise HttpResponseError( - message="Decryption failed.", - response=response, #type: ignore [arg-type] - error=error) from error + message="Decryption failed.", response=response, error=error # type: ignore [arg-type] + ) from error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py index ce490e354b68..bd62f9933338 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_message_encoding.py @@ -40,10 +40,11 @@ def __call__(self, content: Any) -> str: return content def configure( - self, require_encryption: bool, + self, + require_encryption: bool, key_encryption_key: Optional[KeyEncryptionKey], resolver: Optional[Callable[[str], KeyEncryptionKey]], - encryption_version: str = _ENCRYPTION_PROTOCOL_V1 + encryption_version: str = _ENCRYPTION_PROTOCOL_V1, ) -> None: self.require_encryption = require_encryption self.encryption_version = encryption_version @@ -77,17 +78,16 @@ def __call__(self, response: "PipelineResponse", obj: Iterable, headers: Dict[st content = message.message_text if (self.key_encryption_key is not None) or (self.resolver is not None): content = decrypt_queue_message( - content, response, - self.require_encryption, - self.key_encryption_key, - self.resolver) + content, response, self.require_encryption, self.key_encryption_key, self.resolver + ) message.message_text = self.decode(content, response) return obj def configure( - self, require_encryption: bool, + self, + require_encryption: bool, key_encryption_key: Optional[KeyEncryptionKey], - resolver: Optional[Callable[[str], KeyEncryptionKey]] + resolver: Optional[Callable[[str], KeyEncryptionKey]], ) -> None: self.require_encryption = require_encryption self.key_encryption_key = key_encryption_key @@ -107,7 +107,7 @@ class TextBase64EncodePolicy(MessageEncodePolicy): def encode(self, content: str) -> str: if not isinstance(content, str): raise TypeError("Message content must be text for base 64 encoding.") - return b64encode(content.encode('utf-8')).decode('utf-8') + return b64encode(content.encode("utf-8")).decode("utf-8") class TextBase64DecodePolicy(MessageDecodePolicy): @@ -120,13 +120,12 @@ class TextBase64DecodePolicy(MessageDecodePolicy): def decode(self, content: str, response: "PipelineResponse") -> str: try: - return b64decode(content.encode('utf-8')).decode('utf-8') + return b64decode(content.encode("utf-8")).decode("utf-8") except (ValueError, TypeError) as error: # ValueError for Python 3, TypeError for Python 2 raise DecodeError( - message="Message content is not valid base 64.", - response=response, #type: ignore - error=error) from error + message="Message content is not valid base 64.", response=response, error=error # type: ignore + ) from error class BinaryBase64EncodePolicy(MessageEncodePolicy): @@ -139,7 +138,7 @@ class BinaryBase64EncodePolicy(MessageEncodePolicy): def encode(self, content: bytes) -> str: if not isinstance(content, bytes): raise TypeError("Message content must be bytes for base 64 encoding.") - return b64encode(content).decode('utf-8') + return b64encode(content).decode("utf-8") class BinaryBase64DecodePolicy(MessageDecodePolicy): @@ -152,13 +151,12 @@ class BinaryBase64DecodePolicy(MessageDecodePolicy): def decode(self, content: str, response: "PipelineResponse") -> bytes: response = response.http_response try: - return b64decode(content.encode('utf-8')) + return b64decode(content.encode("utf-8")) except (ValueError, TypeError) as error: # ValueError for Python 3, TypeError for Python 2 raise DecodeError( - message="Message content is not valid base 64.", - response=response, #type: ignore - error=error) from error + message="Message content is not valid base 64.", response=response, error=error # type: ignore + ) from error class NoEncodePolicy(MessageEncodePolicy): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py index 3c565ce68f93..6d7faf9e860e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_models.py @@ -72,7 +72,7 @@ class QueueAnalyticsLogging(GeneratedLogging): :keyword ~azure.storage.queue.RetentionPolicy retention_policy: The retention policy for the metrics. """ - version: str = '1.0' + version: str = "1.0" """The version of Storage Analytics to configure.""" delete: bool = False """Indicates whether all delete requests should be logged.""" @@ -84,11 +84,11 @@ class QueueAnalyticsLogging(GeneratedLogging): """The retention policy for the metrics.""" def __init__(self, **kwargs: Any) -> None: - self.version = kwargs.get('version', '1.0') - self.delete = kwargs.get('delete', False) - self.read = kwargs.get('read', False) - self.write = kwargs.get('write', False) - self.retention_policy = kwargs.get('retention_policy') or RetentionPolicy() + self.version = kwargs.get("version", "1.0") + self.delete = kwargs.get("delete", False) + self.read = kwargs.get("read", False) + self.write = kwargs.get("write", False) + self.retention_policy = kwargs.get("retention_policy") or RetentionPolicy() @classmethod def _from_generated(cls, generated: Any) -> Self: @@ -99,7 +99,9 @@ def _from_generated(cls, generated: Any) -> Self: delete=generated.delete, read=generated.read, write=generated.write, - retention_policy=RetentionPolicy._from_generated(generated.retention_policy) # pylint: disable=protected-access + retention_policy=RetentionPolicy._from_generated( # pylint: disable=protected-access + generated.retention_policy + ), ) @@ -115,7 +117,7 @@ class Metrics(GeneratedMetrics): :keyword ~azure.storage.queue.RetentionPolicy retention_policy: The retention policy for the metrics. """ - version: str = '1.0' + version: str = "1.0" """The version of Storage Analytics to configure.""" enabled: bool = False """Indicates whether metrics are enabled for the service.""" @@ -125,10 +127,10 @@ class Metrics(GeneratedMetrics): """The retention policy for the metrics.""" def __init__(self, **kwargs: Any) -> None: - self.version = kwargs.get('version', '1.0') - self.enabled = kwargs.get('enabled', False) - self.include_apis = kwargs.get('include_apis') - self.retention_policy = kwargs.get('retention_policy') or RetentionPolicy() + self.version = kwargs.get("version", "1.0") + self.enabled = kwargs.get("enabled", False) + self.include_apis = kwargs.get("include_apis") + self.retention_policy = kwargs.get("retention_policy") or RetentionPolicy() @classmethod def _from_generated(cls, generated: Any) -> Self: @@ -138,7 +140,9 @@ def _from_generated(cls, generated: Any) -> Self: version=generated.version, enabled=generated.enabled, include_apis=generated.include_apis, - retention_policy=RetentionPolicy._from_generated(generated.retention_policy) # pylint: disable=protected-access + retention_policy=RetentionPolicy._from_generated( # pylint: disable=protected-access + generated.retention_policy + ), ) @@ -187,11 +191,11 @@ class CorsRule(GeneratedCorsRule): request.""" def __init__(self, allowed_origins: List[str], allowed_methods: List[str], **kwargs: Any) -> None: - self.allowed_origins = ','.join(allowed_origins) - self.allowed_methods = ','.join(allowed_methods) - self.allowed_headers = ','.join(kwargs.get('allowed_headers', [])) - self.exposed_headers = ','.join(kwargs.get('exposed_headers', [])) - self.max_age_in_seconds = kwargs.get('max_age_in_seconds', 0) + self.allowed_origins = ",".join(allowed_origins) + self.allowed_methods = ",".join(allowed_methods) + self.allowed_headers = ",".join(kwargs.get("allowed_headers", [])) + self.exposed_headers = ",".join(kwargs.get("exposed_headers", [])) + self.max_age_in_seconds = kwargs.get("max_age_in_seconds", 0) @staticmethod def _to_generated(rules: Optional[List["CorsRule"]]) -> Optional[List[GeneratedCorsRule]]: @@ -205,7 +209,7 @@ def _to_generated(rules: Optional[List["CorsRule"]]) -> Optional[List[GeneratedC allowed_methods=cors_rule.allowed_methods, allowed_headers=cors_rule.allowed_headers, exposed_headers=cors_rule.exposed_headers, - max_age_in_seconds=cors_rule.max_age_in_seconds + max_age_in_seconds=cors_rule.max_age_in_seconds, ) generated_cors_list.append(generated_cors) @@ -247,20 +251,17 @@ class QueueSasPermissions(object): process: bool = False """Get and delete messages from the queue.""" - def __init__( - self, read: bool = False, - add: bool = False, - update: bool = False, - process: bool = False - ) -> None: + def __init__(self, read: bool = False, add: bool = False, update: bool = False, process: bool = False) -> None: self.read = read self.add = add self.update = update self.process = process - self._str = (('r' if self.read else '') + - ('a' if self.add else '') + - ('u' if self.update else '') + - ('p' if self.process else '')) + self._str = ( + ("r" if self.read else "") + + ("a" if self.add else "") + + ("u" if self.update else "") + + ("p" if self.process else "") + ) def __str__(self): return self._str @@ -278,10 +279,10 @@ def from_string(cls, permission: str) -> Self: :return: A QueueSasPermissions object :rtype: ~azure.storage.queue.QueueSasPermissions """ - p_read = 'r' in permission - p_add = 'a' in permission - p_update = 'u' in permission - p_process = 'p' in permission + p_read = "r" in permission + p_add = "a" in permission + p_update = "u" in permission + p_process = "p" in permission parsed = cls(p_read, p_add, p_update, p_process) @@ -328,18 +329,19 @@ class AccessPolicy(GenAccessPolicy): be interpreted as UTC. """ - permission: Optional[Union[QueueSasPermissions, str]] #type: ignore [assignment] + permission: Optional[Union[QueueSasPermissions, str]] # type: ignore [assignment] """The permissions associated with the shared access signature. The user is restricted to operations allowed by the permissions.""" - expiry: Optional[Union["datetime", str]] #type: ignore [assignment] + expiry: Optional[Union["datetime", str]] # type: ignore [assignment] """The time at which the shared access signature becomes invalid.""" - start: Optional[Union["datetime", str]] #type: ignore [assignment] + start: Optional[Union["datetime", str]] # type: ignore [assignment] """The time at which the shared access signature becomes valid.""" def __init__( - self, permission: Optional[Union[QueueSasPermissions, str]] = None, + self, + permission: Optional[Union[QueueSasPermissions, str]] = None, expiry: Optional[Union["datetime", str]] = None, - start: Optional[Union["datetime", str]] = None + start: Optional[Union["datetime", str]] = None, ) -> None: self.start = start self.expiry = expiry @@ -374,13 +376,13 @@ class QueueMessage(DictMixin): Only returned by receive messages operations. Set to None for peek messages.""" def __init__(self, content: Optional[Any] = None, **kwargs: Any) -> None: - self.id = kwargs.pop('id', None) - self.inserted_on = kwargs.pop('inserted_on', None) - self.expires_on = kwargs.pop('expires_on', None) - self.dequeue_count = kwargs.pop('dequeue_count', None) + self.id = kwargs.pop("id", None) + self.inserted_on = kwargs.pop("inserted_on", None) + self.expires_on = kwargs.pop("expires_on", None) + self.dequeue_count = kwargs.pop("dequeue_count", None) self.content = content - self.pop_receipt = kwargs.pop('pop_receipt', None) - self.next_visible_on = kwargs.pop('next_visible_on', None) + self.pop_receipt = kwargs.pop("pop_receipt", None) + self.next_visible_on = kwargs.pop("next_visible_on", None) @classmethod def _from_generated(cls, generated: Any) -> Self: @@ -389,7 +391,7 @@ def _from_generated(cls, generated: Any) -> Self: message.inserted_on = generated.insertion_time message.expires_on = generated.expiration_time message.dequeue_count = generated.dequeue_count - if hasattr(generated, 'pop_receipt'): + if hasattr(generated, "pop_receipt"): message.pop_receipt = generated.pop_receipt message.next_visible_on = generated.time_next_visible return message @@ -413,10 +415,11 @@ class MessagesPaged(PageIterator): """The maximum number of messages to retrieve from the queue.""" def __init__( - self, command: Callable, + self, + command: Callable, results_per_page: Optional[int] = None, continuation_token: Optional[str] = None, - max_messages: Optional[int] = None + max_messages: Optional[int] = None, ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") @@ -469,9 +472,9 @@ class QueueProperties(DictMixin): def __init__(self, **kwargs: Any) -> None: # The name property will always be set to a non-None value after construction. - self.name = None #type: ignore [assignment] - self.metadata = kwargs.get('metadata') - self.approximate_message_count = kwargs.get('x-ms-approximate-messages-count') + self.name = None # type: ignore [assignment] + self.metadata = kwargs.get("metadata") + self.approximate_message_count = kwargs.get("x-ms-approximate-messages-count") @classmethod def _from_generated(cls, generated: Any) -> Self: @@ -510,15 +513,14 @@ class QueuePropertiesPaged(PageIterator): """Function to retrieve the next page of items.""" def __init__( - self, command: Callable, + self, + command: Callable, prefix: Optional[str] = None, results_per_page: Optional[int] = None, - continuation_token: Optional[str] = None + continuation_token: Optional[str] = None, ) -> None: super(QueuePropertiesPaged, self).__init__( - self._get_next_cb, - self._extract_data_cb, #type: ignore - continuation_token=continuation_token or "" + self._get_next_cb, self._extract_data_cb, continuation_token=continuation_token or "" # type: ignore ) self._command = command self.service_endpoint = None @@ -533,7 +535,8 @@ def _get_next_cb(self, continuation_token: Optional[str]) -> Any: marker=continuation_token or None, maxresults=self.results_per_page, cls=return_context_and_deserialized, - use_location=self.location_mode) + use_location=self.location_mode, + ) except HttpResponseError as error: process_storage_error(error) @@ -543,7 +546,9 @@ def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], List[Qu self.prefix = self._response.prefix self.marker = self._response.marker self.results_per_page = self._response.max_results - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + props_list = [ + QueueProperties._from_generated(q) for q in self._response.queue_items # pylint: disable=protected-access + ] return self._response.next_marker or None, props_list @@ -555,9 +560,9 @@ def service_stats_deserialize(generated: Any) -> Dict[str, Any]: :rtype: Dict[str, Any] """ return { - 'geo_replication': { - 'status': generated.geo_replication.status, - 'last_sync_time': generated.geo_replication.last_sync_time, + "geo_replication": { + "status": generated.geo_replication.status, + "last_sync_time": generated.geo_replication.last_sync_time, } } @@ -570,8 +575,10 @@ def service_properties_deserialize(generated: Any) -> Dict[str, Any]: :rtype: Dict[str, Any] """ return { - 'analytics_logging': QueueAnalyticsLogging._from_generated(generated.logging), # pylint: disable=protected-access - 'hour_metrics': Metrics._from_generated(generated.hour_metrics), # pylint: disable=protected-access - 'minute_metrics': Metrics._from_generated(generated.minute_metrics), # pylint: disable=protected-access - 'cors': [CorsRule._from_generated(cors) for cors in generated.cors], # pylint: disable=protected-access + "analytics_logging": QueueAnalyticsLogging._from_generated( # pylint: disable=protected-access + generated.logging + ), + "hour_metrics": Metrics._from_generated(generated.hour_metrics), # pylint: disable=protected-access + "minute_metrics": Metrics._from_generated(generated.minute_metrics), # pylint: disable=protected-access + "cors": [CorsRule._from_generated(cors) for cors in generated.cors], # pylint: disable=protected-access } diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py index cfa598d771b3..7f34d6baefe6 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client.py @@ -6,10 +6,7 @@ import functools import warnings -from typing import ( - Any, cast, Dict, List, Optional, - Tuple, TYPE_CHECKING, Union -) +from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import Self from azure.core.exceptions import HttpResponseError @@ -25,10 +22,7 @@ from ._serialize import get_api_version from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin from ._shared.request_handlers import add_metadata_headers, serialize_iso -from ._shared.response_handlers import ( - process_storage_error, - return_headers_and_deserialized, - return_response_headers) +from ._shared.response_handlers import process_storage_error, return_headers_and_deserialized, return_response_headers if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -36,7 +30,7 @@ BinaryBase64DecodePolicy, BinaryBase64EncodePolicy, TextBase64DecodePolicy, - TextBase64EncodePolicy + TextBase64EncodePolicy, ) from ._models import QueueProperties @@ -93,10 +87,14 @@ class QueueClient(StorageAccountHostsMixin, StorageEncryptionMixin): :dedent: 12 :caption: Create the queue client with url and credential. """ + def __init__( - self, account_url: str, + self, + account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, @@ -110,7 +108,7 @@ def __init__( self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__( parsed_url, - service='queue', + service="queue", credential=credential, secondary_hostname=secondary_hostname, audience=audience, @@ -134,8 +132,11 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_queue_url( - cls, queue_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + cls, + queue_url: str, + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, @@ -195,9 +196,12 @@ def from_queue_url( @classmethod def from_connection_string( - cls, conn_str: str, + cls, + conn_str: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, @@ -254,7 +258,7 @@ def from_connection_string( :dedent: 8 :caption: Create the queue client from connection string. """ - account_url, secondary, credential = parse_connection_str(conn_str, credential, 'queue') + account_url, secondary, credential = parse_connection_str(conn_str, credential, "queue") return cls( account_url, queue_name=queue_name, @@ -269,10 +273,7 @@ def from_connection_string( @distributed_trace def create_queue( - self, *, - metadata: Optional[Dict[str, str]] = None, - timeout: Optional[int] = None, - **kwargs: Any + self, *, metadata: Optional[Dict[str, str]] = None, timeout: Optional[int] = None, **kwargs: Any ) -> None: """Creates a new queue in the storage account. @@ -302,15 +303,11 @@ def create_queue( :dedent: 8 :caption: Create a queue. """ - headers = kwargs.pop('headers', {}) + headers = kwargs.pop("headers", {}) headers.update(add_metadata_headers(metadata)) try: return self._client.queue.create( - metadata=metadata, - timeout=timeout, - headers=headers, - cls=deserialize_queue_creation, - **kwargs + metadata=metadata, timeout=timeout, headers=headers, cls=deserialize_queue_creation, **kwargs ) except HttpResponseError as error: process_storage_error(error) @@ -370,11 +367,10 @@ def get_queue_properties(self, *, timeout: Optional[int] = None, **kwargs: Any) :caption: Get the properties on the queue. """ try: - response = cast("QueueProperties", self._client.queue.get_properties( - timeout=timeout, - cls=deserialize_queue_properties, - **kwargs - )) + response = cast( + "QueueProperties", + self._client.queue.get_properties(timeout=timeout, cls=deserialize_queue_properties, **kwargs), + ) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name @@ -382,10 +378,7 @@ def get_queue_properties(self, *, timeout: Optional[int] = None, **kwargs: Any) @distributed_trace def set_queue_metadata( - self, metadata: Optional[Dict[str, str]] = None, - *, - timeout: Optional[int] = None, - **kwargs: Any + self, metadata: Optional[Dict[str, str]] = None, *, timeout: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: """Sets user-defined metadata on the specified queue. @@ -412,14 +405,11 @@ def set_queue_metadata( :dedent: 12 :caption: Set metadata on the queue. """ - headers = kwargs.pop('headers', {}) + headers = kwargs.pop("headers", {}) headers.update(add_metadata_headers(metadata)) try: return self._client.queue.set_metadata( - timeout=timeout, - headers=headers, - cls=return_response_headers, - **kwargs + timeout=timeout, headers=headers, cls=return_response_headers, **kwargs ) except HttpResponseError as error: process_storage_error(error) @@ -439,21 +429,17 @@ def get_queue_access_policy(self, *, timeout: Optional[int] = None, **kwargs: An :rtype: Dict[str, ~azure.storage.queue.AccessPolicy] """ try: - _, identifiers = cast(Tuple[Dict, List], self._client.queue.get_access_policy( - timeout=timeout, - cls=return_headers_and_deserialized, - **kwargs - )) + _, identifiers = cast( + Tuple[Dict, List], + self._client.queue.get_access_policy(timeout=timeout, cls=return_headers_and_deserialized, **kwargs), + ) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace def set_queue_access_policy( - self, signed_identifiers: Dict[str, AccessPolicy], - *, - timeout: Optional[int] = None, - **kwargs: Any + self, signed_identifiers: Dict[str, AccessPolicy], *, timeout: Optional[int] = None, **kwargs: Any ) -> None: """Sets stored access policies for the queue that may be used with Shared Access Signatures. @@ -508,7 +494,8 @@ def set_queue_access_policy( @distributed_trace def send_message( - self, content: Optional[object], + self, + content: Optional[object], *, visibility_timeout: Optional[int] = None, time_to_live: Optional[int] = None, @@ -566,10 +553,7 @@ def send_message( """ if self.key_encryption_key: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) try: @@ -577,7 +561,7 @@ def send_message( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, - encryption_version=self.encryption_version + encryption_version=self.encryption_version, ) except TypeError: warnings.warn( @@ -589,7 +573,7 @@ def send_message( self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) encoded_content = self._message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) @@ -608,7 +592,7 @@ def send_message( inserted_on=enqueued[0].insertion_time, expires_on=enqueued[0].expiration_time, pop_receipt=enqueued[0].pop_receipt, - next_visible_on=enqueued[0].time_next_visible + next_visible_on=enqueued[0].time_next_visible, ) return queue_message except HttpResponseError as error: @@ -616,10 +600,7 @@ def send_message( @distributed_trace def receive_message( - self, *, - visibility_timeout: Optional[int] = None, - timeout: Optional[int] = None, - **kwargs: Any + self, *, visibility_timeout: Optional[int] = None, timeout: Optional[int] = None, **kwargs: Any ) -> Optional[QueueMessage]: """Removes one message from the front of the queue. @@ -660,16 +641,13 @@ def receive_message( """ if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) try: message = self._client.messages.dequeue( @@ -679,14 +657,17 @@ def receive_message( cls=self._message_decode_policy, **kwargs ) - wrapped_message = QueueMessage._from_generated(message[0]) if message != [] else None # pylint: disable=protected-access + wrapped_message = ( + QueueMessage._from_generated(message[0]) if message != [] else None # pylint: disable=protected-access + ) return wrapped_message except HttpResponseError as error: process_storage_error(error) @distributed_trace def receive_messages( - self, *, + self, + *, messages_per_page: Optional[int] = None, visibility_timeout: Optional[int] = None, max_messages: Optional[int] = None, @@ -753,16 +734,13 @@ def receive_messages( """ if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) try: command = functools.partial( @@ -779,14 +757,15 @@ def receive_messages( command, results_per_page=messages_per_page, page_iterator_class=MessagesPaged, - max_messages=max_messages + max_messages=max_messages, ) except HttpResponseError as error: process_storage_error(error) @distributed_trace def update_message( - self, message: Union[str, QueueMessage], + self, + message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, content: Optional[object] = None, *, @@ -846,10 +825,7 @@ def update_message( """ if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) if isinstance(message, QueueMessage): @@ -875,7 +851,7 @@ def update_message( self.require_encryption, self.key_encryption_key, self.key_resolver_function, - encryption_version=self.encryption_version + encryption_version=self.encryption_version, ) except TypeError: warnings.warn( @@ -885,32 +861,33 @@ def update_message( Retrying without encryption_version." ) self._message_encode_policy.configure( - self.require_encryption, - self.key_encryption_key, - self.key_resolver_function + self.require_encryption, self.key_encryption_key, self.key_resolver_function ) encoded_message_text = self._message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: updated = None try: - response = cast(QueueMessage, self._client.message_id.update( - queue_message=updated, - visibilitytimeout=visibility_timeout or 0, - timeout=timeout, - pop_receipt=receipt, - cls=return_response_headers, - queue_message_id=message_id, - **kwargs - )) + response = cast( + QueueMessage, + self._client.message_id.update( + queue_message=updated, + visibilitytimeout=visibility_timeout or 0, + timeout=timeout, + pop_receipt=receipt, + cls=return_response_headers, + queue_message_id=message_id, + **kwargs + ), + ) new_message = QueueMessage( content=message_text, id=message_id, inserted_on=inserted_on, dequeue_count=dequeue_count, expires_on=expires_on, - pop_receipt=response['popreceipt'], - next_visible_on=response['time_next_visible'] + pop_receipt=response["popreceipt"], + next_visible_on=response["time_next_visible"], ) return new_message except HttpResponseError as error: @@ -918,10 +895,7 @@ def update_message( @distributed_trace def peek_messages( - self, max_messages: Optional[int] = None, - *, - timeout: Optional[int] = None, - **kwargs: Any + self, max_messages: Optional[int] = None, *, timeout: Optional[int] = None, **kwargs: Any ) -> List[QueueMessage]: """Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. @@ -967,23 +941,17 @@ def peek_messages( if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) try: messages = self._client.messages.peek( - number_of_messages=max_messages, - timeout=timeout, - cls=self._message_decode_policy, - **kwargs + number_of_messages=max_messages, timeout=timeout, cls=self._message_decode_policy, **kwargs ) wrapped_messages = [] for peeked in messages: @@ -1019,7 +987,8 @@ def clear_messages(self, *, timeout: Optional[int] = None, **kwargs: Any) -> Non @distributed_trace def delete_message( - self, message: Union[str, QueueMessage], + self, + message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, *, timeout: Optional[int] = None, @@ -1070,11 +1039,6 @@ def delete_message( if receipt is None: raise ValueError("pop_receipt must be present") try: - self._client.message_id.delete( - pop_receipt=receipt, - timeout=timeout, - queue_message_id=message_id, - **kwargs - ) + self._client.message_id.delete(pop_receipt=receipt, timeout=timeout, queue_message_id=message_id, **kwargs) except HttpResponseError as error: process_storage_error(error) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py index 198439d55eda..36e1ddf3a7e7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_client_helpers.py @@ -17,7 +17,16 @@ def _parse_url( account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long + credential: Optional[ + Union[ + str, + Dict[str, str], + "AzureNamedKeyCredential", + "AzureSasCredential", + "AsyncTokenCredential", + "TokenCredential", + ] + ], ) -> Tuple["ParseResult", Any]: """Performs initial input validation and returns the parsed URL and SAS token. @@ -41,11 +50,11 @@ def _parse_url( :rtype: Tuple[ParseResult, Any] """ try: - if not account_url.lower().startswith('http'): + if not account_url.lower().startswith("http"): account_url = "https://" + account_url except AttributeError as exc: raise ValueError("Account URL must be a string.") from exc - parsed_url = urlparse(account_url.rstrip('/')) + parsed_url = urlparse(account_url.rstrip("/")) if not queue_name: raise ValueError("Please specify a queue name.") if not parsed_url.netloc: @@ -57,6 +66,7 @@ def _parse_url( return parsed_url, sas_token + def _format_url(queue_name: Union[bytes, str], hostname: str, scheme: str, query_str: str) -> str: """Format the endpoint URL according to the current location mode hostname. @@ -68,12 +78,11 @@ def _format_url(queue_name: Union[bytes, str], hostname: str, scheme: str, query :rtype: str """ if isinstance(queue_name, str): - queue_name = queue_name.encode('UTF-8') + queue_name = queue_name.encode("UTF-8") else: pass - return ( - f"{scheme}://{hostname}" - f"/{quote(queue_name)}{query_str}") + return f"{scheme}://{hostname}" f"/{quote(queue_name)}{query_str}" + def _from_queue_url(queue_url: str) -> Tuple[str, str]: """A client to interact with a specific Queue. @@ -83,23 +92,21 @@ def _from_queue_url(queue_url: str) -> Tuple[str, str]: :rtype: Tuple[str, str] """ try: - if not queue_url.lower().startswith('http'): + if not queue_url.lower().startswith("http"): queue_url = "https://" + queue_url except AttributeError as exc: raise ValueError("Queue URL must be a string.") from exc - parsed_url = urlparse(queue_url.rstrip('/')) + parsed_url = urlparse(queue_url.rstrip("/")) if not parsed_url.netloc: raise ValueError(f"Invalid URL: {queue_url}") - queue_path = parsed_url.path.lstrip('/').split('/') + queue_path = parsed_url.path.lstrip("/").split("/") account_path = "" if len(queue_path) > 1: account_path = "/" + "/".join(queue_path[:-1]) - account_url = ( - f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" - f"{account_path}?{parsed_url.query}") + account_url = f"{parsed_url.scheme}://{parsed_url.netloc.rstrip('/')}" f"{account_path}?{parsed_url.query}" queue_name = unquote(queue_path[-1]) if not queue_name: raise ValueError("Invalid URL. Please provide a URL with a valid queue name") - return(account_url, queue_name) + return (account_url, queue_name) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py index a99a3e2c78f3..8ab644be6054 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client.py @@ -5,10 +5,7 @@ # -------------------------------------------------------------------------- import functools -from typing import ( - Any, Dict, List, Optional, - TYPE_CHECKING, Union -) +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from typing_extensions import Self from azure.core.exceptions import HttpResponseError @@ -94,23 +91,26 @@ class QueueServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + self, + account_url: str, + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, audience: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> None: parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueServiceClient, self).__init__( parsed_url, - service='queue', + service="queue", credential=credential, secondary_hostname=secondary_hostname, audience=audience, - **kwargs + **kwargs, ) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline) self._client._config.version = get_api_version(api_version) # type: ignore [assignment] @@ -128,13 +128,16 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( - cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + cls, + conn_str: str, + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, audience: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> Self: """Create QueueServiceClient from a Connection String. @@ -174,14 +177,14 @@ def from_connection_string( :dedent: 8 :caption: Creating the QueueServiceClient with a connection string. """ - account_url, secondary, credential = parse_connection_str(conn_str, credential, 'queue') + account_url, secondary, credential = parse_connection_str(conn_str, credential, "queue") return cls( account_url, credential=credential, api_version=api_version, secondary_hostname=secondary_hostname or secondary, audience=audience, - **kwargs + **kwargs, ) @distributed_trace @@ -210,8 +213,7 @@ def get_service_stats(self, *, timeout: Optional[int] = None, **kwargs: Any) -> :rtype: Dict[str, Any] """ try: - stats = self._client.service.get_statistics( - timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) + stats = self._client.service.get_statistics(timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) return service_stats_deserialize(stats) except HttpResponseError as error: process_storage_error(error) @@ -244,13 +246,14 @@ def get_service_properties(self, *, timeout: Optional[int] = None, **kwargs: Any @distributed_trace def set_service_properties( - self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + self, + analytics_logging: Optional["QueueAnalyticsLogging"] = None, hour_metrics: Optional["Metrics"] = None, minute_metrics: Optional["Metrics"] = None, cors: Optional[List[CorsRule]] = None, *, timeout: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -290,7 +293,7 @@ def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=CorsRule._to_generated(cors) # pylint: disable=protected-access + cors=CorsRule._to_generated(cors), # pylint: disable=protected-access ) try: self._client.service.set_properties(props, timeout=timeout, **kwargs) @@ -299,12 +302,13 @@ def set_service_properties( @distributed_trace def list_queues( - self, name_starts_with: Optional[str] = None, + self, + name_starts_with: Optional[str] = None, include_metadata: Optional[bool] = False, *, results_per_page: Optional[int] = None, timeout: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> ItemPaged["QueueProperties"]: """Returns a generator to list the queues under the specified account. @@ -339,28 +343,24 @@ def list_queues( :dedent: 12 :caption: List queues in the service. """ - include = ['metadata'] if include_metadata else None + include = ["metadata"] if include_metadata else None command = functools.partial( self._client.service.list_queues_segment, prefix=name_starts_with, include=include, timeout=timeout, - **kwargs + **kwargs, ) return ItemPaged( command, prefix=name_starts_with, results_per_page=results_per_page, - page_iterator_class=QueuePropertiesPaged + page_iterator_class=QueuePropertiesPaged, ) @distributed_trace def create_queue( - self, name: str, - metadata: Optional[Dict[str, str]] = None, - *, - timeout: Optional[int] = None, - **kwargs: Any + self, name: str, metadata: Optional[Dict[str, str]] = None, *, timeout: Optional[int] = None, **kwargs: Any ) -> QueueClient: """Creates a new queue under the specified account. @@ -387,16 +387,13 @@ def create_queue( :caption: Create a queue in the service. """ queue = self.get_queue_client(name) - kwargs.setdefault('merge_span', True) + kwargs.setdefault("merge_span", True) queue.create_queue(metadata=metadata, timeout=timeout, **kwargs) return queue @distributed_trace def delete_queue( - self, queue: Union["QueueProperties", str], - *, - timeout: Optional[int] = None, - **kwargs: Any + self, queue: Union["QueueProperties", str], *, timeout: Optional[int] = None, **kwargs: Any ) -> None: """Deletes the specified queue and any messages it contains. @@ -426,7 +423,7 @@ def delete_queue( :caption: Delete a queue in the service. """ queue_client = self.get_queue_client(queue) - kwargs.setdefault('merge_span', True) + kwargs.setdefault("merge_span", True) queue_client.delete_queue(timeout=timeout, **kwargs) def get_queue_client(self, queue: Union["QueueProperties", str], **kwargs: Any) -> QueueClient: @@ -457,12 +454,21 @@ def get_queue_client(self, queue: Union["QueueProperties", str], **kwargs: Any) _pipeline = Pipeline( transport=TransportWrapper(self._pipeline._transport), # pylint: disable=protected-access - policies=self._pipeline._impl_policies # type: ignore # pylint: disable=protected-access + policies=self._pipeline._impl_policies, # type: ignore # pylint: disable=protected-access ) return QueueClient( - self.url, queue_name=queue_name, credential=self.credential, - key_resolver_function=self.key_resolver_function, require_encryption=self.require_encryption, - encryption_version=self.encryption_version, key_encryption_key=self.key_encryption_key, - api_version=self.api_version, _pipeline=_pipeline, _configuration=self._config, - _location_mode=self._location_mode, _hosts=self._hosts, **kwargs) + self.url, + queue_name=queue_name, + credential=self.credential, + key_resolver_function=self.key_resolver_function, + require_encryption=self.require_encryption, + encryption_version=self.encryption_version, + key_encryption_key=self.key_encryption_key, + api_version=self.api_version, + _pipeline=_pipeline, + _configuration=self._config, + _location_mode=self._location_mode, + _hosts=self._hosts, + **kwargs, + ) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py index a52216834e0f..9e29d00b4dd5 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_queue_service_client_helpers.py @@ -16,7 +16,16 @@ def _parse_url( account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential", "TokenCredential"]] # pylint: disable=line-too-long + credential: Optional[ + Union[ + str, + Dict[str, str], + "AzureNamedKeyCredential", + "AzureSasCredential", + "AsyncTokenCredential", + "TokenCredential", + ] + ], ) -> Tuple["ParseResult", Any]: """Performs initial input validation and returns the parsed URL and SAS token. @@ -39,11 +48,11 @@ def _parse_url( :rtype: Tuple[ParseResult, Any] """ try: - if not account_url.lower().startswith('http'): + if not account_url.lower().startswith("http"): account_url = "https://" + account_url except AttributeError as exc: raise ValueError("Account URL must be a string.") from exc - parsed_url = urlparse(account_url.rstrip('/')) + parsed_url = urlparse(account_url.rstrip("/")) if not parsed_url.netloc: raise ValueError(f"Invalid URL: {account_url}") diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_serialize.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_serialize.py index ad090b548469..23485b3f3d03 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_serialize.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_serialize.py @@ -6,37 +6,37 @@ from typing import Optional _SUPPORTED_API_VERSIONS = [ - '2019-02-02', - '2019-07-07', - '2019-10-10', - '2019-12-12', - '2020-02-10', - '2020-04-08', - '2020-06-12', - '2020-08-04', - '2020-10-02', - '2020-12-06', - '2021-02-12', - '2021-04-10', - '2021-06-08', - '2021-08-06', - '2021-12-02', - '2022-11-02', - '2023-01-03', - '2023-05-03', - '2023-08-03', - '2023-11-03', - '2024-05-04', - '2024-08-04', - '2024-11-04', - '2025-01-05', - '2025-05-05', - '2025-07-05', + "2019-02-02", + "2019-07-07", + "2019-10-10", + "2019-12-12", + "2020-02-10", + "2020-04-08", + "2020-06-12", + "2020-08-04", + "2020-10-02", + "2020-12-06", + "2021-02-12", + "2021-04-10", + "2021-06-08", + "2021-08-06", + "2021-12-02", + "2022-11-02", + "2023-01-03", + "2023-05-03", + "2023-08-03", + "2023-11-03", + "2024-05-04", + "2024-08-04", + "2024-11-04", + "2025-01-05", + "2025-05-05", + "2025-07-05", ] def get_api_version(api_version: Optional[str]) -> str: if api_version and api_version not in _SUPPORTED_API_VERSIONS: - versions = '\n'.join(_SUPPORTED_API_VERSIONS) + versions = "\n".join(_SUPPORTED_API_VERSIONS) raise ValueError(f"Unsupported API version '{api_version}'. Please select from:\n{versions}") return api_version or _SUPPORTED_API_VERSIONS[-1] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/__init__.py index a8b1a27d48f9..4dbbb7ed7b09 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/__init__.py @@ -11,7 +11,7 @@ try: from urllib.parse import quote, unquote except ImportError: - from urllib2 import quote, unquote # type: ignore + from urllib2 import quote, unquote # type: ignore def url_quote(url): @@ -24,20 +24,20 @@ def url_unquote(url): def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") def decode_base64_to_bytes(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") return base64.b64decode(data) def decode_base64_to_text(data): decoded_bytes = decode_base64_to_bytes(data) - return decoded_bytes.decode('utf-8') + return decoded_bytes.decode("utf-8") def sign_string(key, string_to_sign, key_is_base64=True): @@ -45,9 +45,9 @@ def sign_string(key, string_to_sign, key_is_base64=True): key = decode_base64_to_bytes(key) else: if isinstance(key, str): - key = key.encode('utf-8') + key = key.encode("utf-8") if isinstance(string_to_sign, str): - string_to_sign = string_to_sign.encode('utf-8') + string_to_sign = string_to_sign.encode("utf-8") signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) digest = signed_hmac_sha256.digest() encoded_digest = encode_base64(digest) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py index b41f2391ed4a..f778dc71eec4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/authentication.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) +# fmt: off table_lv0 = [ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, @@ -51,6 +52,8 @@ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, ] +# fmt: on + def compare(lhs: str, rhs: str) -> int: # pylint:disable=too-many-return-statements tables = [table_lv0, table_lv4] @@ -95,6 +98,7 @@ def _wrap_exception(ex, desired_type): msg = ex.args[0] return desired_type(msg) + # This method attempts to emulate the sorting done by the service def _storage_header_sort(input_headers: List[Tuple[str, str]]) -> List[Tuple[str, str]]: @@ -135,38 +139,42 @@ def __init__(self, account_name, account_key): @staticmethod def _get_headers(request, headers_to_sign): headers = dict((name.lower(), value) for name, value in request.http_request.headers.items() if value) - if 'content-length' in headers and headers['content-length'] == '0': - del headers['content-length'] - return '\n'.join(headers.get(x, '') for x in headers_to_sign) + '\n' + if "content-length" in headers and headers["content-length"] == "0": + del headers["content-length"] + return "\n".join(headers.get(x, "") for x in headers_to_sign) + "\n" @staticmethod def _get_verb(request): - return request.http_request.method + '\n' + return request.http_request.method + "\n" def _get_canonicalized_resource(self, request): uri_path = urlparse(request.http_request.url).path try: - if isinstance(request.context.transport, AioHttpTransport) or \ - isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) or \ - isinstance(getattr(getattr(request.context.transport, "_transport", None), "_transport", None), - AioHttpTransport): + if ( + isinstance(request.context.transport, AioHttpTransport) + or isinstance(getattr(request.context.transport, "_transport", None), AioHttpTransport) + or isinstance( + getattr(getattr(request.context.transport, "_transport", None), "_transport", None), + AioHttpTransport, + ) + ): uri_path = URL(uri_path) - return '/' + self.account_name + str(uri_path) + return "/" + self.account_name + str(uri_path) except TypeError: pass - return '/' + self.account_name + uri_path + return "/" + self.account_name + uri_path @staticmethod def _get_canonicalized_headers(request): - string_to_sign = '' + string_to_sign = "" x_ms_headers = [] for name, value in request.http_request.headers.items(): - if name.startswith('x-ms-'): + if name.startswith("x-ms-"): x_ms_headers.append((name.lower(), value)) x_ms_headers = _storage_header_sort(x_ms_headers) for name, value in x_ms_headers: if value is not None: - string_to_sign += ''.join([name, ':', value, '\n']) + string_to_sign += "".join([name, ":", value, "\n"]) return string_to_sign @staticmethod @@ -174,37 +182,46 @@ def _get_canonicalized_resource_query(request): sorted_queries = list(request.http_request.query.items()) sorted_queries.sort() - string_to_sign = '' + string_to_sign = "" for name, value in sorted_queries: if value is not None: - string_to_sign += '\n' + name.lower() + ':' + unquote(value) + string_to_sign += "\n" + name.lower() + ":" + unquote(value) return string_to_sign def _add_authorization_header(self, request, string_to_sign): try: signature = sign_string(self.account_key, string_to_sign) - auth_string = 'SharedKey ' + self.account_name + ':' + signature - request.http_request.headers['Authorization'] = auth_string + auth_string = "SharedKey " + self.account_name + ":" + signature + request.http_request.headers["Authorization"] = auth_string except Exception as ex: # Wrap any error that occurred as signing error # Doing so will clarify/locate the source of problem raise _wrap_exception(ex, AzureSigningError) from ex def on_request(self, request): - string_to_sign = \ - self._get_verb(request) + \ - self._get_headers( + string_to_sign = ( + self._get_verb(request) + + self._get_headers( request, [ - 'content-encoding', 'content-language', 'content-length', - 'content-md5', 'content-type', 'date', 'if-modified-since', - 'if-match', 'if-none-match', 'if-unmodified-since', 'byte_range' - ] - ) + \ - self._get_canonicalized_headers(request) + \ - self._get_canonicalized_resource(request) + \ - self._get_canonicalized_resource_query(request) + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", + ], + ) + + self._get_canonicalized_headers(request) + + self._get_canonicalized_resource(request) + + self._get_canonicalized_resource_query(request) + ) self._add_authorization_header(request, string_to_sign) # logger.debug("String_to_sign=%s", string_to_sign) @@ -212,7 +229,7 @@ def on_request(self, request): class StorageHttpChallenge(object): def __init__(self, challenge): - """ Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """ + """Parses an HTTP WWW-Authentication Bearer challenge from the Storage service.""" if not challenge: raise ValueError("Challenge cannot be empty") @@ -221,7 +238,7 @@ def __init__(self, challenge): # name=value pairs either comma or space separated with values possibly being # enclosed in quotes - for item in re.split('[, ]', trimmed_challenge): + for item in re.split("[, ]", trimmed_challenge): comps = item.split("=") if len(comps) == 2: key = comps[0].strip(' "') @@ -230,11 +247,11 @@ def __init__(self, challenge): self._parameters[key] = value # Extract and verify required parameters - self.authorization_uri = self._parameters.get('authorization_uri') + self.authorization_uri = self._parameters.get("authorization_uri") if not self.authorization_uri: raise ValueError("Authorization Uri not found") - self.resource_id = self._parameters.get('resource_id') + self.resource_id = self._parameters.get("resource_id") if not self.resource_id: raise ValueError("Resource id not found") diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index 7de14050b963..32d0e1bad6ea 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -20,7 +20,10 @@ from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential, TokenCredential from azure.core.exceptions import HttpResponseError from azure.core.pipeline import Pipeline -from azure.core.pipeline.transport import HttpTransport, RequestsTransport # pylint: disable=non-abstract-transport-import, no-name-in-module +from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import, no-name-in-module + HttpTransport, + RequestsTransport, +) from azure.core.pipeline.policies import ( AzureSasCredentialPolicy, ContentDecodePolicy, @@ -73,8 +76,17 @@ def __init__( self, parsed_url: Any, service: str, - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + credential: Optional[ + Union[ + str, + Dict[str, str], + AzureNamedKeyCredential, + AzureSasCredential, + "AsyncTokenCredential", + TokenCredential, + ] + ] = None, + **kwargs: Any, ) -> None: self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY) self._hosts = kwargs.get("_hosts", {}) @@ -83,12 +95,15 @@ def __init__( if service not in ["blob", "queue", "file-share", "dfs"]: raise ValueError(f"Invalid service: {service}") - service_name = service.split('-')[0] + service_name = service.split("-")[0] account = parsed_url.netloc.split(f".{service_name}.core.") self.account_name = account[0] if len(account) > 1 else None - if not self.account_name and parsed_url.netloc.startswith("localhost") \ - or parsed_url.netloc.startswith("127.0.0.1"): + if ( + not self.account_name + and parsed_url.netloc.startswith("localhost") + or parsed_url.netloc.startswith("127.0.0.1") + ): self._is_localhost = True self.account_name = parsed_url.path.strip("/") @@ -106,7 +121,7 @@ def __init__( secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary") if kwargs.get("secondary_hostname"): secondary_hostname = kwargs["secondary_hostname"] - primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/') + primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip("/") self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname} self._sdk_moniker = f"storage-{service}/{VERSION}" @@ -119,71 +134,76 @@ def __enter__(self): def __exit__(self, *args): self._client.__exit__(*args) - def close(self): - """ This method is to close the sockets opened by the client. + def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. """ self._client.close() @property - def url(self): + def url(self) -> str: """The full endpoint URL to this entity, including SAS token if used. This could be either the primary endpoint, or the secondary endpoint depending on the current :func:`location_mode`. - :returns: The full endpoint URL to this entity, including SAS token if used. + :return: The full endpoint URL to this entity, including SAS token if used. :rtype: str """ - return self._format_url(self._hosts[self._location_mode]) + return self._format_url(self._hosts[self._location_mode]) # type: ignore @property - def primary_endpoint(self): + def primary_endpoint(self) -> str: """The full primary endpoint URL. + :return: The full primary endpoint URL. :rtype: str """ - return self._format_url(self._hosts[LocationMode.PRIMARY]) + return self._format_url(self._hosts[LocationMode.PRIMARY]) # type: ignore @property - def primary_hostname(self): + def primary_hostname(self) -> str: """The hostname of the primary endpoint. + :return: The hostname of the primary endpoint. :rtype: str """ return self._hosts[LocationMode.PRIMARY] @property - def secondary_endpoint(self): + def secondary_endpoint(self) -> str: """The full secondary endpoint URL if configured. If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The full secondary endpoint URL. :rtype: str - :raise ValueError: + :raise ValueError: If no secondary endpoint is configured. """ if not self._hosts[LocationMode.SECONDARY]: raise ValueError("No secondary host configured.") - return self._format_url(self._hosts[LocationMode.SECONDARY]) + return self._format_url(self._hosts[LocationMode.SECONDARY]) # type: ignore @property - def secondary_hostname(self): + def secondary_hostname(self) -> Optional[str]: """The hostname of the secondary endpoint. If not available this will be None. To explicitly specify a secondary hostname, use the optional `secondary_hostname` keyword argument on instantiation. + :return: The hostname of the secondary endpoint, or None if not configured. :rtype: Optional[str] """ return self._hosts[LocationMode.SECONDARY] @property - def location_mode(self): + def location_mode(self) -> str: """The location mode that the client is currently using. By default this will be "primary". Options include "primary" and "secondary". + :return: The current location mode. :rtype: str """ @@ -206,11 +226,16 @@ def api_version(self): return self._client._config.version # pylint: disable=protected-access def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", TokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -218,7 +243,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): credential = cast(str, credential) query_str += credential.lstrip("?") @@ -228,13 +254,16 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, Pipeline]: self._credential_policy: Any = None if hasattr(credential, "get_token"): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = StorageBearerTokenCredentialPolicy(cast(TokenCredential, credential), audience) @@ -268,22 +297,18 @@ def _create_pipeline( config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs) + HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore return config, Pipeline(transport, policies=policies) - def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> Iterator["HttpResponse"]: + def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> Iterator["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An iterator of HttpResponse objects. + :return: An iterator of HttpResponse objects. :rtype: Iterator[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg @@ -292,25 +317,21 @@ def _batch_send( request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), headers={ - 'x-ms-version': self.api_version, - "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False) - } + "x-ms-version": self.api_version, + "Content-Type": "multipart/mixed; boundary=" + _get_batch_request_delimiter(batch_id, False, False), + }, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) Pipeline._prepare_multipart_mixed_request(request) # pylint: disable=protected-access body = serialize_batch_body(request.multipart_mixed_info[0], batch_id) @@ -318,9 +339,7 @@ def _batch_send( temp = request.multipart_mixed_info request.multipart_mixed_info = None - pipeline_response = self._pipeline.run( - request, **kwargs - ) + pipeline_response = self._pipeline.run(request, **kwargs) response = pipeline_response.http_response request.multipart_mixed_info = temp @@ -332,8 +351,7 @@ def _batch_send( parts = list(response.parts()) if any(p for p in parts if not 200 <= p.status_code < 300): error = PartialBatchErrorException( - message="There is a partial failure in the batch operation.", - response=response, parts=parts + message="There is a partial failure in the batch operation.", response=response, parts=parts ) raise error return iter(parts) @@ -347,6 +365,7 @@ class TransportWrapper(HttpTransport): by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, transport): self._transport = transport @@ -368,7 +387,9 @@ def __exit__(self, *args): def _format_shared_key_credential( account_name: Optional[str], - credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential]] = None # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, "AsyncTokenCredential", TokenCredential] + ] = None, ) -> Any: if isinstance(credential, str): if not account_name: @@ -388,8 +409,12 @@ def _format_shared_key_credential( def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, TokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -411,14 +436,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -438,7 +460,7 @@ def parse_connection_str( def create_configuration(**kwargs: Any) -> StorageConfiguration: - # Backwards compatibility if someone is not passing sdk_moniker + # Backwards compatibility if someone is not passing sdk_moniker if not kwargs.get("sdk_moniker"): kwargs["sdk_moniker"] = f"storage-{kwargs.pop('storage_sdk')}/{VERSION}" config = StorageConfiguration(**kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 6186b29db107..f39a57b24943 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -64,18 +64,26 @@ async def __aenter__(self): async def __aexit__(self, *args): await self._client.__aexit__(*args) - async def close(self): - """ This method is to close the sockets opened by the client. + async def close(self) -> None: + """This method is to close the sockets opened by the client. It need not be used when using with a context manager. + + :return: None + :rtype: None """ await self._client.close() def _format_query_string( - self, sas_token: Optional[str], - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]], # pylint: disable=line-too-long + self, + sas_token: Optional[str], + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential] + ], snapshot: Optional[str] = None, - share_snapshot: Optional[str] = None - ) -> Tuple[str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]]]: # pylint: disable=line-too-long + share_snapshot: Optional[str] = None, + ) -> Tuple[ + str, Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", AsyncTokenCredential]] + ]: query_str = "?" if snapshot: query_str += f"snapshot={snapshot}&" @@ -83,7 +91,8 @@ def _format_query_string( query_str += f"sharesnapshot={share_snapshot}&" if sas_token and isinstance(credential, AzureSasCredential): raise ValueError( - "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.") + "You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature." + ) if _is_credential_sastoken(credential): query_str += credential.lstrip("?") # type: ignore [union-attr] credential = None @@ -92,35 +101,40 @@ def _format_query_string( return query_str.rstrip("?&"), credential def _create_pipeline( - self, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]] = None, # pylint: disable=line-too-long - **kwargs: Any + self, + credential: Optional[ + Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential] + ] = None, + **kwargs: Any, ) -> Tuple[StorageConfiguration, AsyncPipeline]: self._credential_policy: Optional[ - Union[AsyncStorageBearerTokenCredentialPolicy, - SharedKeyCredentialPolicy, - AzureSasCredentialPolicy]] = None - if hasattr(credential, 'get_token'): - if kwargs.get('audience'): - audience = str(kwargs.pop('audience')).rstrip('/') + DEFAULT_OAUTH_SCOPE + Union[AsyncStorageBearerTokenCredentialPolicy, SharedKeyCredentialPolicy, AzureSasCredentialPolicy] + ] = None + if hasattr(credential, "get_token"): + if kwargs.get("audience"): + audience = str(kwargs.pop("audience")).rstrip("/") + DEFAULT_OAUTH_SCOPE else: audience = STORAGE_OAUTH_SCOPE self._credential_policy = AsyncStorageBearerTokenCredentialPolicy( - cast(AsyncTokenCredential, credential), audience) + cast(AsyncTokenCredential, credential), audience + ) elif isinstance(credential, SharedKeyCredentialPolicy): self._credential_policy = credential elif isinstance(credential, AzureSasCredential): self._credential_policy = AzureSasCredentialPolicy(credential) elif credential is not None: raise TypeError(f"Unsupported credential: {type(credential)}") - config = kwargs.get('_configuration') or create_configuration(**kwargs) - if kwargs.get('_pipeline'): - return config, kwargs['_pipeline'] - transport = kwargs.get('transport') + config = kwargs.get("_configuration") or create_configuration(**kwargs) + if kwargs.get("_pipeline"): + return config, kwargs["_pipeline"] + transport = kwargs.get("transport") kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT) kwargs.setdefault("read_timeout", READ_TIMEOUT) if not transport: try: - from azure.core.pipeline.transport import AioHttpTransport # pylint: disable=non-abstract-transport-import + from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import + AioHttpTransport, + ) except ImportError as exc: raise ImportError("Unable to create async transport. Please check aiohttp is installed.") from exc transport = AioHttpTransport(**kwargs) @@ -143,53 +157,41 @@ def _create_pipeline( HttpLoggingPolicy(**kwargs), ] if kwargs.get("_additional_pipeline_policies"): - policies = policies + kwargs.get("_additional_pipeline_policies") #type: ignore - config.transport = transport #type: ignore - return config, AsyncPipeline(transport, policies=policies) #type: ignore + policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore + config.transport = transport # type: ignore + return config, AsyncPipeline(transport, policies=policies) # type: ignore - async def _batch_send( - self, - *reqs: "HttpRequest", - **kwargs: Any - ) -> AsyncList["HttpResponse"]: + async def _batch_send(self, *reqs: "HttpRequest", **kwargs: Any) -> AsyncList["HttpResponse"]: """Given a series of request, do a Storage batch call. :param HttpRequest reqs: A collection of HttpRequest objects. - :returns: An AsyncList of HttpResponse objects. + :return: An AsyncList of HttpResponse objects. :rtype: AsyncList[HttpResponse] """ # Pop it here, so requests doesn't feel bad about additional kwarg raise_on_any_failure = kwargs.pop("raise_on_any_failure", True) request = self._client._client.post( # pylint: disable=protected-access url=( - f'{self.scheme}://{self.primary_hostname}/' + f"{self.scheme}://{self.primary_hostname}/" f"{kwargs.pop('path', '')}?{kwargs.pop('restype', '')}" f"comp=batch{kwargs.pop('sas', '')}{kwargs.pop('timeout', '')}" ), - headers={ - 'x-ms-version': self.api_version - } + headers={"x-ms-version": self.api_version}, ) policies = [StorageHeadersPolicy()] if self._credential_policy: policies.append(self._credential_policy) # type: ignore - request.set_multipart_mixed( - *reqs, - policies=policies, - enforce_https=False - ) + request.set_multipart_mixed(*reqs, policies=policies, enforce_https=False) - pipeline_response = await self._pipeline.run( - request, **kwargs - ) + pipeline_response = await self._pipeline.run(request, **kwargs) response = pipeline_response.http_response try: if response.status_code not in [202]: raise HttpResponseError(response=response) - parts = response.parts() # Return an AsyncIterator + parts = response.parts() # Return an AsyncIterator if raise_on_any_failure: parts_list = [] async for part in parts: @@ -197,7 +199,8 @@ async def _batch_send( if any(p for p in parts_list if not 200 <= p.status_code < 300): error = PartialBatchErrorException( message="There is a partial failure in the batch operation.", - response=response, parts=parts_list + response=response, + parts=parts_list, ) raise error return AsyncList(parts_list) @@ -205,11 +208,16 @@ async def _batch_send( except HttpResponseError as error: process_storage_error(error) + def parse_connection_str( conn_str: str, credential: Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], - service: str -) -> Tuple[str, Optional[str], Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]]]: # pylint: disable=line-too-long + service: str, +) -> Tuple[ + str, + Optional[str], + Optional[Union[str, Dict[str, str], AzureNamedKeyCredential, AzureSasCredential, AsyncTokenCredential]], +]: conn_str = conn_str.rstrip(";") conn_settings_list = [s.split("=", 1) for s in conn_str.split(";")] if any(len(tup) != 2 for tup in conn_settings_list): @@ -231,14 +239,11 @@ def parse_connection_str( if endpoints["secondary"] in conn_settings: raise ValueError("Connection string specifies only secondary endpoint.") try: - primary =( + primary = ( f"{conn_settings['DEFAULTENDPOINTSPROTOCOL']}://" f"{conn_settings['ACCOUNTNAME']}.{service}.{conn_settings['ENDPOINTSUFFIX']}" ) - secondary = ( - f"{conn_settings['ACCOUNTNAME']}-secondary." - f"{service}.{conn_settings['ENDPOINTSUFFIX']}" - ) + secondary = f"{conn_settings['ACCOUNTNAME']}-secondary." f"{service}.{conn_settings['ENDPOINTSUFFIX']}" except KeyError: pass @@ -256,11 +261,13 @@ def parse_connection_str( secondary = secondary.replace(".blob.", ".dfs.") return primary, secondary, credential + class AsyncTransportWrapper(AsyncHttpTransport): """Wrapper class that ensures that an inner client created by a `get_client` method does not close the outer transport for the parent when used in a context manager. """ + def __init__(self, async_transport): self._transport = async_transport diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/constants.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/constants.py index 0b4b029a2d1b..0926f04c4081 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/constants.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/constants.py @@ -16,4 +16,4 @@ DEFAULT_OAUTH_SCOPE = "/.default" STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default" -SERVICE_HOST_BASE = 'core.windows.net' +SERVICE_HOST_BASE = "core.windows.net" diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py index c8949723449b..9ffa6813efbc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/models.py @@ -22,6 +22,7 @@ def get_enum_value(value): class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Error codes returned by the service.""" # Generic storage values ACCOUNT_ALREADY_EXISTS = "AccountAlreadyExists" @@ -172,26 +173,26 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): CONTAINER_QUOTA_DOWNGRADE_NOT_ALLOWED = "ContainerQuotaDowngradeNotAllowed" # DataLake values - CONTENT_LENGTH_MUST_BE_ZERO = 'ContentLengthMustBeZero' - PATH_ALREADY_EXISTS = 'PathAlreadyExists' - INVALID_FLUSH_POSITION = 'InvalidFlushPosition' - INVALID_PROPERTY_NAME = 'InvalidPropertyName' - INVALID_SOURCE_URI = 'InvalidSourceUri' - UNSUPPORTED_REST_VERSION = 'UnsupportedRestVersion' - FILE_SYSTEM_NOT_FOUND = 'FilesystemNotFound' - PATH_NOT_FOUND = 'PathNotFound' - RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = 'RenameDestinationParentPathNotFound' - SOURCE_PATH_NOT_FOUND = 'SourcePathNotFound' - DESTINATION_PATH_IS_BEING_DELETED = 'DestinationPathIsBeingDeleted' - FILE_SYSTEM_ALREADY_EXISTS = 'FilesystemAlreadyExists' - FILE_SYSTEM_BEING_DELETED = 'FilesystemBeingDeleted' - INVALID_DESTINATION_PATH = 'InvalidDestinationPath' - INVALID_RENAME_SOURCE_PATH = 'InvalidRenameSourcePath' - INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = 'InvalidSourceOrDestinationResourceType' - LEASE_IS_ALREADY_BROKEN = 'LeaseIsAlreadyBroken' - LEASE_NAME_MISMATCH = 'LeaseNameMismatch' - PATH_CONFLICT = 'PathConflict' - SOURCE_PATH_IS_BEING_DELETED = 'SourcePathIsBeingDeleted' + CONTENT_LENGTH_MUST_BE_ZERO = "ContentLengthMustBeZero" + PATH_ALREADY_EXISTS = "PathAlreadyExists" + INVALID_FLUSH_POSITION = "InvalidFlushPosition" + INVALID_PROPERTY_NAME = "InvalidPropertyName" + INVALID_SOURCE_URI = "InvalidSourceUri" + UNSUPPORTED_REST_VERSION = "UnsupportedRestVersion" + FILE_SYSTEM_NOT_FOUND = "FilesystemNotFound" + PATH_NOT_FOUND = "PathNotFound" + RENAME_DESTINATION_PARENT_PATH_NOT_FOUND = "RenameDestinationParentPathNotFound" + SOURCE_PATH_NOT_FOUND = "SourcePathNotFound" + DESTINATION_PATH_IS_BEING_DELETED = "DestinationPathIsBeingDeleted" + FILE_SYSTEM_ALREADY_EXISTS = "FilesystemAlreadyExists" + FILE_SYSTEM_BEING_DELETED = "FilesystemBeingDeleted" + INVALID_DESTINATION_PATH = "InvalidDestinationPath" + INVALID_RENAME_SOURCE_PATH = "InvalidRenameSourcePath" + INVALID_SOURCE_OR_DESTINATION_RESOURCE_TYPE = "InvalidSourceOrDestinationResourceType" + LEASE_IS_ALREADY_BROKEN = "LeaseIsAlreadyBroken" + LEASE_NAME_MISMATCH = "LeaseNameMismatch" + PATH_CONFLICT = "PathConflict" + SOURCE_PATH_IS_BEING_DELETED = "SourcePathIsBeingDeleted" class DictMixin(object): @@ -222,7 +223,7 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')}) + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) def __contains__(self, key): return key in self.__dict__ @@ -234,13 +235,13 @@ def update(self, *args, **kwargs): return self.__dict__.update(*args, **kwargs) def keys(self): - return [k for k in self.__dict__ if not k.startswith('_')] + return [k for k in self.__dict__ if not k.startswith("_")] def values(self): - return [v for k, v in self.__dict__.items() if not k.startswith('_')] + return [v for k, v in self.__dict__.items() if not k.startswith("_")] def items(self): - return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')] + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] def get(self, key, default=None): if key in self.__dict__: @@ -255,8 +256,8 @@ class LocationMode(object): must use PRIMARY. """ - PRIMARY = 'primary' #: Requests should be sent to the primary location. - SECONDARY = 'secondary' #: Requests should be sent to the secondary location, if possible. + PRIMARY = "primary" #: Requests should be sent to the primary location. + SECONDARY = "secondary" #: Requests should be sent to the secondary location, if possible. class ResourceTypes(object): @@ -281,17 +282,12 @@ class ResourceTypes(object): _str: str def __init__( - self, - service: bool = False, - container: bool = False, - object: bool = False # pylint: disable=redefined-builtin + self, service: bool = False, container: bool = False, object: bool = False # pylint: disable=redefined-builtin ) -> None: self.service = service self.container = container self.object = object - self._str = (('s' if self.service else '') + - ('c' if self.container else '') + - ('o' if self.object else '')) + self._str = ("s" if self.service else "") + ("c" if self.container else "") + ("o" if self.object else "") def __str__(self): return self._str @@ -309,9 +305,9 @@ def from_string(cls, string): :return: A ResourceTypes object :rtype: ~azure.storage.queue.ResourceTypes """ - res_service = 's' in string - res_container = 'c' in string - res_object = 'o' in string + res_service = "s" in string + res_container = "c" in string + res_object = "o" in string parsed = cls(res_service, res_container, res_object) parsed._str = string @@ -392,29 +388,30 @@ def __init__( self.write = write self.delete = delete self.delete_previous_version = delete_previous_version - self.permanent_delete = kwargs.pop('permanent_delete', False) + self.permanent_delete = kwargs.pop("permanent_delete", False) self.list = list self.add = add self.create = create self.update = update self.process = process - self.tag = kwargs.pop('tag', False) - self.filter_by_tags = kwargs.pop('filter_by_tags', False) - self.set_immutability_policy = kwargs.pop('set_immutability_policy', False) - self._str = (('r' if self.read else '') + - ('w' if self.write else '') + - ('d' if self.delete else '') + - ('x' if self.delete_previous_version else '') + - ('y' if self.permanent_delete else '') + - ('l' if self.list else '') + - ('a' if self.add else '') + - ('c' if self.create else '') + - ('u' if self.update else '') + - ('p' if self.process else '') + - ('f' if self.filter_by_tags else '') + - ('t' if self.tag else '') + - ('i' if self.set_immutability_policy else '') - ) + self.tag = kwargs.pop("tag", False) + self.filter_by_tags = kwargs.pop("filter_by_tags", False) + self.set_immutability_policy = kwargs.pop("set_immutability_policy", False) + self._str = ( + ("r" if self.read else "") + + ("w" if self.write else "") + + ("d" if self.delete else "") + + ("x" if self.delete_previous_version else "") + + ("y" if self.permanent_delete else "") + + ("l" if self.list else "") + + ("a" if self.add else "") + + ("c" if self.create else "") + + ("u" if self.update else "") + + ("p" if self.process else "") + + ("f" if self.filter_by_tags else "") + + ("t" if self.tag else "") + + ("i" if self.set_immutability_policy else "") + ) def __str__(self): return self._str @@ -432,23 +429,34 @@ def from_string(cls, permission): :return: An AccountSasPermissions object :rtype: ~azure.storage.queue.AccountSasPermissions """ - p_read = 'r' in permission - p_write = 'w' in permission - p_delete = 'd' in permission - p_delete_previous_version = 'x' in permission - p_permanent_delete = 'y' in permission - p_list = 'l' in permission - p_add = 'a' in permission - p_create = 'c' in permission - p_update = 'u' in permission - p_process = 'p' in permission - p_tag = 't' in permission - p_filter_by_tags = 'f' in permission - p_set_immutability_policy = 'i' in permission - parsed = cls(read=p_read, write=p_write, delete=p_delete, delete_previous_version=p_delete_previous_version, - list=p_list, add=p_add, create=p_create, update=p_update, process=p_process, tag=p_tag, - filter_by_tags=p_filter_by_tags, set_immutability_policy=p_set_immutability_policy, - permanent_delete=p_permanent_delete) + p_read = "r" in permission + p_write = "w" in permission + p_delete = "d" in permission + p_delete_previous_version = "x" in permission + p_permanent_delete = "y" in permission + p_list = "l" in permission + p_add = "a" in permission + p_create = "c" in permission + p_update = "u" in permission + p_process = "p" in permission + p_tag = "t" in permission + p_filter_by_tags = "f" in permission + p_set_immutability_policy = "i" in permission + parsed = cls( + read=p_read, + write=p_write, + delete=p_delete, + delete_previous_version=p_delete_previous_version, + list=p_list, + add=p_add, + create=p_create, + update=p_update, + process=p_process, + tag=p_tag, + filter_by_tags=p_filter_by_tags, + set_immutability_policy=p_set_immutability_policy, + permanent_delete=p_permanent_delete, + ) return parsed @@ -464,18 +472,11 @@ class Services(object): Access for the `~azure.storage.fileshare.ShareServiceClient`. Default is False. """ - def __init__( - self, *, - blob: bool = False, - queue: bool = False, - fileshare: bool = False - ) -> None: + def __init__(self, *, blob: bool = False, queue: bool = False, fileshare: bool = False) -> None: self.blob = blob self.queue = queue self.fileshare = fileshare - self._str = (('b' if self.blob else '') + - ('q' if self.queue else '') + - ('f' if self.fileshare else '')) + self._str = ("b" if self.blob else "") + ("q" if self.queue else "") + ("f" if self.fileshare else "") def __str__(self): return self._str @@ -493,9 +494,9 @@ def from_string(cls, string): :return: A Services object :rtype: ~azure.storage.queue.Services """ - res_blob = 'b' in string - res_queue = 'q' in string - res_file = 'f' in string + res_blob = "b" in string + res_queue = "q" in string + res_file = "f" in string parsed = cls(blob=res_blob, queue=res_queue, fileshare=res_file) parsed._str = string @@ -573,13 +574,13 @@ class StorageConfiguration(Configuration): def __init__(self, **kwargs): super(StorageConfiguration, self).__init__(**kwargs) - self.max_single_put_size = kwargs.pop('max_single_put_size', 64 * 1024 * 1024) + self.max_single_put_size = kwargs.pop("max_single_put_size", 64 * 1024 * 1024) self.copy_polling_interval = 15 - self.max_block_size = kwargs.pop('max_block_size', 4 * 1024 * 1024) - self.min_large_block_upload_threshold = kwargs.get('min_large_block_upload_threshold', 4 * 1024 * 1024 + 1) - self.use_byte_buffer = kwargs.pop('use_byte_buffer', False) - self.max_page_size = kwargs.pop('max_page_size', 4 * 1024 * 1024) - self.min_large_chunk_upload_threshold = kwargs.pop('min_large_chunk_upload_threshold', 100 * 1024 * 1024 + 1) - self.max_single_get_size = kwargs.pop('max_single_get_size', 32 * 1024 * 1024) - self.max_chunk_get_size = kwargs.pop('max_chunk_get_size', 4 * 1024 * 1024) - self.max_range_size = kwargs.pop('max_range_size', 4 * 1024 * 1024) + self.max_block_size = kwargs.pop("max_block_size", 4 * 1024 * 1024) + self.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1) + self.use_byte_buffer = kwargs.pop("use_byte_buffer", False) + self.max_page_size = kwargs.pop("max_page_size", 4 * 1024 * 1024) + self.min_large_chunk_upload_threshold = kwargs.pop("min_large_chunk_upload_threshold", 100 * 1024 * 1024 + 1) + self.max_single_get_size = kwargs.pop("max_single_get_size", 32 * 1024 * 1024) + self.max_chunk_get_size = kwargs.pop("max_chunk_get_size", 4 * 1024 * 1024) + self.max_range_size = kwargs.pop("max_range_size", 4 * 1024 * 1024) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py index 112c1984f4fb..e4fcb8f041ba 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/parser.py @@ -12,14 +12,14 @@ def _to_utc_datetime(value: datetime) -> str: - return value.strftime('%Y-%m-%dT%H:%M:%SZ') + return value.strftime("%Y-%m-%dT%H:%M:%SZ") def _rfc_1123_to_datetime(rfc_1123: str) -> Optional[datetime]: """Converts an RFC 1123 date string to a UTC datetime. :param str rfc_1123: The time and date in RFC 1123 format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not rfc_1123: @@ -33,7 +33,7 @@ def _filetime_to_datetime(filetime: str) -> Optional[datetime]: If parsing MS Filetime fails, tries RFC 1123 as backup. :param str filetime: The time and date in MS filetime format. - :returns: The time and date in UTC datetime format. + :return: The time and date in UTC datetime format. :rtype: datetime """ if not filetime: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index e1011d0fb832..f12d102be54e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -28,7 +28,7 @@ HTTPPolicy, NetworkTraceLoggingPolicy, RequestHistory, - SansIOHTTPPolicy + SansIOHTTPPolicy, ) from .authentication import AzureSigningError, StorageHttpChallenge @@ -39,7 +39,7 @@ from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -48,14 +48,14 @@ def encode_base64(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") encoded = base64.b64encode(data) - return encoded.decode('utf-8') + return encoded.decode("utf-8") # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status']) + retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -63,8 +63,8 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): - if settings['hook']: - settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs) + if settings["hook"]: + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -95,17 +95,18 @@ def is_retry(response, mode): def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False def urljoin(base_url, stub_url): parsed = urlparse(base_url) - parsed = parsed._replace(path=parsed.path + '/' + stub_url) + parsed = parsed._replace(path=parsed.path + "/" + stub_url) return parsed.geturl() @@ -113,28 +114,26 @@ class QueueMessagePolicy(SansIOHTTPPolicy): def on_request(self, request): # Hack to fix generated code adding '/messages' after SAS parameters - includes_messages = request.http_request.url.endswith('/messages') + includes_messages = request.http_request.url.endswith("/messages") if includes_messages: - request.http_request.url = request.http_request.url[:-(len('/messages'))] - request.http_request.url = urljoin(request.http_request.url, 'messages') + request.http_request.url = request.http_request.url[: -(len("/messages"))] + request.http_request.url = urljoin(request.http_request.url, "messages") - message_id = request.context.options.pop('queue_message_id', None) + message_id = request.context.options.pop("queue_message_id", None) if message_id: - request.http_request.url = urljoin( - request.http_request.url, - message_id) + request.http_request.url = urljoin(request.http_request.url, message_id) class StorageHeadersPolicy(HeadersPolicy): - request_id_header_name = 'x-ms-client-request-id' + request_id_header_name = "x-ms-client-request-id" def on_request(self, request: "PipelineRequest") -> None: super(StorageHeadersPolicy, self).on_request(request) current_time = format_date_time(time()) - request.http_request.headers['x-ms-date'] = current_time + request.http_request.headers["x-ms-date"] = current_time - custom_id = request.context.options.pop('client_request_id', None) - request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1()) + custom_id = request.context.options.pop("client_request_id", None) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -159,7 +158,7 @@ def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument super(StorageHosts, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request.context.options['hosts'] = self.hosts + request.context.options["hosts"] = self.hosts parsed_url = urlparse(request.http_request.url) # Detect what location mode we're currently requesting with @@ -169,10 +168,10 @@ def on_request(self, request: "PipelineRequest") -> None: location_mode = key # See if a specific location mode has been specified, and if so, redirect - use_location = request.context.options.pop('use_location', None) + use_location = request.context.options.pop("use_location", None) if use_location: # Lock retries to the specific location - request.context.options['retry_to_secondary'] = False + request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: @@ -181,7 +180,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.url = updated.geturl() location_mode = use_location - request.context.options['location_mode'] = location_mode + request.context.options["location_mode"] = location_mode class StorageLoggingPolicy(NetworkTraceLoggingPolicy): @@ -206,19 +205,19 @@ def on_request(self, request: "PipelineRequest") -> None: try: log_url = http_request.url query_params = http_request.query - if 'sig' in query_params: - log_url = log_url.replace(query_params['sig'], "sig=*****") + if "sig" in query_params: + log_url = log_url.replace(query_params["sig"], "sig=*****") _LOGGER.debug("Request URL: %r", log_url) _LOGGER.debug("Request method: %r", http_request.method) _LOGGER.debug("Request headers:") for header, value in http_request.headers.items(): - if header.lower() == 'authorization': - value = '*****' - elif header.lower() == 'x-ms-copy-source' and 'sig' in value: + if header.lower() == "authorization": + value = "*****" + elif header.lower() == "x-ms-copy-source" and "sig" in value: # take the url apart and scrub away the signed signature scheme, netloc, path, params, query, fragment = urlparse(value) parsed_qs = dict(parse_qsl(query)) - parsed_qs['sig'] = '*****' + parsed_qs["sig"] = "*****" # the SAS needs to be put back together value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) @@ -248,11 +247,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") # We don't want to log binary data if the response is a file. _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) - header = response.http_response.headers.get('content-disposition') + header = response.http_response.headers.get("content-disposition") resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): - filename = header.partition('=')[2] + filename = header.partition("=")[2] _LOGGER.debug("File attachments: %s", filename) elif resp_content_type.endswith("octet-stream"): _LOGGER.debug("Body contains binary data.") @@ -274,11 +273,11 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") class StorageRequestHook(SansIOHTTPPolicy): def __init__(self, **kwargs): - self._request_callback = kwargs.get('raw_request_hook') + self._request_callback = kwargs.get("raw_request_hook") super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop('raw_request_hook', self._request_callback) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -286,49 +285,50 @@ def on_request(self, request: "PipelineRequest") -> None: class StorageResponseHook(HTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(StorageResponseHook, self).__init__() def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response @@ -338,7 +338,8 @@ class StorageContentValidation(SansIOHTTPPolicy): This will overwrite any headers already defined in the request. """ - header_name = 'Content-MD5' + + header_name = "Content-MD5" def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @@ -348,10 +349,10 @@ def get_content_md5(data): # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. data = data or b"" - md5 = hashlib.md5() # nosec + md5 = hashlib.md5() # nosec if isinstance(data, bytes): md5.update(data) - elif hasattr(data, 'read'): + elif hasattr(data, "read"): pos = 0 try: pos = data.tell() @@ -369,22 +370,25 @@ def get_content_md5(data): return md5.digest() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop('validate_content', False) - if validate_content and request.http_request.method != 'GET': + validate_content = request.context.options.pop("validate_content", False) + if validate_content and request.http_request.method != "GET": computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) request.http_request.headers[self.header_name] = computed_md5 - request.context['validate_content_md5'] = computed_md5 - request.context['validate_content'] = validate_content + request.context["validate_content_md5"] = computed_md5 + request.context["validate_content"] = validate_content def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - computed_md5 = request.context.get('validate_content_md5') or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: - raise AzureError(( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'."), - response=response.http_response + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, ) @@ -405,11 +409,11 @@ class StorageRetryPolicy(HTTPPolicy): """Whether the secondary endpoint should be retried.""" def __init__(self, **kwargs: Any) -> None: - self.total_retries = kwargs.pop('retry_total', 10) - self.connect_retries = kwargs.pop('retry_connect', 3) - self.read_retries = kwargs.pop('retry_read', 3) - self.status_retries = kwargs.pop('retry_status', 3) - self.retry_to_secondary = kwargs.pop('retry_to_secondary', False) + self.total_retries = kwargs.pop("retry_total", 10) + self.connect_retries = kwargs.pop("retry_connect", 3) + self.read_retries = kwargs.pop("retry_read", 3) + self.status_retries = kwargs.pop("retry_status", 3) + self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: @@ -419,19 +423,27 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe :param Dict[str, Any] settings: The configurable values pertaining to the next host location. :param PipelineRequest request: A pipeline request object. """ - if settings['hosts'] and all(settings['hosts'].values()): + if settings["hosts"] and all(settings["hosts"].values()): url = urlparse(request.url) # If there's more than one possible location, retry to the alternative - if settings['mode'] == LocationMode.PRIMARY: - settings['mode'] = LocationMode.SECONDARY + if settings["mode"] == LocationMode.PRIMARY: + settings["mode"] = LocationMode.SECONDARY else: - settings['mode'] = LocationMode.PRIMARY - updated = url._replace(netloc=settings['hosts'].get(settings['mode'])) + settings["mode"] = LocationMode.PRIMARY + updated = url._replace(netloc=settings["hosts"].get(settings["mode"])) request.url = updated.geturl() def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: + """ + Configure the retry settings for the request. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A dictionary containing the retry settings. + :rtype: Dict[str, Any] + """ body_position = None - if hasattr(request.http_request.body, 'read'): + if hasattr(request.http_request.body, "read"): try: body_position = request.http_request.body.tell() except (AttributeError, UnsupportedOperation): @@ -439,129 +451,140 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: pass options = request.context.options return { - 'total': options.pop("retry_total", self.total_retries), - 'connect': options.pop("retry_connect", self.connect_retries), - 'read': options.pop("retry_read", self.read_retries), - 'status': options.pop("retry_status", self.status_retries), - 'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary), - 'mode': options.pop("location_mode", LocationMode.PRIMARY), - 'hosts': options.pop("hosts", None), - 'hook': options.pop("retry_hook", None), - 'body_position': body_position, - 'count': 0, - 'history': [] + "total": options.pop("retry_total", self.total_retries), + "connect": options.pop("retry_connect", self.connect_retries), + "read": options.pop("retry_read", self.read_retries), + "status": options.pop("retry_status", self.status_retries), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "mode": options.pop("location_mode", LocationMode.PRIMARY), + "hosts": options.pop("hosts", None), + "hook": options.pop("retry_hook", None), + "body_position": body_position, + "count": 0, + "history": [], } def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument - """ Formula for computing the current backoff. + """Formula for computing the current backoff. Should be calculated by child class. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. - :returns: The backoff time. + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: The backoff time. :rtype: float """ return 0 def sleep(self, settings, transport): + """Sleep for the backoff time. + + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. + :param transport: The transport to use for sleeping. + :type transport: + ~azure.core.pipeline.transport.AsyncioBaseTransport or + ~azure.core.pipeline.transport.BaseTransport + """ backoff = self.get_backoff_time(settings) if not backoff or backoff < 0: return transport.sleep(backoff) def increment( - self, settings: Dict[str, Any], + self, + settings: Dict[str, Any], request: "PipelineRequest", response: Optional["PipelineResponse"] = None, - error: Optional[AzureError] = None + error: Optional[AzureError] = None, ) -> bool: """Increment the retry counters. :param Dict[str, Any] settings: The configurable values pertaining to the increment operation. - :param PipelineRequest request: A pipeline request object. - :param Optional[PipelineResponse] response: A pipeline response object. - :param Optional[AzureError] error: An error encountered during the request, or + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: A pipeline response object. + :type response: ~azure.core.pipeline.PipelineResponse or None + :param error: An error encountered during the request, or None if the response was received successfully. - :returns: Whether the retry attempts are exhausted. + :type error: ~azure.core.exceptions.AzureError or None + :return: Whether the retry attempts are exhausted. :rtype: bool """ - settings['total'] -= 1 + settings["total"] -= 1 if error and isinstance(error, ServiceRequestError): # Errors when we're fairly sure that the server did not receive the # request, so it should be safe to retry. - settings['connect'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["connect"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) elif error and isinstance(error, ServiceResponseError): # Errors that occur after the request has been started, so we should # assume that the server began processing it. - settings['read'] -= 1 - settings['history'].append(RequestHistory(request, error=error)) + settings["read"] -= 1 + settings["history"].append(RequestHistory(request, error=error)) else: # Incrementing because of a server error like a 500 in # status_forcelist and a the given method is in the allowlist if response: - settings['status'] -= 1 - settings['history'].append(RequestHistory(request, http_response=response)) + settings["status"] -= 1 + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): - if request.method not in ['PUT'] and settings['retry_secondary']: + if request.method not in ["PUT"] and settings["retry_secondary"]: self._set_next_host_location(settings, request) # rewind the request body if it is a stream - if request.body and hasattr(request.body, 'read'): + if request.body and hasattr(request.body, "read"): # no position was saved, then retry would not work - if settings['body_position'] is None: + if settings["body_position"] is None: return False try: # attempt to rewind the body to the initial position - request.body.seek(settings['body_position'], SEEK_SET) + request.body.seek(settings["body_position"], SEEK_SET) except (UnsupportedOperation, ValueError): # if body is not seekable, then retry would not work return False - settings['count'] += 1 + settings["count"] += 1 return True return False def send(self, request): + """Send the request with retry logic. + + :param request: A pipeline request object. + :type request: ~azure.core.pipeline.PipelineRequest + :return: A pipeline response object. + :rtype: ~azure.core.pipeline.PipelineResponse + """ retries_remaining = True response = None retry_settings = self.configure_retries(request) while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings['mode']) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + retry_hook(retry_settings, request=request.http_request, response=None, error=err) self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -577,12 +600,13 @@ class ExponentialRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, initial_backoff: int = 15, + self, + initial_backoff: int = 15, increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -607,21 +631,20 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to get backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to get backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -636,11 +659,12 @@ class LinearRetry(StorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs a Linear retry object. @@ -659,15 +683,14 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ Calculates how long to sleep before retrying. - :param Dict[str, Any]] settings: The configurable values pertaining to the backoff time. - :returns: + :param Dict[str, Any] settings: The configurable values pertaining to the backoff time. + :return: A float indicating how long to wait before retrying the request, or None to indicate no retry should be performed. :rtype: float @@ -675,19 +698,27 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + """Handle the challenge from the service and authorize the request. + + :param request: The request object. + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The response object. + :type response: ~azure.core.pipeline.PipelineResponse + :return: True if the request was authorized, False otherwise. + :rtype: bool + """ try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 807a51dd297c..4cb32f23248b 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -21,7 +21,7 @@ from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import PipelineRequest, - PipelineResponse + PipelineResponse, ) @@ -29,29 +29,25 @@ async def retry_hook(settings, **kwargs): - if settings['hook']: - if asyncio.iscoroutine(settings['hook']): - await settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + if settings["hook"]: + if asyncio.iscoroutine(settings["hook"]): + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings['hook']( - retry_count=settings['count'] - 1, - location_mode=settings['mode'], - **kwargs) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass - computed_md5 = response.http_request.headers.get('content-md5', None) or \ - encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) - if response.http_response.headers['content-md5'] != computed_md5: + if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() # Load the body in memory and close the socket + except (StreamClosedError, StreamConsumedError): + pass + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + StorageContentValidation.get_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -59,54 +55,56 @@ async def is_checksum_retry(response): class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): - self._response_callback = kwargs.get('raw_response_hook') + self._response_callback = kwargs.get("raw_response_hook") super(AsyncStorageResponseHook, self).__init__() async def send(self, request: "PipelineRequest") -> "PipelineResponse": # Values could be 0 - data_stream_total = request.context.get('data_stream_total') + data_stream_total = request.context.get("data_stream_total") if data_stream_total is None: - data_stream_total = request.context.options.pop('data_stream_total', None) - download_stream_current = request.context.get('download_stream_current') + data_stream_total = request.context.options.pop("data_stream_total", None) + download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop('download_stream_current', None) - upload_stream_current = request.context.get('upload_stream_current') + download_stream_current = request.context.options.pop("download_stream_current", None) + upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop('upload_stream_current', None) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get('response_callback') or \ - request.context.options.pop('raw_response_hook', self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) - will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get('Content-Length', 0)) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: - content_range = response.http_response.headers.get('Content-Range') + content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1]) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get('Content-Length', 0)) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: - if hasattr(pipeline_obj, 'context'): - pipeline_obj.context['data_stream_total'] = data_stream_total - pipeline_obj.context['download_stream_current'] = download_stream_current - pipeline_obj.context['upload_stream_current'] = upload_stream_current + if hasattr(pipeline_obj, "context"): + pipeline_obj.context["data_stream_total"] = data_stream_total + pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): - await response_callback(response) # type: ignore + await response_callback(response) # type: ignore else: response_callback(response) - request.context['response_callback'] = response_callback + request.context["response_callback"] = response_callback return response + class AsyncStorageRetryPolicy(StorageRetryPolicy): """ The base class for Exponential and Linear retries containing shared code. @@ -125,37 +123,29 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, - request=request.http_request, - response=response.http_response) + retry_settings, request=request.http_request, response=response.http_response + ) if retries_remaining: await retry_hook( - retry_settings, - request=request.http_request, - response=response.http_response, - error=None) + retry_settings, request=request.http_request, response=response.http_response, error=None + ) await self.sleep(retry_settings, request.context.transport) continue break except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook( - retry_settings, - request=request.http_request, - response=None, - error=err) + await retry_hook(retry_settings, request=request.http_request, response=None, error=err) await self.sleep(retry_settings, request.context.transport) continue raise err - if retry_settings['history']: - response.context['history'] = retry_settings['history'] - response.http_response.location_mode = retry_settings['mode'] + if retry_settings["history"]: + response.context["history"] = retry_settings["history"] + response.http_response.location_mode = retry_settings["mode"] return response @@ -176,7 +166,8 @@ def __init__( increment_base: int = 3, retry_total: int = 3, retry_to_secondary: bool = False, - random_jitter_range: int = 3, **kwargs + random_jitter_range: int = 3, + **kwargs ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -203,8 +194,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -217,7 +207,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count'])) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -232,7 +222,8 @@ class LinearRetry(AsyncStorageRetryPolicy): """A number in seconds which indicates a range to jitter/randomize for the back-off interval.""" def __init__( - self, backoff: int = 15, + self, + backoff: int = 15, retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, @@ -255,8 +246,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -271,14 +261,13 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range \ - if self.backoff > self.random_jitter_range else 0 + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): - """ Custom Bearer token credential policy for following Storage Bearer challenges """ + """Custom Bearer token credential policy for following Storage Bearer challenges""" def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/request_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/request_handlers.py index af500c8727fa..b23f65859690 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/request_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/request_handlers.py @@ -6,7 +6,7 @@ import logging import stat -from io import (SEEK_END, SEEK_SET, UnsupportedOperation) +from io import SEEK_END, SEEK_SET, UnsupportedOperation from os import fstat from typing import Dict, Optional @@ -37,12 +37,13 @@ def serialize_iso(attr): raise OverflowError("Hit max or min date") date = f"{utc.tm_year:04}-{utc.tm_mon:02}-{utc.tm_mday:02}T{utc.tm_hour:02}:{utc.tm_min:02}:{utc.tm_sec:02}" - return date + 'Z' + return date + "Z" except (ValueError, OverflowError) as err: raise ValueError("Unable to serialize datetime object.") from err except AttributeError as err: raise TypeError("ISO-8601 object must be valid datetime object.") from err + def get_length(data): length = None # Check if object implements the __len__ method, covers most input cases such as bytearray. @@ -62,7 +63,7 @@ def get_length(data): try: mode = fstat(fileno).st_mode if stat.S_ISREG(mode) or stat.S_ISLNK(mode): - #st_size only meaningful if regular file or symlink, other types + # st_size only meaningful if regular file or symlink, other types # e.g. sockets may return misleading sizes like 0 return fstat(fileno).st_size except OSError: @@ -84,13 +85,13 @@ def get_length(data): def read_length(data): try: - if hasattr(data, 'read'): - read_data = b'' + if hasattr(data, "read"): + read_data = b"" for chunk in iter(lambda: data.read(4096), b""): read_data += chunk return len(read_data), read_data - if hasattr(data, '__iter__'): - read_data = b'' + if hasattr(data, "__iter__"): + read_data = b"" for chunk in data: read_data += chunk return len(read_data), read_data @@ -100,8 +101,13 @@ def read_length(data): def validate_and_format_range_headers( - start_range, end_range, start_range_required=True, - end_range_required=True, check_content_md5=False, align_to_page=False): + start_range, + end_range, + start_range_required=True, + end_range_required=True, + check_content_md5=False, + align_to_page=False, +): # If end range is provided, start range must be provided if (start_range_required or end_range is not None) and start_range is None: raise ValueError("start_range value cannot be None.") @@ -111,16 +117,18 @@ def validate_and_format_range_headers( # Page ranges must be 512 aligned if align_to_page: if start_range is not None and start_range % 512 != 0: - raise ValueError(f"Invalid page blob start_range: {start_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob start_range: {start_range}. " "The size must be aligned to a 512-byte boundary." + ) if end_range is not None and end_range % 512 != 511: - raise ValueError(f"Invalid page blob end_range: {end_range}. " - "The size must be aligned to a 512-byte boundary.") + raise ValueError( + f"Invalid page blob end_range: {end_range}. " "The size must be aligned to a 512-byte boundary." + ) # Format based on whether end_range is present range_header = None if end_range is not None: - range_header = f'bytes={start_range}-{end_range}' + range_header = f"bytes={start_range}-{end_range}" elif start_range is not None: range_header = f"bytes={start_range}-" @@ -131,7 +139,7 @@ def validate_and_format_range_headers( raise ValueError("Both start and end range required for MD5 content validation.") if end_range - start_range > 4 * 1024 * 1024: raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") - range_validation = 'true' + range_validation = "true" return range_header, range_validation @@ -140,7 +148,7 @@ def add_metadata_headers(metadata: Optional[Dict[str, str]] = None) -> Dict[str, headers = {} if metadata: for key, value in metadata.items(): - headers[f'x-ms-meta-{key.strip()}'] = value.strip() if value else value + headers[f"x-ms-meta-{key.strip()}"] = value.strip() if value else value return headers @@ -158,29 +166,26 @@ def serialize_batch_body(requests, batch_id): a list of sub-request for the batch request :param str batch_id: to be embedded in batch sub-request delimiter - :returns: The body bytes for this batch. + :return: The body bytes for this batch. :rtype: bytes """ if requests is None or len(requests) == 0: - raise ValueError('Please provide sub-request(s) for this batch request') + raise ValueError("Please provide sub-request(s) for this batch request") - delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode('utf-8') - newline_bytes = _HTTP_LINE_ENDING.encode('utf-8') + delimiter_bytes = (_get_batch_request_delimiter(batch_id, True, False) + _HTTP_LINE_ENDING).encode("utf-8") + newline_bytes = _HTTP_LINE_ENDING.encode("utf-8") batch_body = [] content_index = 0 for request in requests: - request.headers.update({ - "Content-ID": str(content_index), - "Content-Length": str(0) - }) + request.headers.update({"Content-ID": str(content_index), "Content-Length": str(0)}) batch_body.append(delimiter_bytes) batch_body.append(_make_body_from_sub_request(request)) batch_body.append(newline_bytes) content_index += 1 - batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode('utf-8')) + batch_body.append(_get_batch_request_delimiter(batch_id, True, True).encode("utf-8")) # final line of body MUST have \r\n at the end, or it will not be properly read by the service batch_body.append(newline_bytes) @@ -197,35 +202,35 @@ def _get_batch_request_delimiter(batch_id, is_prepend_dashes=False, is_append_da Whether to include the starting dashes. Used in the body, but non on defining the delimiter. :param bool is_append_dashes: Whether to include the ending dashes. Used in the body on the closing delimiter only. - :returns: The delimiter, WITHOUT a trailing newline. + :return: The delimiter, WITHOUT a trailing newline. :rtype: str """ - prepend_dashes = '--' if is_prepend_dashes else '' - append_dashes = '--' if is_append_dashes else '' + prepend_dashes = "--" if is_prepend_dashes else "" + append_dashes = "--" if is_append_dashes else "" return prepend_dashes + _REQUEST_DELIMITER_PREFIX + batch_id + append_dashes def _make_body_from_sub_request(sub_request): """ - Content-Type: application/http - Content-ID: - Content-Transfer-Encoding: (if present) + Content-Type: application/http + Content-ID: + Content-Transfer-Encoding: (if present) - HTTP/ -
:
(repeated as necessary) - Content-Length: - (newline if content length > 0) - (if content length > 0) + HTTP/ +
:
(repeated as necessary) + Content-Length: + (newline if content length > 0) + (if content length > 0) - Serializes an http request. + Serializes an http request. - :param ~azure.core.pipeline.transport.HttpRequest sub_request: - Request to serialize. - :returns: The serialized sub-request in bytes - :rtype: bytes - """ + :param ~azure.core.pipeline.transport.HttpRequest sub_request: + Request to serialize. + :return: The serialized sub-request in bytes + :rtype: bytes + """ # put the sub-request's headers into a list for efficient str concatenation sub_request_body = [] @@ -249,9 +254,9 @@ def _make_body_from_sub_request(sub_request): # append HTTP verb and path and query and HTTP version sub_request_body.append(sub_request.method) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(sub_request.url) - sub_request_body.append(' ') + sub_request_body.append(" ") sub_request_body.append(_HTTP1_1_IDENTIFIER) sub_request_body.append(_HTTP_LINE_ENDING) @@ -266,4 +271,4 @@ def _make_body_from_sub_request(sub_request): # append blank line sub_request_body.append(_HTTP_LINE_ENDING) - return ''.join(sub_request_body).encode() + return "".join(sub_request_body).encode() diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py index af9a2fcdcdc2..bcfa4147763e 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/response_handlers.py @@ -46,23 +46,23 @@ def parse_length_from_content_range(content_range): # First, split in space and take the second half: '1-3/65537' # Next, split on slash and take the second half: '65537' # Finally, convert to an int: 65537 - return int(content_range.split(' ', 1)[1].split('/', 1)[1]) + return int(content_range.split(" ", 1)[1].split("/", 1)[1]) def normalize_headers(headers): normalized = {} for key, value in headers.items(): - if key.startswith('x-ms-'): + if key.startswith("x-ms-"): key = key[5:] - normalized[key.lower().replace('-', '_')] = get_enum_value(value) + normalized[key.lower().replace("-", "_")] = get_enum_value(value) return normalized def deserialize_metadata(response, obj, headers): # pylint: disable=unused-argument try: - raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.http_response.headers.items() if k.lower().startswith("x-ms-meta-")} except AttributeError: - raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith('x-ms-meta-')} + raw_metadata = {k: v for k, v in response.headers.items() if k.lower().startswith("x-ms-meta-")} return {k[10:]: v for k, v in raw_metadata.items()} @@ -82,19 +82,23 @@ def return_raw_deserialized(response, *_): return response.http_response.location_mode, response.context[ContentDecodePolicy.CONTEXT_NAME] -def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches +def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # pylint:disable=too-many-statements, too-many-branches raise_error = HttpResponseError serialized = False if isinstance(storage_error, AzureSigningError): - storage_error.message = storage_error.message + \ - '. This is likely due to an invalid shared key. Please check your shared key and try again.' + storage_error.message = ( + storage_error.message + + ". This is likely due to an invalid shared key. Please check your shared key and try again." + ) if not storage_error.response or storage_error.response.status_code in [200, 204]: raise storage_error # If it is one of those three then it has been serialized prior by the generated layer. - if isinstance(storage_error, (PartialBatchErrorException, - ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError)): + if isinstance( + storage_error, + (PartialBatchErrorException, ClientAuthenticationError, ResourceNotFoundError, ResourceExistsError), + ): serialized = True - error_code = storage_error.response.headers.get('x-ms-error-code') + error_code = storage_error.response.headers.get("x-ms-error-code") error_message = storage_error.message additional_data = {} error_dict = {} @@ -104,27 +108,25 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py if error_body is None or len(error_body) == 0: error_body = storage_error.response.reason except AttributeError: - error_body = '' + error_body = "" # If it is an XML response if isinstance(error_body, Element): - error_dict = { - child.tag.lower(): child.text - for child in error_body - } + error_dict = {child.tag.lower(): child.text for child in error_body} # If it is a JSON response elif isinstance(error_body, dict): - error_dict = error_body.get('error', {}) + error_dict = error_body.get("error", {}) elif not error_code: _LOGGER.warning( - 'Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.', type(error_body)) - error_dict = {'message': str(error_body)} + "Unexpected return type %s from ContentDecodePolicy.deserialize_from_http_generics.", type(error_body) + ) + error_dict = {"message": str(error_body)} # If we extracted from a Json or XML response # There is a chance error_dict is just a string if error_dict and isinstance(error_dict, dict): - error_code = error_dict.get('code') - error_message = error_dict.get('message') - additional_data = {k: v for k, v in error_dict.items() if k not in {'code', 'message'}} + error_code = error_dict.get("code") + error_message = error_dict.get("message") + additional_data = {k: v for k, v in error_dict.items() if k not in {"code", "message"}} except DecodeError: pass @@ -132,31 +134,33 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py # This check would be unnecessary if we have already serialized the error if error_code and not serialized: error_code = StorageErrorCode(error_code) - if error_code in [StorageErrorCode.condition_not_met, - StorageErrorCode.blob_overwritten]: + if error_code in [StorageErrorCode.condition_not_met, StorageErrorCode.blob_overwritten]: raise_error = ResourceModifiedError - if error_code in [StorageErrorCode.invalid_authentication_info, - StorageErrorCode.authentication_failed]: + if error_code in [StorageErrorCode.invalid_authentication_info, StorageErrorCode.authentication_failed]: raise_error = ClientAuthenticationError - if error_code in [StorageErrorCode.resource_not_found, - StorageErrorCode.cannot_verify_copy_source, - StorageErrorCode.blob_not_found, - StorageErrorCode.queue_not_found, - StorageErrorCode.container_not_found, - StorageErrorCode.parent_not_found, - StorageErrorCode.share_not_found]: + if error_code in [ + StorageErrorCode.resource_not_found, + StorageErrorCode.cannot_verify_copy_source, + StorageErrorCode.blob_not_found, + StorageErrorCode.queue_not_found, + StorageErrorCode.container_not_found, + StorageErrorCode.parent_not_found, + StorageErrorCode.share_not_found, + ]: raise_error = ResourceNotFoundError - if error_code in [StorageErrorCode.account_already_exists, - StorageErrorCode.account_being_created, - StorageErrorCode.resource_already_exists, - StorageErrorCode.resource_type_mismatch, - StorageErrorCode.blob_already_exists, - StorageErrorCode.queue_already_exists, - StorageErrorCode.container_already_exists, - StorageErrorCode.container_being_deleted, - StorageErrorCode.queue_being_deleted, - StorageErrorCode.share_already_exists, - StorageErrorCode.share_being_deleted]: + if error_code in [ + StorageErrorCode.account_already_exists, + StorageErrorCode.account_being_created, + StorageErrorCode.resource_already_exists, + StorageErrorCode.resource_type_mismatch, + StorageErrorCode.blob_already_exists, + StorageErrorCode.queue_already_exists, + StorageErrorCode.container_already_exists, + StorageErrorCode.container_being_deleted, + StorageErrorCode.queue_being_deleted, + StorageErrorCode.share_already_exists, + StorageErrorCode.share_being_deleted, + ]: raise_error = ResourceExistsError except ValueError: # Got an unknown error code @@ -183,7 +187,7 @@ def process_storage_error(storage_error) -> NoReturn: # type: ignore [misc] # py error.args = (error.message,) try: # `from None` prevents us from double printing the exception (suppresses generated layer error context) - exec("raise error from None") # pylint: disable=exec-used # nosec + exec("raise error from None") # pylint: disable=exec-used # nosec except SyntaxError as exc: raise error from exc diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/shared_access_signature.py index 36d05f67b061..89518fd27dc3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/shared_access_signature.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +# pylint: disable=docstring-keyword-should-match-keyword-only from datetime import date @@ -10,44 +11,45 @@ from .constants import X_MS_VERSION from . import sign_string, url_quote + # cspell:ignoreRegExp rsc. # cspell:ignoreRegExp s..?id class QueryStringConstants(object): - SIGNED_SIGNATURE = 'sig' - SIGNED_PERMISSION = 'sp' - SIGNED_START = 'st' - SIGNED_EXPIRY = 'se' - SIGNED_RESOURCE = 'sr' - SIGNED_IDENTIFIER = 'si' - SIGNED_IP = 'sip' - SIGNED_PROTOCOL = 'spr' - SIGNED_VERSION = 'sv' - SIGNED_CACHE_CONTROL = 'rscc' - SIGNED_CONTENT_DISPOSITION = 'rscd' - SIGNED_CONTENT_ENCODING = 'rsce' - SIGNED_CONTENT_LANGUAGE = 'rscl' - SIGNED_CONTENT_TYPE = 'rsct' - START_PK = 'spk' - START_RK = 'srk' - END_PK = 'epk' - END_RK = 'erk' - SIGNED_RESOURCE_TYPES = 'srt' - SIGNED_SERVICES = 'ss' - SIGNED_OID = 'skoid' - SIGNED_TID = 'sktid' - SIGNED_KEY_START = 'skt' - SIGNED_KEY_EXPIRY = 'ske' - SIGNED_KEY_SERVICE = 'sks' - SIGNED_KEY_VERSION = 'skv' - SIGNED_ENCRYPTION_SCOPE = 'ses' - SIGNED_KEY_DELEGATED_USER_TID = 'skdutid' - SIGNED_DELEGATED_USER_OID = 'sduoid' + SIGNED_SIGNATURE = "sig" + SIGNED_PERMISSION = "sp" + SIGNED_START = "st" + SIGNED_EXPIRY = "se" + SIGNED_RESOURCE = "sr" + SIGNED_IDENTIFIER = "si" + SIGNED_IP = "sip" + SIGNED_PROTOCOL = "spr" + SIGNED_VERSION = "sv" + SIGNED_CACHE_CONTROL = "rscc" + SIGNED_CONTENT_DISPOSITION = "rscd" + SIGNED_CONTENT_ENCODING = "rsce" + SIGNED_CONTENT_LANGUAGE = "rscl" + SIGNED_CONTENT_TYPE = "rsct" + START_PK = "spk" + START_RK = "srk" + END_PK = "epk" + END_RK = "erk" + SIGNED_RESOURCE_TYPES = "srt" + SIGNED_SERVICES = "ss" + SIGNED_OID = "skoid" + SIGNED_TID = "sktid" + SIGNED_KEY_START = "skt" + SIGNED_KEY_EXPIRY = "ske" + SIGNED_KEY_SERVICE = "sks" + SIGNED_KEY_VERSION = "skv" + SIGNED_ENCRYPTION_SCOPE = "ses" + SIGNED_KEY_DELEGATED_USER_TID = "skdutid" + SIGNED_DELEGATED_USER_OID = "sduoid" # for ADLS - SIGNED_AUTHORIZED_OID = 'saoid' - SIGNED_UNAUTHORIZED_OID = 'suoid' - SIGNED_CORRELATION_ID = 'scid' - SIGNED_DIRECTORY_DEPTH = 'sdd' + SIGNED_AUTHORIZED_OID = "saoid" + SIGNED_UNAUTHORIZED_OID = "suoid" + SIGNED_CORRELATION_ID = "scid" + SIGNED_DIRECTORY_DEPTH = "sdd" @staticmethod def to_list(): @@ -90,28 +92,29 @@ def to_list(): class SharedAccessSignature(object): - ''' + """ Provides a factory for creating account access signature tokens with an account name and account key. Users can either use the factory or can construct the appropriate service and use the generate_*_shared_access_signature method directly. - ''' + """ def __init__(self, account_name, account_key, x_ms_version=X_MS_VERSION): - ''' + """ :param str account_name: The storage account name used to generate the shared access signatures. :param str account_key: The access key to generate the shares access signatures. :param str x_ms_version: The service version used to generate the shared access signatures. - ''' + """ self.account_name = account_name self.account_key = account_key self.x_ms_version = x_ms_version def generate_account( - self, services, + self, + services, resource_types, permission, expiry, @@ -120,7 +123,7 @@ def generate_account( protocol=None, sts_hook=None, ) -> str: - ''' + """ Generates a shared access signature for the account. Use the returned signature with the sas_token parameter of the service or to create a new account object. @@ -164,9 +167,9 @@ def generate_account( For debugging purposes only. If provided, the hook is called with the string to sign that was used to generate the SAS. :type sts_hook: Optional[Callable[[str], None]] - :returns: The generated SAS token for the account. + :return: The generated SAS token for the account. :rtype: str - ''' + """ sas = _SharedAccessHelper() sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version) sas.add_account(services, resource_types) @@ -211,11 +214,9 @@ def add_account(self, services, resource_types): self._add_query(QueryStringConstants.SIGNED_SERVICES, services) self._add_query(QueryStringConstants.SIGNED_RESOURCE_TYPES, resource_types) - def add_override_response_headers(self, cache_control, - content_disposition, - content_encoding, - content_language, - content_type): + def add_override_response_headers( + self, cache_control, content_disposition, content_encoding, content_language, content_type + ): self._add_query(QueryStringConstants.SIGNED_CACHE_CONTROL, cache_control) self._add_query(QueryStringConstants.SIGNED_CONTENT_DISPOSITION, content_disposition) self._add_query(QueryStringConstants.SIGNED_CONTENT_ENCODING, content_encoding) @@ -224,25 +225,25 @@ def add_override_response_headers(self, cache_control, def add_account_signature(self, account_name, account_key): def get_value_to_append(query): - return_value = self.query_dict.get(query) or '' - return return_value + '\n' - - string_to_sign = \ - (account_name + '\n' + - get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + - get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + - get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + - get_value_to_append(QueryStringConstants.SIGNED_START) + - get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + - get_value_to_append(QueryStringConstants.SIGNED_IP) + - get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + - get_value_to_append(QueryStringConstants.SIGNED_VERSION) + - '\n' # Signed Encryption Scope - always empty for queue - ) - - self._add_query(QueryStringConstants.SIGNED_SIGNATURE, - sign_string(account_key, string_to_sign)) + return_value = self.query_dict.get(query) or "" + return return_value + "\n" + + string_to_sign = ( + account_name + + "\n" + + get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + + get_value_to_append(QueryStringConstants.SIGNED_SERVICES) + + get_value_to_append(QueryStringConstants.SIGNED_RESOURCE_TYPES) + + get_value_to_append(QueryStringConstants.SIGNED_START) + + get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + + get_value_to_append(QueryStringConstants.SIGNED_IP) + + get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + + get_value_to_append(QueryStringConstants.SIGNED_VERSION) + + "\n" # Signed Encryption Scope - always empty for queue + ) + + self._add_query(QueryStringConstants.SIGNED_SIGNATURE, sign_string(account_key, string_to_sign)) self.string_to_sign = string_to_sign def get_token(self) -> str: - return '&'.join([f'{n}={url_quote(v)}' for n, v in self.query_dict.items() if v is not None]) + return "&".join([f"{n}={url_quote(v)}" for n, v in self.query_dict.items() if v is not None]) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py index b31cfb3291d9..7a5fb3f3dc91 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads.py @@ -12,7 +12,7 @@ from azure.core.tracing.common import with_current_context -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers @@ -41,20 +41,21 @@ def _parallel_uploads(executor, uploader, pending, running): def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - validate_content=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + validate_content=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -64,7 +65,8 @@ def upload_data_chunks( parallel=parallel, validate_content=validate_content, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: upload_tasks = uploader.get_chunk_streams() @@ -81,18 +83,19 @@ def upload_data_chunks( def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -100,7 +103,8 @@ def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: with futures.ThreadPoolExecutor(max_concurrency) as executor: @@ -120,15 +124,17 @@ def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -253,7 +259,7 @@ def __init__(self, *args, **kwargs): def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) self.service.stage_block( block_id, @@ -261,20 +267,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) return index, block_id def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -302,11 +308,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): pass @@ -326,19 +332,20 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) def _upload_substream_block(self, index, block_stream): @@ -356,11 +363,11 @@ def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] def _upload_substream_block(self, index, block_stream): try: @@ -371,7 +378,7 @@ def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -388,9 +395,9 @@ def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - return f'bytes={chunk_offset}-{chunk_end}', response + return f"bytes={chunk_offset}-{chunk_end}", response # TODO: Implement this method. def _upload_substream_block(self, index, block_stream): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py index a056cd290230..6ed5ba1d0f91 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/uploads_async.py @@ -12,7 +12,7 @@ from math import ceil from typing import AsyncGenerator, Union -from .import encode_base64, url_quote +from . import encode_base64, url_quote from .request_handlers import get_length from .response_handlers import return_response_headers from .uploads import SubStream, IterStreamer # pylint: disable=unused-import @@ -59,19 +59,20 @@ async def _parallel_uploads(uploader, pending, running): async def upload_data_chunks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, @@ -80,7 +81,8 @@ async def upload_data_chunks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_chunk_streams() @@ -104,18 +106,19 @@ async def upload_data_chunks( async def upload_substream_blocks( - service=None, - uploader_class=None, - total_size=None, - chunk_size=None, - max_concurrency=None, - stream=None, - progress_hook=None, - **kwargs): + service=None, + uploader_class=None, + total_size=None, + chunk_size=None, + max_concurrency=None, + stream=None, + progress_hook=None, + **kwargs, +): parallel = max_concurrency > 1 - if parallel and 'modified_access_conditions' in kwargs: + if parallel and "modified_access_conditions" in kwargs: # Access conditions do not work with parallelism - kwargs['modified_access_conditions'] = None + kwargs["modified_access_conditions"] = None uploader = uploader_class( service=service, total_size=total_size, @@ -123,13 +126,13 @@ async def upload_substream_blocks( stream=stream, parallel=parallel, progress_hook=progress_hook, - **kwargs) + **kwargs, + ) if parallel: upload_tasks = uploader.get_substream_blocks() running_futures = [ - asyncio.ensure_future(uploader.process_substream_block(u)) - for u in islice(upload_tasks, 0, max_concurrency) + asyncio.ensure_future(uploader.process_substream_block(u)) for u in islice(upload_tasks, 0, max_concurrency) ] range_ids = await _parallel_uploads(uploader.process_substream_block, upload_tasks, running_futures) else: @@ -144,15 +147,17 @@ async def upload_substream_blocks( class _ChunkUploader(object): # pylint: disable=too-many-instance-attributes def __init__( - self, service, - total_size, - chunk_size, - stream, - parallel, - encryptor=None, - padder=None, - progress_hook=None, - **kwargs): + self, + service, + total_size, + chunk_size, + stream, + parallel, + encryptor=None, + padder=None, + progress_hook=None, + **kwargs, + ): self.service = service self.total_size = total_size self.chunk_size = chunk_size @@ -178,7 +183,7 @@ def __init__( async def get_chunk_streams(self): index = 0 while True: - data = b'' + data = b"" read_size = self.chunk_size # Buffer until we either reach the end of the stream or get a whole chunk. @@ -189,12 +194,12 @@ async def get_chunk_streams(self): if inspect.isawaitable(temp): temp = await temp if not isinstance(temp, bytes): - raise TypeError('Blob data should be of type bytes.') + raise TypeError("Blob data should be of type bytes.") data += temp or b"" # We have read an empty string and so are at the end # of the buffer or we have read a full chunk. - if temp == b'' or len(data) == self.chunk_size: + if temp == b"" or len(data) == self.chunk_size: break if len(data) == self.chunk_size: @@ -273,13 +278,13 @@ def set_response_properties(self, resp): class BlockBlobChunkUploader(_ChunkUploader): def __init__(self, *args, **kwargs): - kwargs.pop('modified_access_conditions', None) + kwargs.pop("modified_access_conditions", None) super(BlockBlobChunkUploader, self).__init__(*args, **kwargs) self.current_length = None async def _upload_chunk(self, chunk_offset, chunk_data): # TODO: This is incorrect, but works with recording. - index = f'{chunk_offset:032d}' + index = f"{chunk_offset:032d}" block_id = encode_base64(url_quote(encode_base64(index))) await self.service.stage_block( block_id, @@ -287,19 +292,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): body=chunk_data, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) return index, block_id async def _upload_substream_block(self, index, block_stream): try: - block_id = f'BlockId{(index//self.chunk_size):05}' + block_id = f"BlockId{(index//self.chunk_size):05}" await self.service.stage_block( block_id, len(block_stream), block_stream, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) finally: block_stream.close() return block_id @@ -311,7 +318,7 @@ def _is_chunk_empty(self, chunk_data): # read until non-zero byte is encountered # if reached the end without returning, then chunk_data is all 0's for each_byte in chunk_data: - if each_byte not in [0, b'\x00']: + if each_byte not in [0, b"\x00"]: return False return True @@ -319,7 +326,7 @@ async def _upload_chunk(self, chunk_offset, chunk_data): # avoid uploading the empty pages if not self._is_chunk_empty(chunk_data): chunk_end = chunk_offset + len(chunk_data) - 1 - content_range = f'bytes={chunk_offset}-{chunk_end}' + content_range = f"bytes={chunk_offset}-{chunk_end}" computed_md5 = None self.response_headers = await self.service.upload_pages( body=chunk_data, @@ -329,10 +336,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): pass @@ -352,18 +360,21 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) - self.current_length = int(self.response_headers['blob_append_offset']) + **self.request_options, + ) + self.current_length = int(self.response_headers["blob_append_offset"]) else: - self.request_options['append_position_access_conditions'].append_position = \ + self.request_options["append_position_access_conditions"].append_position = ( self.current_length + chunk_offset + ) self.response_headers = await self.service.append_block( body=chunk_data, content_length=len(chunk_data), cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options) + **self.request_options, + ) async def _upload_substream_block(self, index, block_stream): pass @@ -379,11 +390,11 @@ async def _upload_chunk(self, chunk_offset, chunk_data): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - if not self.parallel and self.request_options.get('modified_access_conditions'): - self.request_options['modified_access_conditions'].if_match = self.response_headers['etag'] + if not self.parallel and self.request_options.get("modified_access_conditions"): + self.request_options["modified_access_conditions"].if_match = self.response_headers["etag"] async def _upload_substream_block(self, index, block_stream): try: @@ -394,7 +405,7 @@ async def _upload_substream_block(self, index, block_stream): cls=return_response_headers, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) finally: block_stream.close() @@ -411,9 +422,9 @@ async def _upload_chunk(self, chunk_offset, chunk_data): length, data_stream_total=self.total_size, upload_stream_current=self.progress_total, - **self.request_options + **self.request_options, ) - range_id = f'bytes={chunk_offset}-{chunk_end}' + range_id = f"bytes={chunk_offset}-{chunk_end}" return range_id, response # TODO: Implement this method. @@ -421,10 +432,11 @@ async def _upload_substream_block(self, index, block_stream): pass -class AsyncIterStreamer(): +class AsyncIterStreamer: """ File-like streaming object for AsyncGenerators. """ + def __init__(self, generator: AsyncGenerator[Union[bytes, str], None], encoding: str = "UTF-8"): self.iterator = generator.__aiter__() self.leftover = b"" diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py index 9c63cc22ab15..80003a57c9dd 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared_access_signature.py @@ -5,10 +5,7 @@ # -------------------------------------------------------------------------- # pylint: disable=docstring-keyword-should-match-keyword-only -from typing import ( - Any, Callable, Optional, Union, - TYPE_CHECKING -) +from typing import Any, Callable, Optional, Union, TYPE_CHECKING from urllib.parse import parse_qs from azure.storage.queue._shared import sign_string @@ -17,45 +14,43 @@ from azure.storage.queue._shared.shared_access_signature import ( QueryStringConstants, SharedAccessSignature, - _SharedAccessHelper + _SharedAccessHelper, ) if TYPE_CHECKING: - from azure.storage.queue import ( - AccountSasPermissions, - QueueSasPermissions, - ResourceTypes - ) + from azure.storage.queue import AccountSasPermissions, QueueSasPermissions, ResourceTypes from datetime import datetime + class QueueSharedAccessSignature(SharedAccessSignature): - ''' + """ Provides a factory for creating queue shares access signature tokens with a common account name and account key. Users can either use the factory or can construct the appropriate service and use the generate_*_shared_access_signature method directly. - ''' + """ def __init__(self, account_name: str, account_key: str) -> None: - ''' + """ :param str account_name: The storage account name used to generate the shared access signatures. :param str account_key: The access key to generate the shares access signatures. - ''' + """ super(QueueSharedAccessSignature, self).__init__(account_name, account_key, x_ms_version=X_MS_VERSION) def generate_queue( - self, queue_name: str, + self, + queue_name: str, permission: Optional[Union["QueueSasPermissions", str]] = None, expiry: Optional[Union["datetime", str]] = None, start: Optional[Union["datetime", str]] = None, policy_id: Optional[str] = None, ip: Optional[str] = None, protocol: Optional[str] = None, - sts_hook: Optional[Callable[[str], None]] = None + sts_hook: Optional[Callable[[str], None]] = None, ) -> str: - ''' + """ Generates a shared access signature for the queue. Use the returned signature with the sas_token parameter of QueueService. :param str queue_name: @@ -100,7 +95,7 @@ def generate_queue( :type sts_hook: Optional[Callable[[str], None]] :return: A Shared Access Signature (sas) token. :rtype: str - ''' + """ sas = _QueueSharedAccessHelper() sas.add_base(permission, expiry, start, ip, protocol, self.x_ms_version) sas.add_id(policy_id) @@ -116,32 +111,32 @@ class _QueueSharedAccessHelper(_SharedAccessHelper): def add_resource_signature(self, account_name: str, account_key: str, path: str): def get_value_to_append(query): - return_value = self.query_dict.get(query) or '' - return return_value + '\n' + return_value = self.query_dict.get(query) or "" + return return_value + "\n" - if path[0] != '/': - path = '/' + path + if path[0] != "/": + path = "/" + path - canonicalized_resource = '/queue/' + account_name + path + '\n' + canonicalized_resource = "/queue/" + account_name + path + "\n" # Form the string to sign from shared_access_policy and canonicalized # resource. The order of values is important. - string_to_sign = \ - (get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + - get_value_to_append(QueryStringConstants.SIGNED_START) + - get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + - canonicalized_resource + - get_value_to_append(QueryStringConstants.SIGNED_IDENTIFIER) + - get_value_to_append(QueryStringConstants.SIGNED_IP) + - get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + - get_value_to_append(QueryStringConstants.SIGNED_VERSION)) + string_to_sign = ( + get_value_to_append(QueryStringConstants.SIGNED_PERMISSION) + + get_value_to_append(QueryStringConstants.SIGNED_START) + + get_value_to_append(QueryStringConstants.SIGNED_EXPIRY) + + canonicalized_resource + + get_value_to_append(QueryStringConstants.SIGNED_IDENTIFIER) + + get_value_to_append(QueryStringConstants.SIGNED_IP) + + get_value_to_append(QueryStringConstants.SIGNED_PROTOCOL) + + get_value_to_append(QueryStringConstants.SIGNED_VERSION) + ) # remove the trailing newline - if string_to_sign[-1] == '\n': + if string_to_sign[-1] == "\n": string_to_sign = string_to_sign[:-1] - self._add_query(QueryStringConstants.SIGNED_SIGNATURE, - sign_string(account_key, string_to_sign)) + self._add_query(QueryStringConstants.SIGNED_SIGNATURE, sign_string(account_key, string_to_sign)) self.string_to_sign = string_to_sign @@ -302,6 +297,7 @@ def generate_queue_sas( **kwargs ) + def _is_credential_sastoken(credential: Any) -> bool: if not credential or not isinstance(credential, str): return False diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py index 15781aefaf2e..434d5fe99bba 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/__init__.py @@ -9,6 +9,6 @@ __all__ = [ - 'QueueClient', - 'QueueServiceClient', + "QueueClient", + "QueueServiceClient", ] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py index 49dfd72c63d0..5a1dfba70bde 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_models.py @@ -31,17 +31,18 @@ class MessagesPaged(AsyncPageIterator): """The maximum number of messages to retrieve from the queue.""" def __init__( - self, command: Callable, + self, + command: Callable, results_per_page: Optional[int] = None, continuation_token: Optional[str] = None, - max_messages: Optional[int] = None + max_messages: Optional[int] = None, ) -> None: if continuation_token is not None: raise ValueError("This operation does not support continuation token") super(MessagesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, # type: ignore [arg-type] + self._extract_data_cb, # type: ignore [arg-type] ) self._command = command self.results_per_page = results_per_page @@ -97,15 +98,16 @@ class QueuePropertiesPaged(AsyncPageIterator): """Function to retrieve the next page of items.""" def __init__( - self, command: Callable, + self, + command: Callable, prefix: Optional[str] = None, results_per_page: Optional[int] = None, - continuation_token: Optional[str] = None + continuation_token: Optional[str] = None, ) -> None: super(QueuePropertiesPaged, self).__init__( self._get_next_cb, - self._extract_data_cb, # type: ignore [arg-type] - continuation_token=continuation_token or "" + self._extract_data_cb, # type: ignore [arg-type] + continuation_token=continuation_token or "", ) self._command = command self.service_endpoint = None @@ -120,7 +122,8 @@ async def _get_next_cb(self, continuation_token: Optional[str]) -> Any: marker=continuation_token or None, maxresults=self.results_per_page, cls=return_context_and_deserialized, - use_location=self.location_mode) + use_location=self.location_mode, + ) except HttpResponseError as error: process_storage_error(error) @@ -130,6 +133,8 @@ async def _extract_data_cb(self, get_next_return: Any) -> Tuple[Optional[str], L self.prefix = self._response.prefix self.marker = self._response.marker self.results_per_page = self._response.max_results - props_list = [QueueProperties._from_generated(q) for q in self._response.queue_items] # pylint: disable=protected-access + props_list = [ + QueueProperties._from_generated(q) for q in self._response.queue_items # pylint: disable=protected-access + ] next_marker = self._response.next_marker return next_marker or None, props_list diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py index 5f053b5c1c09..df12b5d1a961 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_client_async.py @@ -6,10 +6,7 @@ import functools import warnings -from typing import ( - Any, cast, Dict, List, - Optional, Tuple, TYPE_CHECKING, Union -) +from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged @@ -29,11 +26,7 @@ from .._shared.base_client_async import AsyncStorageAccountHostsMixin, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.request_handlers import add_metadata_headers, serialize_iso -from .._shared.response_handlers import ( - process_storage_error, - return_headers_and_deserialized, - return_response_headers -) +from .._shared.response_handlers import process_storage_error, return_headers_and_deserialized, return_response_headers if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential @@ -42,15 +35,13 @@ BinaryBase64DecodePolicy, BinaryBase64EncodePolicy, TextBase64DecodePolicy, - TextBase64EncodePolicy + TextBase64EncodePolicy, ) from .._models import QueueProperties class QueueClient( # type: ignore [misc] - AsyncStorageAccountHostsMixin, - StorageAccountHostsMixin, - StorageEncryptionMixin + AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin ): """A client to interact with a specific Queue. @@ -108,9 +99,12 @@ class QueueClient( # type: ignore [misc] """ def __init__( - self, account_url: str, + self, + account_url: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, @@ -120,13 +114,13 @@ def __init__( **kwargs: Any ) -> None: kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) - loop = kwargs.pop('loop', None) + loop = kwargs.pop("loop", None) parsed_url, sas_token = _parse_url(account_url=account_url, queue_name=queue_name, credential=credential) self.queue_name = queue_name self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueClient, self).__init__( parsed_url, - service='queue', + service="queue", credential=credential, secondary_hostname=secondary_hostname, audience=audience, @@ -148,17 +142,15 @@ def _format_url(self, hostname: str) -> str: :returns: The formatted endpoint URL according to the specified location mode hostname. :rtype: str """ - return _format_url( - queue_name=self.queue_name, - hostname=hostname, - scheme=self.scheme, - query_str=self._query_str - ) + return _format_url(queue_name=self.queue_name, hostname=hostname, scheme=self.scheme, query_str=self._query_str) @classmethod def from_queue_url( - cls, queue_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + cls, + queue_url: str, + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, @@ -218,9 +210,12 @@ def from_queue_url( @classmethod def from_connection_string( - cls, conn_str: str, + cls, + conn_str: str, queue_name: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, @@ -277,7 +272,7 @@ def from_connection_string( :dedent: 8 :caption: Create the queue client from connection string. """ - account_url, secondary, credential = parse_connection_str(conn_str, credential, 'queue') + account_url, secondary, credential = parse_connection_str(conn_str, credential, "queue") return cls( account_url, queue_name=queue_name, @@ -292,10 +287,7 @@ def from_connection_string( @distributed_trace_async async def create_queue( - self, *, - metadata: Optional[Dict[str, str]] = None, - timeout: Optional[int] = None, - **kwargs: Any + self, *, metadata: Optional[Dict[str, str]] = None, timeout: Optional[int] = None, **kwargs: Any ) -> None: """Creates a new queue in the storage account. @@ -329,11 +321,7 @@ async def create_queue( headers.update(add_metadata_headers(metadata)) try: return await self._client.queue.create( - metadata=metadata, - timeout=timeout, - headers=headers, - cls=deserialize_queue_creation, - **kwargs + metadata=metadata, timeout=timeout, headers=headers, cls=deserialize_queue_creation, **kwargs ) except HttpResponseError as error: process_storage_error(error) @@ -393,11 +381,10 @@ async def get_queue_properties(self, *, timeout: Optional[int] = None, **kwargs: :caption: Get the properties on the queue. """ try: - response = cast("QueueProperties", await (self._client.queue.get_properties( - timeout=timeout, - cls=deserialize_queue_properties, - **kwargs - ))) + response = cast( + "QueueProperties", + await self._client.queue.get_properties(timeout=timeout, cls=deserialize_queue_properties, **kwargs), + ) except HttpResponseError as error: process_storage_error(error) response.name = self.queue_name @@ -405,10 +392,7 @@ async def get_queue_properties(self, *, timeout: Optional[int] = None, **kwargs: @distributed_trace_async async def set_queue_metadata( - self, metadata: Optional[Dict[str, str]] = None, - *, - timeout: Optional[int] = None, - **kwargs: Any + self, metadata: Optional[Dict[str, str]] = None, *, timeout: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: """Sets user-defined metadata on the specified queue. @@ -439,10 +423,7 @@ async def set_queue_metadata( headers.update(add_metadata_headers(metadata)) try: return await self._client.queue.set_metadata( - timeout=timeout, - headers=headers, - cls=return_response_headers, - **kwargs + timeout=timeout, headers=headers, cls=return_response_headers, **kwargs ) except HttpResponseError as error: process_storage_error(error) @@ -462,21 +443,19 @@ async def get_queue_access_policy(self, *, timeout: Optional[int] = None, **kwar :rtype: dict(str, ~azure.storage.queue.AccessPolicy) """ try: - _, identifiers = cast(Tuple[Dict, List], await self._client.queue.get_access_policy( - timeout=timeout, - cls=return_headers_and_deserialized, - **kwargs - )) + _, identifiers = cast( + Tuple[Dict, List], + await self._client.queue.get_access_policy( + timeout=timeout, cls=return_headers_and_deserialized, **kwargs + ), + ) except HttpResponseError as error: process_storage_error(error) return {s.id: s.access_policy or AccessPolicy() for s in identifiers} @distributed_trace_async async def set_queue_access_policy( - self, signed_identifiers: Dict[str, AccessPolicy], - *, - timeout: Optional[int] = None, - **kwargs: Any + self, signed_identifiers: Dict[str, AccessPolicy], *, timeout: Optional[int] = None, **kwargs: Any ) -> None: """Sets stored access policies for the queue that may be used with Shared Access Signatures. @@ -531,7 +510,8 @@ async def set_queue_access_policy( @distributed_trace_async async def send_message( - self, content: Optional[object], + self, + content: Optional[object], *, visibility_timeout: Optional[int] = None, time_to_live: Optional[int] = None, @@ -589,10 +569,7 @@ async def send_message( """ if self.key_encryption_key: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) try: @@ -600,7 +577,7 @@ async def send_message( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, resolver=self.key_resolver_function, - encryption_version=self.encryption_version + encryption_version=self.encryption_version, ) except TypeError: warnings.warn( @@ -612,7 +589,7 @@ async def send_message( self._message_encode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) encoded_content = self._message_encode_policy(content) new_message = GenQueueMessage(message_text=encoded_content) @@ -631,7 +608,7 @@ async def send_message( inserted_on=enqueued[0].insertion_time, expires_on=enqueued[0].expiration_time, pop_receipt=enqueued[0].pop_receipt, - next_visible_on=enqueued[0].time_next_visible + next_visible_on=enqueued[0].time_next_visible, ) return queue_message except HttpResponseError as error: @@ -639,10 +616,7 @@ async def send_message( @distributed_trace_async async def receive_message( - self, *, - visibility_timeout: Optional[int] = None, - timeout: Optional[int] = None, - **kwargs: Any + self, *, visibility_timeout: Optional[int] = None, timeout: Optional[int] = None, **kwargs: Any ) -> Optional[QueueMessage]: """Removes one message from the front of the queue. @@ -683,16 +657,13 @@ async def receive_message( """ if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) try: message = await self._client.messages.dequeue( @@ -702,14 +673,17 @@ async def receive_message( cls=self._message_decode_policy, **kwargs ) - wrapped_message = QueueMessage._from_generated(message[0]) if message != [] else None # pylint: disable=protected-access + wrapped_message = ( + QueueMessage._from_generated(message[0]) if message != [] else None # pylint: disable=protected-access + ) return wrapped_message except HttpResponseError as error: process_storage_error(error) @distributed_trace def receive_messages( - self, *, + self, + *, messages_per_page: Optional[int] = None, visibility_timeout: Optional[int] = None, max_messages: Optional[int] = None, @@ -766,16 +740,13 @@ def receive_messages( """ if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) try: command = functools.partial( @@ -792,14 +763,15 @@ def receive_messages( command, results_per_page=messages_per_page, page_iterator_class=MessagesPaged, - max_messages=max_messages + max_messages=max_messages, ) except HttpResponseError as error: process_storage_error(error) @distributed_trace_async async def update_message( - self, message: Union[str, QueueMessage], + self, + message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, content: Optional[object] = None, *, @@ -859,10 +831,7 @@ async def update_message( """ if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) if isinstance(message, QueueMessage): @@ -888,7 +857,7 @@ async def update_message( self.require_encryption, self.key_encryption_key, self.key_resolver_function, - encryption_version=self.encryption_version + encryption_version=self.encryption_version, ) except TypeError: warnings.warn( @@ -898,32 +867,33 @@ async def update_message( Retrying without encryption_version." ) self._message_encode_policy.configure( - self.require_encryption, - self.key_encryption_key, - self.key_resolver_function + self.require_encryption, self.key_encryption_key, self.key_resolver_function ) encoded_message_text = self._message_encode_policy(message_text) updated = GenQueueMessage(message_text=encoded_message_text) else: updated = None try: - response = cast(QueueMessage, await self._client.message_id.update( - queue_message=updated, - visibilitytimeout=visibility_timeout or 0, - timeout=timeout, - pop_receipt=receipt, - cls=return_response_headers, - queue_message_id=message_id, - **kwargs - )) + response = cast( + QueueMessage, + await self._client.message_id.update( + queue_message=updated, + visibilitytimeout=visibility_timeout or 0, + timeout=timeout, + pop_receipt=receipt, + cls=return_response_headers, + queue_message_id=message_id, + **kwargs + ), + ) new_message = QueueMessage( content=message_text, id=message_id, inserted_on=inserted_on, dequeue_count=dequeue_count, expires_on=expires_on, - pop_receipt=response['popreceipt'], - next_visible_on=response['time_next_visible'] + pop_receipt=response["popreceipt"], + next_visible_on=response["time_next_visible"], ) return new_message except HttpResponseError as error: @@ -931,10 +901,7 @@ async def update_message( @distributed_trace_async async def peek_messages( - self, max_messages: Optional[int] = None, - *, - timeout: Optional[int] = None, - **kwargs: Any + self, max_messages: Optional[int] = None, *, timeout: Optional[int] = None, **kwargs: Any ) -> List[QueueMessage]: """Retrieves one or more messages from the front of the queue, but does not alter the visibility of the message. @@ -980,23 +947,17 @@ async def peek_messages( if self.key_encryption_key or self.key_resolver_function: modify_user_agent_for_encryption( - self._config.user_agent_policy.user_agent, - self._sdk_moniker, - self.encryption_version, - kwargs + self._config.user_agent_policy.user_agent, self._sdk_moniker, self.encryption_version, kwargs ) self._message_decode_policy.configure( require_encryption=self.require_encryption, key_encryption_key=self.key_encryption_key, - resolver=self.key_resolver_function + resolver=self.key_resolver_function, ) try: messages = await self._client.messages.peek( - number_of_messages=max_messages, - timeout=timeout, - cls=self._message_decode_policy, - **kwargs + number_of_messages=max_messages, timeout=timeout, cls=self._message_decode_policy, **kwargs ) wrapped_messages = [] for peeked in messages: @@ -1032,7 +993,8 @@ async def clear_messages(self, *, timeout: Optional[int] = None, **kwargs: Any) @distributed_trace_async async def delete_message( - self, message: Union[str, QueueMessage], + self, + message: Union[str, QueueMessage], pop_receipt: Optional[str] = None, *, timeout: Optional[int] = None, diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py index 7fcb853601d3..577d7ec973a4 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/aio/_queue_service_client_async.py @@ -5,10 +5,7 @@ # -------------------------------------------------------------------------- import functools -from typing import ( - Any, Dict, List, Optional, - TYPE_CHECKING, Union -) +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from typing_extensions import Self from azure.core.async_paging import AsyncItemPaged @@ -37,9 +34,7 @@ class QueueServiceClient( # type: ignore [misc] - AsyncStorageAccountHostsMixin, - StorageAccountHostsMixin, - StorageEncryptionMixin + AsyncStorageAccountHostsMixin, StorageAccountHostsMixin, StorageEncryptionMixin ): """A client to interact with the Queue Service at the account level. @@ -93,25 +88,28 @@ class QueueServiceClient( # type: ignore [misc] """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + self, + account_url: str, + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, audience: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> None: - kwargs['retry_policy'] = kwargs.get('retry_policy') or ExponentialRetry(**kwargs) - loop = kwargs.pop('loop', None) + kwargs["retry_policy"] = kwargs.get("retry_policy") or ExponentialRetry(**kwargs) + loop = kwargs.pop("loop", None) parsed_url, sas_token = _parse_url(account_url=account_url, credential=credential) self._query_str, credential = self._format_query_string(sas_token, credential) super(QueueServiceClient, self).__init__( parsed_url, - service='queue', + service="queue", credential=credential, secondary_hostname=secondary_hostname, audience=audience, - **kwargs + **kwargs, ) self._client = AzureQueueStorage(self.url, base_url=self.url, pipeline=self._pipeline, loop=loop) self._client._config.version = get_api_version(api_version) # type: ignore [assignment] @@ -130,13 +128,16 @@ def _format_url(self, hostname: str) -> str: @classmethod def from_connection_string( - cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"]] = None, # pylint: disable=line-too-long + cls, + conn_str: str, + credential: Optional[ + Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "AsyncTokenCredential"] + ] = None, *, api_version: Optional[str] = None, secondary_hostname: Optional[str] = None, audience: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ) -> Self: """Create QueueServiceClient from a Connection String. @@ -173,14 +174,14 @@ def from_connection_string( :dedent: 8 :caption: Creating the QueueServiceClient with a connection string. """ - account_url, secondary, credential = parse_connection_str(conn_str, credential, 'queue') + account_url, secondary, credential = parse_connection_str(conn_str, credential, "queue") return cls( account_url, credential=credential, api_version=api_version, secondary_hostname=secondary_hostname or secondary, audience=audience, - **kwargs + **kwargs, ) @distributed_trace_async @@ -210,7 +211,8 @@ async def get_service_stats(self, *, timeout: Optional[int] = None, **kwargs: An """ try: stats = await self._client.service.get_statistics( - timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs) + timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs + ) return service_stats_deserialize(stats) except HttpResponseError as error: process_storage_error(error) @@ -243,13 +245,14 @@ async def get_service_properties(self, *, timeout: Optional[int] = None, **kwarg @distributed_trace_async async def set_service_properties( - self, analytics_logging: Optional["QueueAnalyticsLogging"] = None, + self, + analytics_logging: Optional["QueueAnalyticsLogging"] = None, hour_metrics: Optional["Metrics"] = None, minute_metrics: Optional["Metrics"] = None, cors: Optional[List[CorsRule]] = None, *, timeout: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Sets the properties of a storage account's Queue service, including Azure Storage Analytics. @@ -289,7 +292,7 @@ async def set_service_properties( logging=analytics_logging, hour_metrics=hour_metrics, minute_metrics=minute_metrics, - cors=CorsRule._to_generated(cors) # pylint: disable=protected-access + cors=CorsRule._to_generated(cors), # pylint: disable=protected-access ) try: await self._client.service.set_properties(props, timeout=timeout, **kwargs) @@ -298,12 +301,13 @@ async def set_service_properties( @distributed_trace def list_queues( - self, name_starts_with: Optional[str] = None, + self, + name_starts_with: Optional[str] = None, include_metadata: Optional[bool] = False, *, results_per_page: Optional[int] = None, timeout: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> AsyncItemPaged: """Returns a generator to list the queues under the specified account. @@ -338,28 +342,24 @@ def list_queues( :dedent: 16 :caption: List queues in the service. """ - include = ['metadata'] if include_metadata else None + include = ["metadata"] if include_metadata else None command = functools.partial( self._client.service.list_queues_segment, prefix=name_starts_with, include=include, timeout=timeout, - **kwargs + **kwargs, ) return AsyncItemPaged( command, prefix=name_starts_with, results_per_page=results_per_page, - page_iterator_class=QueuePropertiesPaged + page_iterator_class=QueuePropertiesPaged, ) @distributed_trace_async async def create_queue( - self, name: str, - metadata: Optional[Dict[str, str]] = None, - *, - timeout: Optional[int] = None, - **kwargs: Any + self, name: str, metadata: Optional[Dict[str, str]] = None, *, timeout: Optional[int] = None, **kwargs: Any ) -> QueueClient: """Creates a new queue under the specified account. @@ -386,16 +386,13 @@ async def create_queue( :caption: Create a queue in the service. """ queue = self.get_queue_client(name) - kwargs.setdefault('merge_span', True) + kwargs.setdefault("merge_span", True) await queue.create_queue(metadata=metadata, timeout=timeout, **kwargs) return queue @distributed_trace_async async def delete_queue( - self, queue: Union["QueueProperties", str], - *, - timeout: Optional[int] = None, - **kwargs: Any + self, queue: Union["QueueProperties", str], *, timeout: Optional[int] = None, **kwargs: Any ) -> None: """Deletes the specified queue and any messages it contains. @@ -425,7 +422,7 @@ async def delete_queue( :caption: Delete a queue in the service. """ queue_client = self.get_queue_client(queue) - kwargs.setdefault('merge_span', True) + kwargs.setdefault("merge_span", True) await queue_client.delete_queue(timeout=timeout, **kwargs) def get_queue_client(self, queue: Union["QueueProperties", str], **kwargs: Any) -> QueueClient: @@ -456,12 +453,21 @@ def get_queue_client(self, queue: Union["QueueProperties", str], **kwargs: Any) _pipeline = AsyncPipeline( transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable=protected-access - policies=self._pipeline._impl_policies # type: ignore # pylint: disable=protected-access + policies=self._pipeline._impl_policies, # type: ignore # pylint: disable=protected-access ) return QueueClient( - self.url, queue_name=queue_name, credential=self.credential, - key_resolver_function=self.key_resolver_function, require_encryption=self.require_encryption, - encryption_version=self.encryption_version, key_encryption_key=self.key_encryption_key, - api_version=self.api_version, _pipeline=_pipeline, _configuration=self._config, - _location_mode=self._location_mode, _hosts=self._hosts, **kwargs) + self.url, + queue_name=queue_name, + credential=self.credential, + key_resolver_function=self.key_resolver_function, + require_encryption=self.require_encryption, + encryption_version=self.encryption_version, + key_encryption_key=self.key_encryption_key, + api_version=self.api_version, + _pipeline=_pipeline, + _configuration=self._config, + _location_mode=self._location_mode, + _hosts=self._hosts, + **kwargs, + ) diff --git a/sdk/storage/azure-storage-queue/pyproject.toml b/sdk/storage/azure-storage-queue/pyproject.toml index 508be0fa97ba..7ea997ba706c 100644 --- a/sdk/storage/azure-storage-queue/pyproject.toml +++ b/sdk/storage/azure-storage-queue/pyproject.toml @@ -3,4 +3,4 @@ mypy = true pyright = false type_check_samples = true verifytypes = true -black = false +black = true diff --git a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py index 91b36adc2940..8f4a3972f4f8 100644 --- a/sdk/storage/azure-storage-queue/samples/network_activity_logging.py +++ b/sdk/storage/azure-storage-queue/samples/network_activity_logging.py @@ -39,15 +39,15 @@ # Retrieve connection string from environment variables # and construct a blob service client. -connection_string = os.environ.get('STORAGE_CONNECTION_STRING', None) +connection_string = os.environ.get("STORAGE_CONNECTION_STRING", None) if not connection_string: - print('STORAGE_CONNECTION_STRING required.') + print("STORAGE_CONNECTION_STRING required.") sys.exit(1) service_client = QueueServiceClient.from_connection_string(connection_string) # Retrieve a compatible logger and add a handler to send the output to console (STDOUT). # Compatible loggers in this case include `azure` and `azure.storage`. -logger = logging.getLogger('azure.storage.queue') +logger = logging.getLogger("azure.storage.queue") logger.addHandler(logging.StreamHandler(stream=sys.stdout)) # Logging policy logs network activity at the DEBUG level. Set the level on the logger prior to the call. @@ -55,14 +55,14 @@ # The logger level must be set to DEBUG, AND the following must be true: # `logging_enable=True` passed as kwarg to the client constructor OR the API call -print('Request with logging enabled and log level set to DEBUG.') +print("Request with logging enabled and log level set to DEBUG.") queues = service_client.list_queues(logging_enable=True) for queue in queues: - print('Queue: {}'.format(queue.name)) + print("Queue: {}".format(queue.name)) queue_client = service_client.get_queue_client(queue.name) messages = queue_client.peek_messages(max_messages=20, logging_enable=True) for message in messages: try: - print(' Message: {!r}'.format(base64.b64decode(message.content))) + print(" Message: {!r}".format(base64.b64decode(message.content))) except (binascii.Error, ValueError) as e: - print(' Message: {}'.format(message.content)) + print(" Message: {}".format(message.content)) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py index c940a627010d..6219caf5436b 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication.py @@ -40,13 +40,17 @@ class QueueAuthSamples(object): def authentication_by_connection_string(self): if self.connection_string is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_connection_string") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_connection_string" + ) sys.exit(1) # Instantiate a QueueServiceClient using a connection string # [START auth_from_connection_string] from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [END auth_from_connection_string] @@ -55,13 +59,17 @@ def authentication_by_connection_string(self): def authentication_by_shared_key(self): if self.account_url is None or self.access_key is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_shared_key") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_shared_key" + ) sys.exit(1) # Instantiate a QueueServiceClient using a shared access key # [START create_queue_service_client] from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) # [END create_queue_service_client] @@ -70,16 +78,21 @@ def authentication_by_shared_key(self): def authentication_by_oauth(self): if self.account_url is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_oauth") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_oauth" + ) sys.exit(1) # [START create_queue_service_client_oauth] # Get a token credential for authentication from azure.identity import DefaultAzureCredential + token_credential = DefaultAzureCredential() # Instantiate a QueueServiceClient using a token credential from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) # [END create_queue_service_client_oauth] @@ -87,17 +100,22 @@ def authentication_by_oauth(self): properties = queue_service.get_service_properties() def authentication_by_shared_access_signature(self): - if (self.connection_string is None or - self.account_name is None or - self.access_key is None or - self.account_url is None + if ( + self.connection_string is None + or self.account_name is None + or self.access_key is None + or self.account_url is None ): - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_shared_access_signature") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_shared_access_signature" + ) sys.exit(1) # Instantiate a QueueServiceClient using a connection string from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Create a SAS token to use for authentication of a client @@ -108,7 +126,7 @@ def authentication_by_shared_access_signature(self): self.access_key, resource_types=ResourceTypes(service=True), permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1) + expiry=datetime.utcnow() + timedelta(hours=1), ) token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) @@ -117,7 +135,7 @@ def authentication_by_shared_access_signature(self): properties = token_auth_queue_service.get_service_properties() -if __name__ == '__main__': +if __name__ == "__main__": sample = QueueAuthSamples() sample.authentication_by_connection_string() sample.authentication_by_shared_key() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py index 707e3adf6538..c109df896b73 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_authentication_async.py @@ -38,15 +38,20 @@ class QueueAuthSamplesAsync(object): account_url = os.getenv("STORAGE_ACCOUNT_QUEUE_URL") account_name = os.getenv("STORAGE_ACCOUNT_NAME") access_key = os.getenv("STORAGE_ACCOUNT_KEY") + async def authentication_by_connection_string_async(self): if self.connection_string is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_connection_string_async") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_connection_string_async" + ) sys.exit(1) # Instantiate a QueueServiceClient using a connection string # [START async_auth_from_connection_string] from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [END async_auth_from_connection_string] @@ -56,13 +61,17 @@ async def authentication_by_connection_string_async(self): async def authentication_by_shared_key_async(self): if self.account_url is None or self.access_key is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_shared_key_async") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_shared_key_async" + ) sys.exit(1) # Instantiate a QueueServiceClient using a shared access key # [START async_create_queue_service_client] from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.account_url, credential=self.access_key) # [END async_create_queue_service_client] # Get information for the Queue Service @@ -71,16 +80,21 @@ async def authentication_by_shared_key_async(self): async def authentication_by_oauth_async(self): if self.account_url is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_oauth") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_oauth" + ) sys.exit(1) # [START async_create_queue_service_client_oauth] # Get a token credential for authentication from azure.identity.aio import DefaultAzureCredential + token_credential = DefaultAzureCredential() # Instantiate a QueueServiceClient using a token credential from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient(account_url=self.account_url, credential=token_credential) # [END async_create_queue_service_client_oauth] @@ -89,27 +103,33 @@ async def authentication_by_oauth_async(self): properties = await queue_service.get_service_properties() async def authentication_by_shared_access_signature_async(self): - if (self.connection_string is None or - self.account_name is None or - self.access_key is None or - self.account_url is None + if ( + self.connection_string is None + or self.account_name is None + or self.access_key is None + or self.account_url is None ): - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: authentication_by_shared_access_signature_async") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: authentication_by_shared_access_signature_async" + ) sys.exit(1) # Instantiate a QueueServiceClient using a connection string from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Create a SAS token to use for authentication of a client from azure.storage.queue import generate_account_sas, ResourceTypes, AccountSasPermissions + sas_token = generate_account_sas( self.account_name, self.access_key, resource_types=ResourceTypes(service=True), permission=AccountSasPermissions(read=True), - expiry=datetime.utcnow() + timedelta(hours=1) + expiry=datetime.utcnow() + timedelta(hours=1), ) token_auth_queue_service = QueueServiceClient(account_url=self.account_url, credential=sas_token) @@ -125,5 +145,6 @@ async def main(): await sample.authentication_by_oauth_async() await sample.authentication_by_shared_access_signature_async() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py index 3957feee19fc..03969be87f7d 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world.py @@ -30,12 +30,16 @@ class QueueHelloWorldSamples(object): def create_client_with_connection_string(self): if self.connection_string is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: create_client_with_connection_string") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: create_client_with_connection_string" + ) sys.exit(1) # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Get queue service properties @@ -43,12 +47,16 @@ def create_client_with_connection_string(self): def queue_and_messages_example(self): if self.connection_string is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: queue_and_messages_example") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: queue_and_messages_example" + ) sys.exit(1) # Instantiate the QueueClient from a connection string from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="myqueue") # Create the queue @@ -74,7 +82,7 @@ def queue_and_messages_example(self): # [END delete_queue] -if __name__ == '__main__': +if __name__ == "__main__": sample = QueueHelloWorldSamples() sample.create_client_with_connection_string() sample.queue_and_messages_example() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py index 376826daec81..46803f90cda5 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_hello_world_async.py @@ -32,12 +32,16 @@ class QueueHelloWorldSamplesAsync(object): async def create_client_with_connection_string_async(self): if self.connection_string is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: create_client_with_connection_string_async") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: create_client_with_connection_string_async" + ) sys.exit(1) # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # Get queue service properties @@ -46,12 +50,16 @@ async def create_client_with_connection_string_async(self): async def queue_and_messages_example_async(self): if self.connection_string is None: - print("Missing required environment variable(s). Please see specific test for more details." + '\n' + - "Test: queue_and_messages_example_async") + print( + "Missing required environment variable(s). Please see specific test for more details." + + "\n" + + "Test: queue_and_messages_example_async" + ) sys.exit(1) # Instantiate the QueueClient from a connection string from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(conn_str=self.connection_string, queue_name="asyncmyqueue") async with queue: @@ -63,8 +71,7 @@ async def queue_and_messages_example_async(self): try: # Send messages await asyncio.gather( - queue.send_message("I'm using queues!"), - queue.send_message("This is my second message") + queue.send_message("I'm using queues!"), queue.send_message("This is my second message") ) # Receive the messages @@ -85,5 +92,6 @@ async def main(): await sample.create_client_with_connection_string_async() await sample.queue_and_messages_example_async() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py index c734ce59752a..1320ab463224 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message.py @@ -39,10 +39,10 @@ def set_access_policy(self): # [START create_queue_client_from_connection_string] from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue1") if queue.account_name is None: - print("Connection string did not provide an account name." + '\n' + - "Test: set_access_policy") + print("Connection string did not provide an account name." + "\n" + "Test: set_access_policy") sys.exit(1) # [END create_queue_client_from_connection_string] @@ -56,11 +56,12 @@ def set_access_policy(self): # [START set_access_policy] # Create an access policy from azure.storage.queue import AccessPolicy, QueueSasPermissions + access_policy = AccessPolicy() access_policy.start = datetime.utcnow() - timedelta(hours=1) access_policy.expiry = datetime.utcnow() + timedelta(hours=1) access_policy.permission = QueueSasPermissions(read=True) - identifiers = {'my-access-policy-id': access_policy} + identifiers = {"my-access-policy-id": access_policy} # Set the access policy queue.set_queue_access_policy(identifiers) @@ -69,20 +70,15 @@ def set_access_policy(self): # Use the access policy to generate a SAS token # [START queue_client_sas_token] from azure.storage.queue import generate_queue_sas + sas_token = generate_queue_sas( - queue.account_name, - queue.queue_name, - queue.credential.account_key, - policy_id='my-access-policy-id' + queue.account_name, queue.queue_name, queue.credential.account_key, policy_id="my-access-policy-id" ) # [END queue_client_sas_token] # Authenticate with the sas token # [START create_queue_client] - token_auth_queue = QueueClient.from_queue_url( - queue_url=queue.url, - credential=sas_token - ) + token_auth_queue = QueueClient.from_queue_url(queue_url=queue.url, credential=sas_token) # [END create_queue_client] # Use the newly authenticated client to receive messages @@ -99,6 +95,7 @@ def queue_metadata(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue2") # Create the queue @@ -106,7 +103,7 @@ def queue_metadata(self): try: # [START set_queue_metadata] - metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} + metadata = {"foo": "val1", "bar": "val2", "baz": "val3"} queue.set_queue_metadata(metadata=metadata) # [END set_queue_metadata] @@ -125,6 +122,7 @@ def send_and_receive_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue3") # Create the queue @@ -170,6 +168,7 @@ def list_message_pages(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue4") # Create the queue @@ -207,6 +206,7 @@ def receive_one_message_from_queue(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue5") # Create the queue @@ -242,6 +242,7 @@ def delete_and_clear_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue6") # Create the queue @@ -278,6 +279,7 @@ def peek_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue7") # Create the queue @@ -314,6 +316,7 @@ def update_message(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue8") # Create the queue @@ -330,10 +333,8 @@ def update_message(self): # Update the message list_result = next(messages) message = queue.update_message( - list_result.id, - pop_receipt=list_result.pop_receipt, - visibility_timeout=0, - content="updated") + list_result.id, pop_receipt=list_result.pop_receipt, visibility_timeout=0, content="updated" + ) # [END update_message] finally: @@ -347,6 +348,7 @@ def receive_messages_with_max_messages(self): # Instantiate a queue client from azure.storage.queue import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueue9") # Create the queue @@ -382,7 +384,7 @@ def receive_messages_with_max_messages(self): queue.delete_queue() -if __name__ == '__main__': +if __name__ == "__main__": sample = QueueMessageSamples() sample.set_access_policy() sample.queue_metadata() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py index 884942c15ed0..1fa4ac924b99 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_message_async.py @@ -39,10 +39,10 @@ async def set_access_policy_async(self): # [START async_create_queue_client_from_connection_string] from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync1") if queue.account_name is None: - print("Connection string did not provide an account name." + '\n' + - "Test: set_access_policy_async") + print("Connection string did not provide an account name." + "\n" + "Test: set_access_policy_async") sys.exit(1) # [END async_create_queue_client_from_connection_string] @@ -57,11 +57,12 @@ async def set_access_policy_async(self): # [START async_set_access_policy] # Create an access policy from azure.storage.queue import AccessPolicy, QueueSasPermissions + access_policy = AccessPolicy() access_policy.start = datetime.utcnow() - timedelta(hours=1) access_policy.expiry = datetime.utcnow() + timedelta(hours=1) access_policy.permission = QueueSasPermissions(read=True) - identifiers = {'my-access-policy-id': access_policy} + identifiers = {"my-access-policy-id": access_policy} # Set the access policy await queue.set_queue_access_policy(identifiers) @@ -69,19 +70,14 @@ async def set_access_policy_async(self): # Use the access policy to generate a SAS token from azure.storage.queue import generate_queue_sas + sas_token = generate_queue_sas( - queue.account_name, - queue.queue_name, - queue.credential.account_key, - policy_id='my-access-policy-id' + queue.account_name, queue.queue_name, queue.credential.account_key, policy_id="my-access-policy-id" ) # Authenticate with the sas token # [START async_create_queue_client] - token_auth_queue = QueueClient.from_queue_url( - queue_url=queue.url, - credential=sas_token - ) + token_auth_queue = QueueClient.from_queue_url(queue_url=queue.url, credential=sas_token) # [END async_create_queue_client] # Use the newly authenticated client to receive messages @@ -98,6 +94,7 @@ async def queue_metadata_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync2") # Create the queue @@ -106,7 +103,7 @@ async def queue_metadata_async(self): try: # [START async_set_queue_metadata] - metadata = {'foo': 'val1', 'bar': 'val2', 'baz': 'val3'} + metadata = {"foo": "val1", "bar": "val2", "baz": "val3"} await queue.set_queue_metadata(metadata=metadata) # [END async_set_queue_metadata] @@ -125,6 +122,7 @@ async def send_and_receive_messages_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync3") # Create the queue @@ -138,7 +136,7 @@ async def send_and_receive_messages_async(self): queue.send_message("message2", visibility_timeout=30), # wait 30s before becoming visible queue.send_message("message3"), queue.send_message("message4"), - queue.send_message("message5") + queue.send_message("message5"), ) # [END async_send_messages] @@ -173,6 +171,7 @@ async def receive_one_message_from_queue(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync4") # Create the queue @@ -181,9 +180,8 @@ async def receive_one_message_from_queue(self): try: await asyncio.gather( - queue.send_message("message1"), - queue.send_message("message2"), - queue.send_message("message3")) + queue.send_message("message1"), queue.send_message("message2"), queue.send_message("message3") + ) # [START receive_one_message] # Pop two messages from the front of the queue @@ -210,6 +208,7 @@ async def delete_and_clear_messages_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync5") # Create the queue @@ -223,7 +222,7 @@ async def delete_and_clear_messages_async(self): queue.send_message("message2"), queue.send_message("message3"), queue.send_message("message4"), - queue.send_message("message5") + queue.send_message("message5"), ) # [START async_delete_message] @@ -232,7 +231,7 @@ async def delete_and_clear_messages_async(self): async for msg in messages: # Delete the specified message await queue.delete_message(msg) - # [END async_delete_message] + # [END async_delete_message] break # [START async_clear_messages] @@ -250,6 +249,7 @@ async def peek_messages_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync6") # Create the queue @@ -263,7 +263,7 @@ async def peek_messages_async(self): queue.send_message("message2"), queue.send_message("message3"), queue.send_message("message4"), - queue.send_message("message5") + queue.send_message("message5"), ) # [START async_peek_message] @@ -289,6 +289,7 @@ async def update_message_async(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync7") # Create the queue @@ -305,17 +306,14 @@ async def update_message_async(self): # Update the message async for message in messages: - message = await queue.update_message( - message, - visibility_timeout=0, - content="updated") - # [END async_update_message] + message = await queue.update_message(message, visibility_timeout=0, content="updated") + # [END async_update_message] break finally: # Delete the queue await queue.delete_queue() - + async def receive_messages_with_max_messages(self): if self.connection_string is None: print("Missing required environment variable: connection_string") @@ -323,6 +321,7 @@ async def receive_messages_with_max_messages(self): # Instantiate a queue client from azure.storage.queue.aio import QueueClient + queue = QueueClient.from_connection_string(self.connection_string, "myqueueasync8") # Create the queue @@ -340,7 +339,7 @@ async def receive_messages_with_max_messages(self): await queue.send_message("message8") await queue.send_message("message9") await queue.send_message("message10") - + # Receive messages one-by-one messages = queue.receive_messages(max_messages=5) async for msg in messages: @@ -370,5 +369,6 @@ async def main(): await sample.update_message_async() await sample.receive_messages_with_max_messages() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py index 9905cf3f3067..b537470edb93 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service.py @@ -35,6 +35,7 @@ def queue_service_properties(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START set_queue_service_properties] @@ -42,16 +43,20 @@ def queue_service_properties(self): from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy # Create logging settings - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + logging = QueueAnalyticsLogging( + read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Create metrics for requests statistics hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Create CORS rules - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] + cors_rule1 = CorsRule(["www.xyz.com"], ["GET"]) + allowed_origins = ["www.xyz.com", "www.ab.com", "www.bc.com"] + allowed_methods = ["GET", "PUT"] max_age_in_seconds = 500 exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] @@ -60,7 +65,7 @@ def queue_service_properties(self): allowed_methods, max_age_in_seconds=max_age_in_seconds, exposed_headers=exposed_headers, - allowed_headers=allowed_headers + allowed_headers=allowed_headers, ) cors = [cors_rule1, cors_rule2] @@ -80,6 +85,7 @@ def queues_in_account(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START qsc_create_queue] @@ -111,6 +117,7 @@ def get_queue_client(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue import QueueServiceClient, QueueClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START get_queue_client] @@ -119,7 +126,7 @@ def get_queue_client(self): # [END get_queue_client] -if __name__ == '__main__': +if __name__ == "__main__": sample = QueueServiceSamples() sample.queue_service_properties() sample.queues_in_account() diff --git a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py index b802abc36020..7b6d669fa6f5 100644 --- a/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py +++ b/sdk/storage/azure-storage-queue/samples/queue_samples_service_async.py @@ -37,6 +37,7 @@ async def queue_service_properties_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) async with queue_service: @@ -45,16 +46,22 @@ async def queue_service_properties_async(self): from azure.storage.queue import QueueAnalyticsLogging, Metrics, CorsRule, RetentionPolicy # Create logging settings - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + logging = QueueAnalyticsLogging( + read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Create metrics for requests statistics - hour_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) - minute_metrics = Metrics(enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + hour_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) + minute_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Create CORS rules - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] + cors_rule1 = CorsRule(["www.xyz.com"], ["GET"]) + allowed_origins = ["www.xyz.com", "www.ab.com", "www.bc.com"] + allowed_methods = ["GET", "PUT"] max_age_in_seconds = 500 exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] @@ -63,7 +70,7 @@ async def queue_service_properties_async(self): allowed_methods, max_age_in_seconds=max_age_in_seconds, exposed_headers=exposed_headers, - allowed_headers=allowed_headers + allowed_headers=allowed_headers, ) cors = [cors_rule1, cors_rule2] @@ -83,6 +90,7 @@ async def queues_in_account_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) async with queue_service: @@ -115,6 +123,7 @@ async def get_queue_client_async(self): # Instantiate the QueueServiceClient from a connection string from azure.storage.queue.aio import QueueServiceClient, QueueClient + queue_service = QueueServiceClient.from_connection_string(conn_str=self.connection_string) # [START async_get_queue_client] @@ -129,5 +138,6 @@ async def main(): await sample.queues_in_account_async() await sample.get_queue_client_async() -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/sdk/storage/azure-storage-queue/setup.py b/sdk/storage/azure-storage-queue/setup.py index ee2dda40b482..52d6676f7f05 100644 --- a/sdk/storage/azure-storage-queue/setup.py +++ b/sdk/storage/azure-storage-queue/setup.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import re import os.path @@ -16,63 +16,59 @@ PACKAGE_PPRINT_NAME = "Azure Queue Storage" # a-b-c => a/b/c -package_folder_path = PACKAGE_NAME.replace('-', '/') +package_folder_path = PACKAGE_NAME.replace("-", "/") # a-b-c => a.b.c -namespace_name = PACKAGE_NAME.replace('-', '.') +namespace_name = PACKAGE_NAME.replace("-", ".") # Version extraction inspired from 'requests' -with open(os.path.join(package_folder_path, '_version.py'), 'r') as fd: - version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', # type: ignore - fd.read(), re.MULTILINE).group(1) +with open(os.path.join(package_folder_path, "_version.py"), "r") as fd: + version = re.search(r'^VERSION\s*=\s*[\'"]([^\'"]*)[\'"]', fd.read(), re.MULTILINE).group(1) # type: ignore if not version: - raise RuntimeError('Cannot find version information') + raise RuntimeError("Cannot find version information") -with open('README.md', encoding='utf-8') as f: +with open("README.md", encoding="utf-8") as f: readme = f.read() -with open('CHANGELOG.md', encoding='utf-8') as f: +with open("CHANGELOG.md", encoding="utf-8") as f: changelog = f.read() setup( name=PACKAGE_NAME, version=version, include_package_data=True, - description=f'Microsoft Azure {PACKAGE_PPRINT_NAME} Client Library for Python', - long_description=readme + '\n\n' + changelog, - long_description_content_type='text/markdown', - license='MIT License', - author='Microsoft Corporation', - author_email='ascl@microsoft.com', - url='https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-queue', + description=f"Microsoft Azure {PACKAGE_PPRINT_NAME} Client Library for Python", + long_description=readme + "\n\n" + changelog, + long_description_content_type="text/markdown", + license="MIT License", + author="Microsoft Corporation", + author_email="ascl@microsoft.com", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-queue", keywords="azure, azure sdk", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Programming Language :: Python', + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'License :: OSI Approved :: MIT License', + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: MIT License", ], zip_safe=False, - packages=find_packages(exclude=[ - # Exclude packages that will be covered by PEP420 or nspkg - 'azure', - 'azure.storage', - 'tests', - 'tests.queue', - 'tests.common' - ]), + packages=find_packages( + exclude=[ + # Exclude packages that will be covered by PEP420 or nspkg + "azure", + "azure.storage", + "tests", + "tests.queue", + "tests.common", + ] + ), python_requires=">=3.8", - install_requires=[ - "azure-core>=1.30.0", - "cryptography>=2.1.4", - "typing-extensions>=4.6.0", - "isodate>=0.6.1" - ], + install_requires=["azure-core>=1.30.0", "cryptography>=2.1.4", "typing-extensions>=4.6.0", "isodate>=0.6.1"], extras_require={ "aio": [ "azure-core[aio]>=1.30.0", diff --git a/sdk/storage/azure-storage-queue/tests/conftest.py b/sdk/storage/azure-storage-queue/tests/conftest.py index 8b1b824ea060..ca09905c2ca6 100644 --- a/sdk/storage/azure-storage-queue/tests/conftest.py +++ b/sdk/storage/azure-storage-queue/tests/conftest.py @@ -14,9 +14,10 @@ add_header_regex_sanitizer, add_oauth_response_sanitizer, add_uri_string_sanitizer, - test_proxy + test_proxy, ) + @pytest.fixture(scope="session", autouse=True) def add_sanitizers(test_proxy): subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000") diff --git a/sdk/storage/azure-storage-queue/tests/encryption_test_helper.py b/sdk/storage/azure-storage-queue/tests/encryption_test_helper.py index 50a7e7c9b29d..d54df3de963b 100644 --- a/sdk/storage/azure-storage-queue/tests/encryption_test_helper.py +++ b/sdk/storage/azure-storage-queue/tests/encryption_test_helper.py @@ -17,26 +17,26 @@ class KeyWrapper: - def __init__(self, kid='local:key1'): + def __init__(self, kid="local:key1"): # Must have constant key value for recorded tests, otherwise we could use a random generator. - self.kek = b'\xbe\xa4\x11K\x9eJ\x07\xdafF\x83\xad+\xadvA C\xe8\xbc\x90\xa4\x11}G\xc3\x0f\xd4\xb4\x19m\x11' + self.kek = b"\xbe\xa4\x11K\x9eJ\x07\xdafF\x83\xad+\xadvA C\xe8\xbc\x90\xa4\x11}G\xc3\x0f\xd4\xb4\x19m\x11" self.backend = default_backend() self.kid = kid - def wrap_key(self, key, algorithm='A256KW'): - if algorithm == 'A256KW': + def wrap_key(self, key, algorithm="A256KW"): + if algorithm == "A256KW": return aes_key_wrap(self.kek, key, self.backend) raise ValueError(_ERROR_UNKNOWN_KEY_WRAP_ALGORITHM) def unwrap_key(self, key, algorithm): - if algorithm == 'A256KW': + if algorithm == "A256KW": return aes_key_unwrap(self.kek, key, self.backend) raise ValueError(_ERROR_UNKNOWN_KEY_WRAP_ALGORITHM) def get_key_wrap_algorithm(self): - return 'A256KW' + return "A256KW" def get_kid(self): return self.kid @@ -54,37 +54,29 @@ def resolve_key(self, kid): class RSAKeyWrapper: - def __init__(self, kid='local:key2'): - self.private_key = generate_private_key(public_exponent=65537, - key_size=2048, - backend=default_backend()) + def __init__(self, kid="local:key2"): + self.private_key = generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) self.public_key = self.private_key.public_key() self.kid = kid - def wrap_key(self, key, algorithm='RSA'): - if algorithm == 'RSA': - return self.public_key.encrypt(key, - OAEP( - mgf=MGF1(algorithm=SHA1()), # nosec - algorithm=SHA1(), # nosec - label=None) - ) + def wrap_key(self, key, algorithm="RSA"): + if algorithm == "RSA": + return self.public_key.encrypt( + key, OAEP(mgf=MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None) # nosec # nosec + ) raise ValueError(_ERROR_UNKNOWN_KEY_WRAP_ALGORITHM) def unwrap_key(self, key, algorithm): - if algorithm == 'RSA': - return self.private_key.decrypt(key, - OAEP( - mgf=MGF1(algorithm=SHA1()), # nosec - algorithm=SHA1(), # nosec - label=None) - ) + if algorithm == "RSA": + return self.private_key.decrypt( + key, OAEP(mgf=MGF1(algorithm=SHA1()), algorithm=SHA1(), label=None) # nosec # nosec + ) raise ValueError(_ERROR_UNKNOWN_KEY_WRAP_ALGORITHM) def get_key_wrap_algorithm(self): - return 'RSA' + return "RSA" def get_kid(self): return self.kid @@ -97,10 +89,10 @@ def mock_urandom(size: int) -> bytes: to be recorded. """ if size == 12: - return b'Mb\xd5N\xc2\xbd\xa0\xc8\xa4L\xfb\xa0' + return b"Mb\xd5N\xc2\xbd\xa0\xc8\xa4L\xfb\xa0" elif size == 16: - return b'\xbb\xd6\x87\xb6j\xe5\xdc\x93\xb0\x13\x1e\xcc\x9f\xf4\xca\xab' + return b"\xbb\xd6\x87\xb6j\xe5\xdc\x93\xb0\x13\x1e\xcc\x9f\xf4\xca\xab" elif size == 32: - return b'\x08\xe0A\xb6\xf2\xb7x\x8f\xe5\xdap\x87^6x~\xa4F\xc4\xe9\xb1\x8a:\xfbC%S\x0cZ\xbb\xbe\x88' + return b"\x08\xe0A\xb6\xf2\xb7x\x8f\xe5\xdap\x87^6x~\xa4F\xc4\xe9\xb1\x8a:\xfbC%S\x0cZ\xbb\xbe\x88" else: - return os.urandom(size) \ No newline at end of file + return os.urandom(size) diff --git a/sdk/storage/azure-storage-queue/tests/settings/service_versions.py b/sdk/storage/azure-storage-queue/tests/settings/service_versions.py deleted file mode 100644 index 20219b27a8fc..000000000000 --- a/sdk/storage/azure-storage-queue/tests/settings/service_versions.py +++ /dev/null @@ -1,27 +0,0 @@ -from enum import Enum - - -class ServiceVersion(str, Enum): - - V2019_02_02 = "2019-02-02" - V2019_07_07 = "2019-07-07" - V2019_10_10 = "2019-10-10" - V2019_12_12 = "2019-12-12" - V2020_02_10 = "2020-02-10" - V2020_04_08 = "2020-04-08" - V2020_06_12 = "2020-06-12" - V2020_08_04 = "2020-08-04" - - -service_version_map = { - "V2019_02_02": ServiceVersion.V2019_02_02, - "V2019_07_07": ServiceVersion.V2019_07_07, - "V2019_10_10": ServiceVersion.V2019_10_10, - "V2019_12_12": ServiceVersion.V2019_12_12, - "V2020_02_10": ServiceVersion.V2020_02_10, - "V2020_04_08": ServiceVersion.V2020_04_08, - "V2020_06_12": ServiceVersion.V2020_06_12, - "V2020_08_04": ServiceVersion.V2020_08_04, - "LATEST": ServiceVersion.V2020_08_04, - "LATEST_PLUS_1": ServiceVersion.V2020_06_12 -} diff --git a/sdk/storage/azure-storage-queue/tests/settings/settings_fake.py b/sdk/storage/azure-storage-queue/tests/settings/settings_fake.py index 19f4029068c2..f393095e7019 100644 --- a/sdk/storage/azure-storage-queue/tests/settings/settings_fake.py +++ b/sdk/storage/azure-storage-queue/tests/settings/settings_fake.py @@ -7,7 +7,7 @@ STORAGE_ACCOUNT_NAME = "fakename" STORAGE_ACCOUNT_KEY = "fakekey" -ACCOUNT_URL_SUFFIX = 'core.windows.net' +ACCOUNT_URL_SUFFIX = "core.windows.net" RUN_IN_LIVE = "False" SKIP_LIVE_RECORDING = "True" diff --git a/sdk/storage/azure-storage-queue/tests/settings/testcase.py b/sdk/storage/azure-storage-queue/tests/settings/testcase.py index 9a2f59ba331c..5d139ce47608 100644 --- a/sdk/storage/azure-storage-queue/tests/settings/testcase.py +++ b/sdk/storage/azure-storage-queue/tests/settings/testcase.py @@ -16,7 +16,7 @@ from devtools_testutils.fake_credentials import STORAGE_ACCOUNT_FAKE_KEY try: - from cStringIO import StringIO # Python 2 + from cStringIO import StringIO # Python 2 except ImportError: from io import StringIO try: @@ -27,23 +27,26 @@ from .settings_fake import * -LOGGING_FORMAT = '%(asctime)s %(name)-20s %(levelname)-5s %(message)s' -os.environ['STORAGE_ACCOUNT_NAME'] = os.environ.get('STORAGE_ACCOUNT_NAME', None) or STORAGE_ACCOUNT_NAME -os.environ['STORAGE_ACCOUNT_KEY'] = os.environ.get('STORAGE_ACCOUNT_KEY', None) or STORAGE_ACCOUNT_KEY +LOGGING_FORMAT = "%(asctime)s %(name)-20s %(levelname)-5s %(message)s" +os.environ["STORAGE_ACCOUNT_NAME"] = os.environ.get("STORAGE_ACCOUNT_NAME", None) or STORAGE_ACCOUNT_NAME +os.environ["STORAGE_ACCOUNT_KEY"] = os.environ.get("STORAGE_ACCOUNT_KEY", None) or STORAGE_ACCOUNT_KEY -os.environ['AZURE_TEST_RUN_LIVE'] = os.environ.get('AZURE_TEST_RUN_LIVE', None) or RUN_IN_LIVE -os.environ['AZURE_SKIP_LIVE_RECORDING'] = os.environ.get('AZURE_SKIP_LIVE_RECORDING', None) or SKIP_LIVE_RECORDING -os.environ['PROTOCOL'] = PROTOCOL -os.environ['ACCOUNT_URL_SUFFIX'] = ACCOUNT_URL_SUFFIX +os.environ["AZURE_TEST_RUN_LIVE"] = os.environ.get("AZURE_TEST_RUN_LIVE", None) or RUN_IN_LIVE +os.environ["AZURE_SKIP_LIVE_RECORDING"] = os.environ.get("AZURE_SKIP_LIVE_RECORDING", None) or SKIP_LIVE_RECORDING +os.environ["PROTOCOL"] = PROTOCOL +os.environ["ACCOUNT_URL_SUFFIX"] = ACCOUNT_URL_SUFFIX QueuePreparer = functools.partial( - PowerShellPreparer, "storage", + PowerShellPreparer, + "storage", storage_account_name="storagename", storage_account_key=STORAGE_ACCOUNT_FAKE_KEY, ) + def not_for_emulator(test): def skip_test_if_targeting_emulator(self): test(self) + return skip_test_if_targeting_emulator diff --git a/sdk/storage/azure-storage-queue/tests/test_queue.py b/sdk/storage/azure-storage-queue/tests/test_queue.py index ba1b6361edf3..7b396cb28037 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue.py @@ -13,7 +13,7 @@ ClientAuthenticationError, HttpResponseError, ResourceExistsError, - ResourceNotFoundError + ResourceNotFoundError, ) from azure.core.pipeline.transport import RequestsTransport from azure.storage.queue import ( @@ -24,7 +24,7 @@ QueueClient, QueueSasPermissions, QueueServiceClient, - ResourceTypes + ResourceTypes, ) from devtools_testutils import FakeTokenCredential, recorded_by_proxy @@ -32,11 +32,12 @@ from settings.testcase import QueuePreparer # ------------------------------------------------------------------------------ -TEST_QUEUE_PREFIX = 'pyqueuesync' +TEST_QUEUE_PREFIX = "pyqueuesync" # ------------------------------------------------------------------------------ + # pylint: disable=locally-disabled, multiple-statements, fixme, too-many-lines class TestStorageQueue(StorageRecordedTestCase): # --Helpers----------------------------------------------------------------- @@ -110,15 +111,14 @@ def test_create_queue_with_options(self, **kwargs): url = self.account_url(storage_account_name, "queue") qsc = QueueServiceClient(url, storage_account_key) queue_client = self._get_queue_reference(qsc) - queue_client.create_queue( - metadata={'val1': 'test', 'val2': 'blah'}) + queue_client.create_queue(metadata={"val1": "test", "val2": "blah"}) props = queue_client.get_queue_properties() # Asserts assert 0 == props.approximate_message_count assert 2 == len(props.metadata) - assert 'test' == props.metadata['val1'] - assert 'blah' == props.metadata['val2'] + assert "test" == props.metadata["val1"] + assert "blah" == props.metadata["val2"] @QueuePreparer() @recorded_by_proxy @@ -173,21 +173,19 @@ def test_list_queues_with_options(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - prefix = 'listqueue' + prefix = "listqueue" qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_list = [] for i in range(0, 4): self._create_queue(qsc, prefix + str(i), queue_list) # Action - generator1 = qsc.list_queues( - name_starts_with=prefix, - results_per_page=3).by_page() + generator1 = qsc.list_queues(name_starts_with=prefix, results_per_page=3).by_page() queues1 = list(next(generator1)) - generator2 = qsc.list_queues( - name_starts_with=prefix, - include_metadata=True).by_page(generator1.continuation_token) + generator2 = qsc.list_queues(name_starts_with=prefix, include_metadata=True).by_page( + generator1.continuation_token + ) queues2 = list(next(generator2)) # Asserts @@ -195,13 +193,13 @@ def test_list_queues_with_options(self, **kwargs): assert 3 == len(queues1) assert queues1[0] is not None assert queues1[0].metadata is None - assert '' != queues1[0].name + assert "" != queues1[0].name assert generator1.location_mode is not None # Asserts assert queues2 is not None assert len(queue_list) - 3 <= len(queues2) assert queues2[0] is not None - assert '' != queues2[0].name + assert "" != queues2[0].name @QueuePreparer() @recorded_by_proxy @@ -213,19 +211,18 @@ def test_list_queues_with_metadata(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._get_queue_reference(qsc) queue.create_queue() - queue.set_queue_metadata(metadata={'val1': 'test', 'val2': 'blah'}) + queue.set_queue_metadata(metadata={"val1": "test", "val2": "blah"}) - listed_queue = list(qsc.list_queues( - name_starts_with=queue.queue_name, - results_per_page=1, - include_metadata=True))[0] + listed_queue = list( + qsc.list_queues(name_starts_with=queue.queue_name, results_per_page=1, include_metadata=True) + )[0] # Asserts assert listed_queue is not None assert queue.queue_name == listed_queue.name assert listed_queue.metadata is not None assert len(listed_queue.metadata) == 2 - assert listed_queue.metadata['val1'] == 'test' + assert listed_queue.metadata["val1"] == "test" @QueuePreparer() @recorded_by_proxy @@ -243,7 +240,7 @@ def test_list_queues_account_sas(self, **kwargs): storage_account_key, ResourceTypes(service=True), AccountSasPermissions(list=True), - datetime.utcnow() + timedelta(hours=1) + datetime.utcnow() + timedelta(hours=1), ) # Act @@ -263,7 +260,7 @@ def test_set_queue_metadata(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._get_queue_reference(qsc) - metadata = {'hello': 'world', 'number': '43'} + metadata = {"hello": "world", "number": "43"} queue.create_queue() # Act @@ -282,11 +279,11 @@ def test_get_queue_metadata_message_count(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - sent_message = queue_client.send_message('message1') + sent_message = queue_client.send_message("message1") props = queue_client.get_queue_properties() # Asserts - assert 'message1' == sent_message.content + assert "message1" == sent_message.content assert props.approximate_message_count >= 1 assert 0 == len(props.metadata) @@ -315,7 +312,7 @@ def test_queue_not_exists(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - queue = qsc.get_queue_client(self.get_resource_name('missing')) + queue = qsc.get_queue_client(self.get_resource_name("missing")) # Act with pytest.raises(ResourceNotFoundError): queue.get_queue_properties() @@ -332,18 +329,18 @@ def test_put_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - message = queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + message = queue_client.send_message("message4") # Asserts assert message is not None - assert '' != message.id + assert "" != message.id assert isinstance(message.inserted_on, datetime) assert isinstance(message.expires_on, datetime) - assert '' != message.pop_receipt - assert 'message4' == message.content + assert "" != message.pop_receipt + assert "message4" == message.content @QueuePreparer() @recorded_by_proxy @@ -356,7 +353,7 @@ def test_put_message_large_time_to_live(self, **kwargs): queue_client = self._get_queue_reference(qsc) queue_client.create_queue() # There should be no upper bound on a queue message's time to live - queue_client.send_message('message1', time_to_live=1024 * 1024 * 1024) + queue_client.send_message("message1", time_to_live=1024 * 1024 * 1024) # Act messages = queue_client.peek_messages() @@ -374,7 +371,7 @@ def test_put_message_infinite_time_to_live(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1', time_to_live=-1) + queue_client.send_message("message1", time_to_live=-1) # Act messages = queue_client.peek_messages() @@ -392,18 +389,18 @@ def test_get_messages(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") message = next(queue_client.receive_messages()) # Asserts assert message is not None assert message is not None - assert '' != message.id - assert 'message1' == message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "message1" == message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count assert isinstance(message.inserted_on, datetime) @@ -422,9 +419,9 @@ def test_receive_one_message(self, **kwargs): queue_client.create_queue() assert queue_client.receive_message() is None - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") message1 = queue_client.receive_message() message2 = queue_client.receive_message() @@ -432,18 +429,18 @@ def test_receive_one_message(self, **kwargs): # Asserts assert message1 is not None - assert '' != message1.id - assert 'message1' == message1.content - assert '' != message1.pop_receipt + assert "" != message1.id + assert "message1" == message1.content + assert "" != message1.pop_receipt assert 1 == message1.dequeue_count assert message2 is not None - assert '' != message2.id - assert 'message2' == message2.content - assert '' != message2.pop_receipt + assert "" != message2.id + assert "message2" == message2.content + assert "" != message2.pop_receipt assert 1 == message2.dequeue_count - assert 'message3' == peeked_message3.content + assert "message3" == peeked_message3.content assert 0 == peeked_message3.dequeue_count @QueuePreparer() @@ -456,10 +453,10 @@ def test_get_messages_with_options(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") pager = queue_client.receive_messages(messages_per_page=4, visibility_timeout=20) result = list(pager) @@ -469,13 +466,13 @@ def test_get_messages_with_options(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy @@ -487,16 +484,16 @@ def test_get_messages_with_max_messages(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') - queue_client.send_message('message5') - queue_client.send_message('message6') - queue_client.send_message('message7') - queue_client.send_message('message8') - queue_client.send_message('message9') - queue_client.send_message('message10') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") + queue_client.send_message("message5") + queue_client.send_message("message6") + queue_client.send_message("message7") + queue_client.send_message("message8") + queue_client.send_message("message9") + queue_client.send_message("message10") pager = queue_client.receive_messages(max_messages=5) result = list(pager) @@ -506,13 +503,13 @@ def test_get_messages_with_max_messages(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy @@ -524,11 +521,11 @@ def test_get_messages_with_too_little_messages(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') - queue_client.send_message('message5') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") + queue_client.send_message("message5") pager = queue_client.receive_messages(max_messages=10) result = list(pager) @@ -538,13 +535,13 @@ def test_get_messages_with_too_little_messages(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy @@ -556,11 +553,11 @@ def test_get_messages_with_page_bigger_than_max(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') - queue_client.send_message('message5') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") + queue_client.send_message("message5") # Asserts with pytest.raises(ValueError): @@ -576,18 +573,18 @@ def test_get_messages_with_remainder(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') - queue_client.send_message('message5') - queue_client.send_message('message6') - queue_client.send_message('message7') - queue_client.send_message('message8') - queue_client.send_message('message9') - queue_client.send_message('message10') - queue_client.send_message('message11') - queue_client.send_message('message12') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") + queue_client.send_message("message5") + queue_client.send_message("message6") + queue_client.send_message("message7") + queue_client.send_message("message8") + queue_client.send_message("message9") + queue_client.send_message("message10") + queue_client.send_message("message11") + queue_client.send_message("message12") pager = queue_client.receive_messages(messages_per_page=3, max_messages=10) result = list(pager) @@ -604,23 +601,23 @@ def test_get_messages_with_remainder(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on for message in remainder_list: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy @@ -632,10 +629,10 @@ def test_peek_messages(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") result = queue_client.peek_messages() # Asserts @@ -643,12 +640,12 @@ def test_peek_messages(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert '' != message.content + assert "" != message.id + assert "" != message.content assert message.pop_receipt is None assert 0 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on + assert "" != message.inserted_on + assert "" != message.expires_on assert message.next_visible_on is None @QueuePreparer() @@ -661,10 +658,10 @@ def test_peek_messages_with_options(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") result = queue_client.peek_messages(max_messages=4) # Asserts @@ -672,12 +669,12 @@ def test_peek_messages_with_options(self, **kwargs): assert 4 == len(result) for message in result: assert message is not None - assert '' != message.id - assert '' != message.content + assert "" != message.id + assert "" != message.content assert message.pop_receipt is None assert 0 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on + assert "" != message.inserted_on + assert "" != message.expires_on assert message.next_visible_on is None @QueuePreparer() @@ -690,10 +687,10 @@ def test_clear_messages(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") queue_client.clear_messages() result = queue_client.peek_messages() @@ -711,10 +708,10 @@ def test_delete_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') - queue_client.send_message('message2') - queue_client.send_message('message3') - queue_client.send_message('message4') + queue_client.send_message("message1") + queue_client.send_message("message2") + queue_client.send_message("message3") + queue_client.send_message("message4") message = next(queue_client.receive_messages()) queue_client.delete_message(message) @@ -735,13 +732,12 @@ def test_update_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") messages = queue_client.receive_messages() list_result1 = next(messages) message = queue_client.update_message( - list_result1.id, - pop_receipt=list_result1.pop_receipt, - visibility_timeout=0) + list_result1.id, pop_receipt=list_result1.pop_receipt, visibility_timeout=0 + ) list_result2 = next(messages) # Asserts @@ -756,7 +752,7 @@ def test_update_message(self, **kwargs): message = list_result2 assert message is not None assert list_result1.id == message.id - assert 'message1' == message.content + assert "message1" == message.content assert 2 == message.dequeue_count assert message.pop_receipt is not None assert message.inserted_on is not None @@ -773,15 +769,13 @@ def test_update_message_content(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") messages = queue_client.receive_messages() list_result1 = next(messages) message = queue_client.update_message( - list_result1.id, - pop_receipt=list_result1.pop_receipt, - visibility_timeout=0, - content='new text') + list_result1.id, pop_receipt=list_result1.pop_receipt, visibility_timeout=0, content="new text" + ) list_result2 = next(messages) # Asserts @@ -790,14 +784,14 @@ def test_update_message_content(self, **kwargs): assert message.pop_receipt is not None assert message.next_visible_on is not None assert isinstance(message.next_visible_on, datetime) - assert 'new text' == message.content + assert "new text" == message.content # Get response assert list_result2 is not None message = list_result2 assert message is not None assert list_result1.id == message.id - assert 'new text' == message.content + assert "new text" == message.content assert 2 == message.dequeue_count assert message.pop_receipt is not None assert message.inserted_on is not None @@ -814,7 +808,7 @@ def test_account_sas(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") token = self.generate_sas( generate_account_sas, qsc.account_name, @@ -822,7 +816,7 @@ def test_account_sas(self, **kwargs): ResourceTypes(object=True), AccountSasPermissions(read=True), datetime.utcnow() + timedelta(hours=1), - datetime.utcnow() - timedelta(minutes=5) + datetime.utcnow() - timedelta(minutes=5), ) # Act @@ -839,8 +833,8 @@ def test_account_sas(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content @QueuePreparer() @recorded_by_proxy @@ -853,7 +847,7 @@ def test_azure_named_key_credential_access(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), named_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") # Act result = queue_client.peek_messages() @@ -868,8 +862,8 @@ def test_account_sas_raises_if_sas_already_in_uri(self, **kwargs): with pytest.raises(ValueError): QueueServiceClient( - self.account_url(storage_account_name, "queue") + "?sig=foo", - credential=AzureSasCredential("?foo=bar")) + self.account_url(storage_account_name, "queue") + "?sig=foo", credential=AzureSasCredential("?foo=bar") + ) @pytest.mark.live_test_only @QueuePreparer() @@ -905,7 +899,7 @@ def test_sas_read(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, @@ -913,7 +907,7 @@ def test_sas_read(self, **kwargs): queue_client.credential.account_key, QueueSasPermissions(read=True), datetime.utcnow() + timedelta(hours=1), - datetime.utcnow() - timedelta(minutes=5) + datetime.utcnow() - timedelta(minutes=5), ) # Act @@ -928,8 +922,8 @@ def test_sas_read(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content @QueuePreparer() @recorded_by_proxy @@ -955,11 +949,11 @@ def test_sas_add(self, **kwargs): queue_url=queue_client.url, credential=token, ) - result = service.send_message('addedmessage') + result = service.send_message("addedmessage") # Assert result = next(queue_client.receive_messages()) - assert 'addedmessage' == result.content + assert "addedmessage" == result.content @QueuePreparer() @recorded_by_proxy @@ -971,7 +965,7 @@ def test_sas_update(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, @@ -992,12 +986,12 @@ def test_sas_update(self, **kwargs): result.id, pop_receipt=result.pop_receipt, visibility_timeout=0, - content='updatedmessage1', + content="updatedmessage1", ) # Assert result = next(messages) - assert 'updatedmessage1' == result.content + assert "updatedmessage1" == result.content @QueuePreparer() @recorded_by_proxy @@ -1009,7 +1003,7 @@ def test_sas_process(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, @@ -1028,39 +1022,39 @@ def test_sas_process(self, **kwargs): # Assert assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content @QueuePreparer() @recorded_by_proxy def test_sas_signed_identifier(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - variables = kwargs.pop('variables', {}) + variables = kwargs.pop("variables", {}) # Arrange access_policy = AccessPolicy() - start_time = self.get_datetime_variable(variables, 'start_time', datetime.utcnow() - timedelta(hours=1)) - expiry_time = self.get_datetime_variable(variables, 'expiry_time', datetime.utcnow() + timedelta(hours=1)) + start_time = self.get_datetime_variable(variables, "start_time", datetime.utcnow() - timedelta(hours=1)) + expiry_time = self.get_datetime_variable(variables, "expiry_time", datetime.utcnow() + timedelta(hours=1)) access_policy.start = start_time access_policy.expiry = expiry_time access_policy.permission = QueueSasPermissions(read=True) - identifiers = {'testid': access_policy} + identifiers = {"testid": access_policy} qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() resp = queue_client.set_queue_access_policy(identifiers) - queue_client.send_message('message1') + queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, queue_client.queue_name, queue_client.credential.account_key, - policy_id='testid' + policy_id="testid", ) # Act @@ -1075,8 +1069,8 @@ def test_sas_signed_identifier(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content return variables @@ -1184,23 +1178,23 @@ def test_set_queue_acl_with_empty_signed_identifier(self, **kwargs): queue_client.create_queue() # Act - queue_client.set_queue_access_policy(signed_identifiers={'empty': None}) + queue_client.set_queue_access_policy(signed_identifiers={"empty": None}) # Assert acl = queue_client.get_queue_access_policy() assert acl is not None assert len(acl) == 1 - assert acl['empty'] is not None - assert acl['empty'].permission is None - assert acl['empty'].expiry is None - assert acl['empty'].start is None + assert acl["empty"] is not None + assert acl["empty"].permission is None + assert acl["empty"].expiry is None + assert acl["empty"].start is None @QueuePreparer() @recorded_by_proxy def test_set_queue_acl_with_signed_identifiers(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - variables = kwargs.pop('variables', {}) + variables = kwargs.pop("variables", {}) # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) @@ -1208,12 +1202,10 @@ def test_set_queue_acl_with_signed_identifiers(self, **kwargs): queue_client.create_queue() # Act - expiry_time = self.get_datetime_variable(variables, 'expiry_time', datetime.utcnow() + timedelta(hours=1)) - start_time = self.get_datetime_variable(variables, 'start_time', datetime.utcnow() - timedelta(minutes=5)) - access_policy = AccessPolicy(permission=QueueSasPermissions(read=True), - expiry=expiry_time, - start=start_time) - identifiers = {'testid': access_policy} + expiry_time = self.get_datetime_variable(variables, "expiry_time", datetime.utcnow() + timedelta(hours=1)) + start_time = self.get_datetime_variable(variables, "start_time", datetime.utcnow() - timedelta(minutes=5)) + access_policy = AccessPolicy(permission=QueueSasPermissions(read=True), expiry=expiry_time, start=start_time) + identifiers = {"testid": access_policy} resp = queue_client.set_queue_access_policy(signed_identifiers=identifiers) @@ -1222,7 +1214,7 @@ def test_set_queue_acl_with_signed_identifiers(self, **kwargs): acl = queue_client.get_queue_access_policy() assert acl is not None assert len(acl) == 1 - assert 'testid' in acl + assert "testid" in acl return variables @@ -1240,7 +1232,7 @@ def test_set_queue_acl_too_many_ids(self, **kwargs): # Act identifiers = {} for i in range(0, 16): - identifiers[f'id{i}'] = AccessPolicy() + identifiers[f"id{i}"] = AccessPolicy() # Assert with pytest.raises(ValueError): @@ -1270,7 +1262,7 @@ def test_unicode_create_queue_unicode_name(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - queue_name = '啊齄丂狛狜' + queue_name = "啊齄丂狛狜" with pytest.raises(HttpResponseError): # not supported - queue name must be alphanumeric, lowercase @@ -1289,14 +1281,14 @@ def test_unicode_get_messages_unicode_data(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1㚈') + queue_client.send_message("message1㚈") message = next(queue_client.receive_messages()) # Asserts assert message is not None - assert '' != message.id - assert 'message1㚈' == message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "message1㚈" == message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count assert isinstance(message.inserted_on, datetime) assert isinstance(message.expires_on, datetime) @@ -1312,19 +1304,19 @@ def test_unicode_update_message_unicode_data(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) queue_client.create_queue() - queue_client.send_message('message1') + queue_client.send_message("message1") messages = queue_client.receive_messages() list_result1 = next(messages) - list_result1.content = '啊齄丂狛狜' + list_result1.content = "啊齄丂狛狜" queue_client.update_message(list_result1, visibility_timeout=0) # Asserts message = next(messages) assert message is not None assert list_result1.id == message.id - assert '啊齄丂狛狜' == message.content - assert '' != message.pop_receipt + assert "啊齄丂狛狜" == message.content + assert "" != message.pop_receipt assert 2 == message.dequeue_count assert isinstance(message.inserted_on, datetime) assert isinstance(message.expires_on, datetime) @@ -1340,8 +1332,8 @@ def test_transport_closed_only_once(self, **kwargs): prefix = TEST_QUEUE_PREFIX queue_name = self.get_resource_name(prefix) with QueueServiceClient( - self.account_url(storage_account_name, "queue"), - credential=storage_account_key, transport=transport) as qsc: + self.account_url(storage_account_name, "queue"), credential=storage_account_key, transport=transport + ) as qsc: qsc.get_service_properties() assert transport.session is not None with qsc.get_queue_client(queue_name) as qc: @@ -1362,8 +1354,9 @@ def test_storage_account_audience_queue_service_client(self, **kwargs): # Act token_credential = self.get_credential(QueueServiceClient) qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), credential=token_credential, - audience=f'https://{storage_account_name}.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + credential=token_credential, + audience=f"https://{storage_account_name}.queue.core.windows.net", ) # Assert @@ -1383,8 +1376,9 @@ def test_bad_audience_queue_service_client(self, **kwargs): # Act token_credential = self.get_credential(QueueServiceClient) qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), credential=token_credential, - audience=f'https://badaudience.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + credential=token_credential, + audience=f"https://badaudience.queue.core.windows.net", ) # Will not raise ClientAuthenticationError despite bad audience due to Bearer Challenge @@ -1397,14 +1391,16 @@ def test_storage_account_audience_queue_client(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - queue = QueueClient(self.account_url(storage_account_name, "queue"), 'testqueue1', storage_account_key) + queue = QueueClient(self.account_url(storage_account_name, "queue"), "testqueue1", storage_account_key) queue.create_queue() # Act token_credential = self.get_credential(QueueServiceClient) queue = QueueClient( - self.account_url(storage_account_name, "queue"), 'testqueue1', credential=token_credential, - audience=f'https://{storage_account_name}.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + "testqueue1", + credential=token_credential, + audience=f"https://{storage_account_name}.queue.core.windows.net", ) # Assert @@ -1418,14 +1414,16 @@ def test_bad_audience_queue_client(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - queue = QueueClient(self.account_url(storage_account_name, "queue"), 'testqueue2', storage_account_key) + queue = QueueClient(self.account_url(storage_account_name, "queue"), "testqueue2", storage_account_key) queue.create_queue() # Act token_credential = self.get_credential(QueueServiceClient) queue = QueueClient( - self.account_url(storage_account_name, "queue"), 'testqueue2', credential=token_credential, - audience=f'https://badaudience.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + "testqueue2", + credential=token_credential, + audience=f"https://badaudience.queue.core.windows.net", ) # Will not raise ClientAuthenticationError despite bad audience due to Bearer Challenge @@ -1433,5 +1431,5 @@ def test_bad_audience_queue_client(self, **kwargs): # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_api_version.py b/sdk/storage/azure-storage-queue/tests/test_queue_api_version.py index 035b432a3837..2afdab424c7c 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_api_version.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_api_version.py @@ -13,6 +13,7 @@ # ------------------------------------------------------------------------------ + class TestStorageClient(StorageRecordedTestCase): def setUp(self): self.api_version_1 = "2019-02-02" @@ -22,9 +23,7 @@ def setUp(self): def test_service_client_api_version_property(self): self.setUp() - service_client = QueueServiceClient( - "https://foo.queue.core.windows.net/account", - credential="fake_key") + service_client = QueueServiceClient("https://foo.queue.core.windows.net/account", credential="fake_key") assert service_client.api_version == self.api_version_2 assert service_client._client._config.version == self.api_version_2 @@ -32,9 +31,8 @@ def test_service_client_api_version_property(self): service_client.api_version = "foo" service_client = QueueServiceClient( - "https://foo.queue.core.windows.net/account", - credential="fake_key", - api_version=self.api_version_1) + "https://foo.queue.core.windows.net/account", credential="fake_key", api_version=self.api_version_1 + ) assert service_client.api_version == self.api_version_1 assert service_client._client._config.version == self.api_version_1 @@ -48,15 +46,14 @@ def test_queue_client_api_version_property(self): "https://foo.queue.core.windows.net/account", "queue_name", credential="fake_key", - api_version=self.api_version_1) + api_version=self.api_version_1, + ) assert queue_client.api_version == self.api_version_1 assert queue_client._client._config.version == self.api_version_1 - queue_client = QueueClient( - "https://foo.queue.core.windows.net/account", - "queue_name", - credential="fake_key") + queue_client = QueueClient("https://foo.queue.core.windows.net/account", "queue_name", credential="fake_key") assert queue_client.api_version == self.api_version_2 assert queue_client._client._config.version == self.api_version_2 + # ------------------------------------------------------------------------------ diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_api_version_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_api_version_async.py index b4466d2cf653..3f3dc8ad6a5a 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_api_version_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_api_version_async.py @@ -13,6 +13,7 @@ # ------------------------------------------------------------------------------ + class TestAsyncStorageClient(AsyncStorageRecordedTestCase): def setUp(self): self.api_version_1 = "2019-02-02" @@ -22,9 +23,7 @@ def setUp(self): def test_service_client_api_version_property(self): self.setUp() - service_client = QueueServiceClient( - "https://foo.queue.core.windows.net/account", - credential="fake_key") + service_client = QueueServiceClient("https://foo.queue.core.windows.net/account", credential="fake_key") assert service_client.api_version == self.api_version_2 assert service_client._client._config.version == self.api_version_2 @@ -32,9 +31,8 @@ def test_service_client_api_version_property(self): service_client.api_version = "foo" service_client = QueueServiceClient( - "https://foo.queue.core.windows.net/account", - credential="fake_key", - api_version=self.api_version_1) + "https://foo.queue.core.windows.net/account", credential="fake_key", api_version=self.api_version_1 + ) assert service_client.api_version == self.api_version_1 assert service_client._client._config.version == self.api_version_1 @@ -48,15 +46,14 @@ def test_queue_client_api_version_property(self): "https://foo.queue.core.windows.net/account", "queue_name", credential="fake_key", - api_version=self.api_version_1) + api_version=self.api_version_1, + ) assert queue_client.api_version == self.api_version_1 assert queue_client._client._config.version == self.api_version_1 - queue_client = QueueClient( - "https://foo.queue.core.windows.net/account", - "queue_name", - credential="fake_key") + queue_client = QueueClient("https://foo.queue.core.windows.net/account", "queue_name", credential="fake_key") assert queue_client.api_version == self.api_version_2 assert queue_client._client._config.version == self.api_version_2 + # ------------------------------------------------------------------------------ diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_async.py index cd4662b62721..82007cab89c4 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_async.py @@ -13,7 +13,7 @@ ClientAuthenticationError, HttpResponseError, ResourceExistsError, - ResourceNotFoundError + ResourceNotFoundError, ) from azure.core.pipeline.transport import AioHttpTransport from azure.storage.queue import ( @@ -22,7 +22,7 @@ generate_account_sas, generate_queue_sas, QueueSasPermissions, - ResourceTypes + ResourceTypes, ) from azure.storage.queue.aio import QueueClient, QueueServiceClient @@ -32,7 +32,7 @@ from settings.testcase import QueuePreparer # ------------------------------------------------------------------------------ -TEST_QUEUE_PREFIX = 'pyqueueasync' +TEST_QUEUE_PREFIX = "pyqueueasync" # ------------------------------------------------------------------------------ # pylint: disable=locally-disabled, multiple-statements, fixme, too-many-lines @@ -107,15 +107,14 @@ async def test_create_queue_with_options(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) - await queue_client.create_queue( - metadata={'val1': 'test', 'val2': 'blah'}) + await queue_client.create_queue(metadata={"val1": "test", "val2": "blah"}) props = await queue_client.get_queue_properties() # Asserts assert 0 == props.approximate_message_count assert 2 == len(props.metadata) - assert 'test' == props.metadata['val1'] - assert 'blah' == props.metadata['val2'] + assert "test" == props.metadata["val1"] + assert "blah" == props.metadata["val2"] @QueuePreparer() @recorded_by_proxy_async @@ -127,16 +126,16 @@ async def test_get_messages_with_max_messages(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) await queue_client.create_queue() - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') - await queue_client.send_message('message5') - await queue_client.send_message('message6') - await queue_client.send_message('message7') - await queue_client.send_message('message8') - await queue_client.send_message('message9') - await queue_client.send_message('message10') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") + await queue_client.send_message("message5") + await queue_client.send_message("message6") + await queue_client.send_message("message7") + await queue_client.send_message("message8") + await queue_client.send_message("message9") + await queue_client.send_message("message10") result = [] async for m in queue_client.receive_messages(max_messages=5): result.append(m) @@ -147,13 +146,13 @@ async def test_get_messages_with_max_messages(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy_async @@ -162,16 +161,14 @@ async def test_get_messages_with_too_little_messages(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Action - qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), - storage_account_key) + qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) await queue_client.create_queue() - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') - await queue_client.send_message('message5') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") + await queue_client.send_message("message5") result = [] async for m in queue_client.receive_messages(max_messages=10): result.append(m) @@ -182,13 +179,13 @@ async def test_get_messages_with_too_little_messages(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy_async @@ -197,16 +194,14 @@ async def test_get_messages_with_page_bigger_than_max(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Action - qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), - storage_account_key) + qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) await queue_client.create_queue() - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') - await queue_client.send_message('message5') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") + await queue_client.send_message("message5") # Asserts with pytest.raises(ValueError): @@ -222,18 +217,18 @@ async def test_get_messages_with_remainder(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = self._get_queue_reference(qsc) await queue_client.create_queue() - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') - await queue_client.send_message('message5') - await queue_client.send_message('message6') - await queue_client.send_message('message7') - await queue_client.send_message('message8') - await queue_client.send_message('message9') - await queue_client.send_message('message10') - await queue_client.send_message('message11') - await queue_client.send_message('message12') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") + await queue_client.send_message("message5") + await queue_client.send_message("message6") + await queue_client.send_message("message7") + await queue_client.send_message("message8") + await queue_client.send_message("message9") + await queue_client.send_message("message10") + await queue_client.send_message("message11") + await queue_client.send_message("message12") result = [] async for m in queue_client.receive_messages(messages_per_page=3, max_messages=10): @@ -252,23 +247,23 @@ async def test_get_messages_with_remainder(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on for message in remainder: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy_async @@ -327,21 +322,19 @@ async def test_list_queues_with_options(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_list = [] - prefix = 'listqueue' + prefix = "listqueue" for i in range(0, 4): await self._create_queue(qsc, prefix + str(i), queue_list) # Action - generator1 = qsc.list_queues( - name_starts_with=prefix, - results_per_page=3).by_page() + generator1 = qsc.list_queues(name_starts_with=prefix, results_per_page=3).by_page() queues1 = [] async for el in await generator1.__anext__(): queues1.append(el) - generator2 = qsc.list_queues( - name_starts_with=prefix, - include_metadata=True).by_page(generator1.continuation_token) + generator2 = qsc.list_queues(name_starts_with=prefix, include_metadata=True).by_page( + generator1.continuation_token + ) queues2 = [] async for el in await generator2.__anext__(): queues2.append(el) @@ -351,12 +344,12 @@ async def test_list_queues_with_options(self, **kwargs): assert 3 == len(queues1) assert queues1[0] is not None assert queues1[0].metadata is None - assert '' != queues1[0].name + assert "" != queues1[0].name # Asserts assert queues2 is not None assert len(queue_list) - 3 <= len(queues2) assert queues2[0] is not None - assert '' != queues2[0].name + assert "" != queues2[0].name @QueuePreparer() @recorded_by_proxy_async @@ -367,13 +360,10 @@ async def test_list_queues_with_metadata(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = await self._create_queue(qsc) - await queue.set_queue_metadata(metadata={'val1': 'test', 'val2': 'blah'}) + await queue.set_queue_metadata(metadata={"val1": "test", "val2": "blah"}) listed_queue = [] - async for q in qsc.list_queues( - name_starts_with=queue.queue_name, - results_per_page=1, - include_metadata=True): + async for q in qsc.list_queues(name_starts_with=queue.queue_name, results_per_page=1, include_metadata=True): listed_queue.append(q) listed_queue = listed_queue[0] @@ -382,7 +372,7 @@ async def test_list_queues_with_metadata(self, **kwargs): assert queue.queue_name == listed_queue.name assert listed_queue.metadata is not None assert len(listed_queue.metadata) == 2 - assert listed_queue.metadata['val1'] == 'test' + assert listed_queue.metadata["val1"] == "test" @QueuePreparer() @recorded_by_proxy_async @@ -400,7 +390,7 @@ async def test_list_queues_account_sas(self, **kwargs): storage_account_key, ResourceTypes(service=True), AccountSasPermissions(list=True), - datetime.utcnow() + timedelta(hours=1) + datetime.utcnow() + timedelta(hours=1), ) # Act @@ -421,7 +411,7 @@ async def test_set_queue_metadata(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - metadata = {'hello': 'world', 'number': '43'} + metadata = {"hello": "world", "number": "43"} queue = await self._create_queue(qsc) # Act @@ -440,7 +430,7 @@ async def test_get_queue_metadata_message_count(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") props = await queue_client.get_queue_properties() # Asserts @@ -471,7 +461,7 @@ async def test_queue_not_exists(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - queue = qsc.get_queue_client(self.get_resource_name('missing')) + queue = qsc.get_queue_client(self.get_resource_name("missing")) # Act with pytest.raises(ResourceNotFoundError): await queue.get_queue_properties() @@ -487,18 +477,18 @@ async def test_put_message(self, **kwargs): # Action. No exception means pass. No asserts needed. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - message = await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + message = await queue_client.send_message("message4") # Asserts assert message is not None - assert '' != message.id + assert "" != message.id assert isinstance(message.inserted_on, datetime) assert isinstance(message.expires_on, datetime) - assert '' != message.pop_receipt - assert 'message4' == message.content + assert "" != message.pop_receipt + assert "message4" == message.content @QueuePreparer() @recorded_by_proxy_async @@ -510,7 +500,7 @@ async def test_put_message_large_time_to_live(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) # There should be no upper bound on a queue message's time to live - await queue_client.send_message('message1', time_to_live=1024*1024*1024) + await queue_client.send_message("message1", time_to_live=1024 * 1024 * 1024) # Act messages = await queue_client.peek_messages() @@ -527,7 +517,7 @@ async def test_put_message_infinite_time_to_live(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1', time_to_live=-1) + await queue_client.send_message("message1", time_to_live=-1) # Act messages = await queue_client.peek_messages() @@ -544,10 +534,10 @@ async def test_get_messages(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") messages = [] async for m in queue_client.receive_messages(): messages.append(m) @@ -557,9 +547,9 @@ async def test_get_messages(self, **kwargs): # Asserts assert message is not None assert message is not None - assert '' != message.id - assert 'message1' == message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "message1" == message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count assert isinstance(message.inserted_on, datetime) @@ -577,9 +567,9 @@ async def test_receive_one_message(self, **kwargs): queue_client = await self._create_queue(qsc) assert await queue_client.receive_message() is None - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") message1 = await queue_client.receive_message() message2 = await queue_client.receive_message() @@ -587,18 +577,18 @@ async def test_receive_one_message(self, **kwargs): # Asserts assert message1 is not None - assert '' != message1.id - assert 'message1' == message1.content - assert '' != message1.pop_receipt + assert "" != message1.id + assert "message1" == message1.content + assert "" != message1.pop_receipt assert 1 == message1.dequeue_count assert message2 is not None - assert '' != message2.id - assert 'message2' == message2.content - assert '' != message2.pop_receipt + assert "" != message2.id + assert "message2" == message2.content + assert "" != message2.pop_receipt assert 1 == message2.dequeue_count - assert 'message3' == peeked_message3[0].content + assert "message3" == peeked_message3[0].content assert 0 == peeked_message3[0].dequeue_count @QueuePreparer() @@ -610,10 +600,10 @@ async def test_get_messages_with_options(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") pager = queue_client.receive_messages(messages_per_page=4, visibility_timeout=20) result = [] async for el in pager: @@ -625,13 +615,13 @@ async def test_get_messages_with_options(self, **kwargs): for message in result: assert message is not None - assert '' != message.id - assert '' != message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "" != message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on - assert '' != message.next_visible_on + assert "" != message.inserted_on + assert "" != message.expires_on + assert "" != message.next_visible_on @QueuePreparer() @recorded_by_proxy_async @@ -642,10 +632,10 @@ async def test_peek_messages(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") result = await queue_client.peek_messages() # Asserts @@ -653,12 +643,12 @@ async def test_peek_messages(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert '' != message.content + assert "" != message.id + assert "" != message.content assert message.pop_receipt is None assert 0 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on + assert "" != message.inserted_on + assert "" != message.expires_on assert message.next_visible_on is None @QueuePreparer() @@ -670,10 +660,10 @@ async def test_peek_messages_with_options(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") result = await queue_client.peek_messages(max_messages=4) # Asserts @@ -681,12 +671,12 @@ async def test_peek_messages_with_options(self, **kwargs): assert 4 == len(result) for message in result: assert message is not None - assert '' != message.id - assert '' != message.content + assert "" != message.id + assert "" != message.content assert message.pop_receipt is None assert 0 == message.dequeue_count - assert '' != message.inserted_on - assert '' != message.expires_on + assert "" != message.inserted_on + assert "" != message.expires_on assert message.next_visible_on is None @QueuePreparer() @@ -698,10 +688,10 @@ async def test_clear_messages(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") await queue_client.clear_messages() result = await queue_client.peek_messages() @@ -718,10 +708,10 @@ async def test_delete_message(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') - await queue_client.send_message('message2') - await queue_client.send_message('message3') - await queue_client.send_message('message4') + await queue_client.send_message("message1") + await queue_client.send_message("message2") + await queue_client.send_message("message3") + await queue_client.send_message("message4") messages = [] async for m in queue_client.receive_messages(): messages.append(m) @@ -730,7 +720,7 @@ async def test_delete_message(self, **kwargs): messages.append(m) # Asserts assert messages is not None - assert 3 == len(messages)-1 + assert 3 == len(messages) - 1 @QueuePreparer() @recorded_by_proxy_async @@ -741,15 +731,14 @@ async def test_update_message(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") messages = [] async for m in queue_client.receive_messages(): messages.append(m) list_result1 = messages[0] message = await queue_client.update_message( - list_result1.id, - pop_receipt=list_result1.pop_receipt, - visibility_timeout=0) + list_result1.id, pop_receipt=list_result1.pop_receipt, visibility_timeout=0 + ) messages = [] async for m in queue_client.receive_messages(): messages.append(m) @@ -767,7 +756,7 @@ async def test_update_message(self, **kwargs): message = list_result2 assert message is not None assert list_result1.id == message.id - assert 'message1' == message.content + assert "message1" == message.content assert 2 == message.dequeue_count assert message.pop_receipt is not None assert message.inserted_on is not None @@ -783,18 +772,16 @@ async def test_update_message_content(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") messages = [] async for m in queue_client.receive_messages(): messages.append(m) list_result1 = messages[0] message = await queue_client.update_message( - list_result1.id, - pop_receipt=list_result1.pop_receipt, - visibility_timeout=0, - content='new text') - assert 'new text' == message.content + list_result1.id, pop_receipt=list_result1.pop_receipt, visibility_timeout=0, content="new text" + ) + assert "new text" == message.content messages = [] async for m in queue_client.receive_messages(): @@ -813,7 +800,7 @@ async def test_update_message_content(self, **kwargs): message = list_result2 assert message is not None assert list_result1.id == message.id - assert 'new text' == message.content + assert "new text" == message.content assert 2 == message.dequeue_count assert message.pop_receipt is not None assert message.inserted_on is not None @@ -830,7 +817,7 @@ async def test_account_sas(self, **kwargs): # Arrange queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") token = self.generate_sas( generate_account_sas, qsc.account_name, @@ -838,7 +825,7 @@ async def test_account_sas(self, **kwargs): ResourceTypes(object=True), AccountSasPermissions(read=True), datetime.utcnow() + timedelta(hours=1), - datetime.utcnow() - timedelta(minutes=5) + datetime.utcnow() - timedelta(minutes=5), ) # Act @@ -855,8 +842,8 @@ async def test_account_sas(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content @QueuePreparer() @recorded_by_proxy_async @@ -864,14 +851,12 @@ async def test_azure_named_key_credential_access(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - # Arrange named_key = AzureNamedKeyCredential(storage_account_name, storage_account_key) - qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), named_key) + qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), named_key) queue_client = self._get_queue_reference(qsc) await queue_client.create_queue() - await queue_client.send_message('message1') + await queue_client.send_message("message1") # Act result = await queue_client.peek_messages() @@ -886,8 +871,8 @@ async def test_account_sas_raises_if_sas_already_in_uri(self, **kwargs): with pytest.raises(ValueError): QueueServiceClient( - self.account_url(storage_account_name, "queue") + "?sig=foo", - credential=AzureSasCredential("?foo=bar")) + self.account_url(storage_account_name, "queue") + "?sig=foo", credential=AzureSasCredential("?foo=bar") + ) @pytest.mark.live_test_only @QueuePreparer() @@ -899,24 +884,18 @@ async def test_token_credential(self, **kwargs): token_credential = self.get_credential(QueueServiceClient, is_async=True) # Action 1: make sure token works - service = QueueServiceClient( - self.account_url(storage_account_name, "queue"), - credential=token_credential) + service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=token_credential) queues = await service.get_service_properties() assert queues is not None # Action 2: change token value to make request fail fake_credential = AsyncFakeCredential() - service = QueueServiceClient( - self.account_url(storage_account_name, "queue"), - credential=fake_credential) + service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=fake_credential) with pytest.raises(ClientAuthenticationError): await service.get_service_properties() # Action 3: update token to make it working again - service = QueueServiceClient( - self.account_url(storage_account_name, "queue"), - credential=token_credential) + service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=token_credential) queues = await service.get_service_properties() # Not raise means success assert queues is not None @@ -930,7 +909,7 @@ async def test_sas_read(self, **kwargs): # Arrange queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, @@ -938,7 +917,7 @@ async def test_sas_read(self, **kwargs): queue_client.credential.account_key, QueueSasPermissions(read=True), datetime.utcnow() + timedelta(hours=1), - datetime.utcnow() - timedelta(minutes=5) + datetime.utcnow() - timedelta(minutes=5), ) # Act @@ -953,8 +932,8 @@ async def test_sas_read(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content @QueuePreparer() @recorded_by_proxy_async @@ -980,15 +959,15 @@ async def test_sas_add(self, **kwargs): queue_url=queue_client.url, credential=token, ) - result = await service.send_message('addedmessage') - assert 'addedmessage' == result.content + result = await service.send_message("addedmessage") + assert "addedmessage" == result.content # Assert messages = [] async for m in queue_client.receive_messages(): messages.append(m) result = messages[0] - assert 'addedmessage' == result.content + assert "addedmessage" == result.content @QueuePreparer() @recorded_by_proxy_async @@ -1000,7 +979,7 @@ async def test_sas_update(self, **kwargs): # Arrange queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, @@ -1023,7 +1002,7 @@ async def test_sas_update(self, **kwargs): result.id, pop_receipt=result.pop_receipt, visibility_timeout=0, - content='updatedmessage1', + content="updatedmessage1", ) # Assert @@ -1031,7 +1010,7 @@ async def test_sas_update(self, **kwargs): async for m in queue_client.receive_messages(): messages.append(m) result = messages[0] - assert 'updatedmessage1' == result.content + assert "updatedmessage1" == result.content @QueuePreparer() @recorded_by_proxy_async @@ -1043,7 +1022,7 @@ async def test_sas_process(self, **kwargs): # Arrange queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, @@ -1065,39 +1044,39 @@ async def test_sas_process(self, **kwargs): # Assert assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content @QueuePreparer() @recorded_by_proxy_async async def test_sas_signed_identifier(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - variables = kwargs.pop('variables', {}) + variables = kwargs.pop("variables", {}) qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange access_policy = AccessPolicy() - start_time = self.get_datetime_variable(variables, 'start_time', datetime.utcnow() - timedelta(hours=1)) - expiry_time = self.get_datetime_variable(variables, 'expiry_time', datetime.utcnow() + timedelta(hours=1)) + start_time = self.get_datetime_variable(variables, "start_time", datetime.utcnow() - timedelta(hours=1)) + expiry_time = self.get_datetime_variable(variables, "expiry_time", datetime.utcnow() + timedelta(hours=1)) access_policy.start = start_time access_policy.expiry = expiry_time access_policy.permission = QueueSasPermissions(read=True) - identifiers = {'testid': access_policy} + identifiers = {"testid": access_policy} queue_client = await self._create_queue(qsc) resp = await queue_client.set_queue_access_policy(identifiers) - await queue_client.send_message('message1') + await queue_client.send_message("message1") token = self.generate_sas( generate_queue_sas, queue_client.account_name, queue_client.queue_name, queue_client.credential.account_key, - policy_id='testid' + policy_id="testid", ) # Act @@ -1112,8 +1091,8 @@ async def test_sas_signed_identifier(self, **kwargs): assert 1 == len(result) message = result[0] assert message is not None - assert '' != message.id - assert 'message1' == message.content + assert "" != message.id + assert "message1" == message.content return variables @@ -1216,35 +1195,33 @@ async def test_set_queue_acl_with_empty_signed_identifier(self, **kwargs): queue_client = await self._create_queue(qsc) # Act - await queue_client.set_queue_access_policy(signed_identifiers={'empty': None}) + await queue_client.set_queue_access_policy(signed_identifiers={"empty": None}) # Assert acl = await queue_client.get_queue_access_policy() assert acl is not None assert len(acl) == 1 - assert acl['empty'] is not None - assert acl['empty'].permission is None - assert acl['empty'].expiry is None - assert acl['empty'].start is None + assert acl["empty"] is not None + assert acl["empty"].permission is None + assert acl["empty"].expiry is None + assert acl["empty"].start is None @QueuePreparer() @recorded_by_proxy_async async def test_set_queue_acl_with_signed_identifiers(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - variables = kwargs.pop('variables', {}) + variables = kwargs.pop("variables", {}) # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) # Act - expiry_time = self.get_datetime_variable(variables, 'expiry_time', datetime.utcnow() + timedelta(hours=1)) - start_time = self.get_datetime_variable(variables, 'start_time', datetime.utcnow() - timedelta(minutes=5)) - access_policy = AccessPolicy(permission=QueueSasPermissions(read=True), - expiry=expiry_time, - start=start_time) - identifiers = {'testid': access_policy} + expiry_time = self.get_datetime_variable(variables, "expiry_time", datetime.utcnow() + timedelta(hours=1)) + start_time = self.get_datetime_variable(variables, "start_time", datetime.utcnow() - timedelta(minutes=5)) + access_policy = AccessPolicy(permission=QueueSasPermissions(read=True), expiry=expiry_time, start=start_time) + identifiers = {"testid": access_policy} resp = await queue_client.set_queue_access_policy(signed_identifiers=identifiers) @@ -1253,7 +1230,7 @@ async def test_set_queue_acl_with_signed_identifiers(self, **kwargs): acl = await queue_client.get_queue_access_policy() assert acl is not None assert len(acl) == 1 - assert 'testid' in acl + assert "testid" in acl return variables @@ -1270,7 +1247,7 @@ async def test_set_queue_acl_too_many_ids(self, **kwargs): # Act identifiers = {} for i in range(0, 16): - identifiers[f'id{i}'] = AccessPolicy() + identifiers[f"id{i}"] = AccessPolicy() # Assert with pytest.raises(ValueError): @@ -1298,7 +1275,7 @@ async def test_unicode_create_queue_unicode_name(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - queue_name = '啊齄丂狛狜' + queue_name = "啊齄丂狛狜" with pytest.raises(HttpResponseError): # not supported - queue name must be alphanumeric, lowercase @@ -1316,15 +1293,15 @@ async def test_unicode_get_messages_unicode_data(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1㚈') + await queue_client.send_message("message1㚈") message = None async for m in queue_client.receive_messages(): message = m # Asserts assert message is not None - assert '' != message.id - assert 'message1㚈' == message.content - assert '' != message.pop_receipt + assert "" != message.id + assert "message1㚈" == message.content + assert "" != message.pop_receipt assert 1 == message.dequeue_count assert isinstance(message.inserted_on, datetime) assert isinstance(message.expires_on, datetime) @@ -1339,13 +1316,13 @@ async def test_unicode_update_message_unicode_data(self, **kwargs): # Action qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue_client = await self._create_queue(qsc) - await queue_client.send_message('message1') + await queue_client.send_message("message1") messages = [] async for m in queue_client.receive_messages(): messages.append(m) list_result1 = messages[0] - list_result1.content = '啊齄丂狛狜' + list_result1.content = "啊齄丂狛狜" await queue_client.update_message(list_result1, visibility_timeout=0) messages = [] async for m in queue_client.receive_messages(): @@ -1354,8 +1331,8 @@ async def test_unicode_update_message_unicode_data(self, **kwargs): message = messages[0] assert message is not None assert list_result1.id == message.id - assert '啊齄丂狛狜' == message.content - assert '' != message.pop_receipt + assert "啊齄丂狛狜" == message.content + assert "" != message.pop_receipt assert 2 == message.dequeue_count assert isinstance(message.inserted_on, datetime) assert isinstance(message.expires_on, datetime) @@ -1371,8 +1348,8 @@ async def test_transport_closed_only_once(self, **kwargs): prefix = TEST_QUEUE_PREFIX queue_name = self.get_resource_name(prefix) async with QueueServiceClient( - self.account_url(storage_account_name, "queue"), - credential=storage_account_key, transport=transport) as qsc: + self.account_url(storage_account_name, "queue"), credential=storage_account_key, transport=transport + ) as qsc: await qsc.get_service_properties() assert transport.session is not None async with qsc.get_queue_client(queue_name) as qc: @@ -1393,8 +1370,9 @@ async def test_storage_account_audience_queue_service_client(self, **kwargs): # Act token_credential = self.get_credential(QueueServiceClient, is_async=True) qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), credential=token_credential, - audience=f'https://{storage_account_name}.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + credential=token_credential, + audience=f"https://{storage_account_name}.queue.core.windows.net", ) # Assert @@ -1414,8 +1392,9 @@ async def test_bad_audience_queue_service_client(self, **kwargs): # Act token_credential = self.get_credential(QueueServiceClient, is_async=True) qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), credential=token_credential, - audience=f'https://badaudience.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + credential=token_credential, + audience=f"https://badaudience.queue.core.windows.net", ) # Will not raise ClientAuthenticationError despite bad audience due to Bearer Challenge @@ -1435,8 +1414,10 @@ async def test_storage_account_audience_queue_client(self, **kwargs): # Act token_credential = self.get_credential(QueueServiceClient, is_async=True) queue = QueueClient( - self.account_url(storage_account_name, "queue"), queue_name, credential=token_credential, - audience=f'https://{storage_account_name}.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + queue_name, + credential=token_credential, + audience=f"https://{storage_account_name}.queue.core.windows.net", ) # Assert @@ -1457,13 +1438,16 @@ async def test_bad_audience_queue_client(self, **kwargs): # Act token_credential = self.get_credential(QueueServiceClient, is_async=True) queue = QueueClient( - self.account_url(storage_account_name, "queue"), queue_name, credential=token_credential, - audience=f'https://badaudience.queue.core.windows.net' + self.account_url(storage_account_name, "queue"), + queue_name, + credential=token_credential, + audience=f"https://badaudience.queue.core.windows.net", ) # Will not raise ClientAuthenticationError despite bad audience due to Bearer Challenge await queue.get_queue_properties() + # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_client.py b/sdk/storage/azure-storage-queue/tests/test_queue_client.py index e237c6fb766f..e2d999f57a6a 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_client.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_client.py @@ -14,7 +14,7 @@ QueueClient, QueueServiceClient, ResourceTypes, - VERSION + VERSION, ) from devtools_testutils import recorded_by_proxy @@ -23,13 +23,14 @@ # ------------------------------------------------------------------------------ SERVICES = { - QueueServiceClient: 'queue', - QueueClient: 'queue', + QueueServiceClient: "queue", + QueueClient: "queue", } -_CONNECTION_ENDPOINTS = {'queue': 'QueueEndpoint'} +_CONNECTION_ENDPOINTS = {"queue": "QueueEndpoint"} + +_CONNECTION_ENDPOINTS_SECONDARY = {"queue": "QueueSecondaryEndpoint"} -_CONNECTION_ENDPOINTS_SECONDARY = {'queue': 'QueueSecondaryEndpoint'} class TestStorageQueueClient(StorageRecordedTestCase): def setUp(self): @@ -42,8 +43,8 @@ def validate_standard_account_endpoints(self, service, url_type, account_name, a assert service.account_name == account_name assert service.credential.account_name == account_name assert service.credential.account_key == account_key - assert f'{account_name}.{url_type}.core.windows.net' in service.url - assert f'{account_name}-secondary.{url_type}.core.windows.net' in service.secondary_endpoint + assert f"{account_name}.{url_type}.core.windows.net" in service.url + assert f"{account_name}-secondary.{url_type}.core.windows.net" in service.secondary_endpoint def generate_fake_sas_token(self): fake_key = "a" * 30 + "b" * 30 @@ -68,27 +69,29 @@ def test_create_service_with_key(self, **kwargs): for client, url in SERVICES.items(): # Act service = client( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="foo" + ) # Assert self.validate_standard_account_endpoints(service, url, storage_account_name, storage_account_key) - assert service.scheme == 'https' + assert service.scheme == "https" @QueuePreparer() def test_create_service_with_connection_string(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - for service_type in SERVICES.items(): # Act service = service_type[0].from_connection_string( - self.connection_string(storage_account_name, storage_account_key), queue_name="test") + self.connection_string(storage_account_name, storage_account_key), queue_name="test" + ) # Assert self.validate_standard_account_endpoints( - service, service_type[1], storage_account_name, storage_account_key) - assert service.scheme == 'https' + service, service_type[1], storage_account_name, storage_account_key + ) + assert service.scheme == "https" @QueuePreparer() def test_create_service_with_sas(self, **kwargs): @@ -100,12 +103,13 @@ def test_create_service_with_sas(self, **kwargs): for service_type in SERVICES: # Act service = service_type( - self.account_url(storage_account_name, "queue"), credential=self.sas_token, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=self.sas_token, queue_name="foo" + ) # Assert assert service is not None assert service.account_name == storage_account_name - assert service.url.startswith('https://' + storage_account_name + '.queue.core.windows.net') + assert service.url.startswith("https://" + storage_account_name + ".queue.core.windows.net") assert service.url.endswith(self.sas_token) assert service.credential is None @@ -118,15 +122,16 @@ def test_create_service_with_token(self, **kwargs): for service_type in SERVICES: # Act service = service_type( - self.account_url(storage_account_name, "queue"), credential=self.token_credential, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=self.token_credential, queue_name="foo" + ) # Assert assert service is not None assert service.account_name == storage_account_name - assert service.url.startswith('https://' + storage_account_name + '.queue.core.windows.net') + assert service.url.startswith("https://" + storage_account_name + ".queue.core.windows.net") assert service.credential == self.token_credential - assert not hasattr(service.credential, 'account_key') - assert hasattr(service.credential, 'get_token') + assert not hasattr(service.credential, "account_key") + assert hasattr(service.credential, "get_token") @QueuePreparer() def test_create_service_with_token_and_http(self, **kwargs): @@ -137,8 +142,8 @@ def test_create_service_with_token_and_http(self, **kwargs): for service_type in SERVICES: # Act with pytest.raises(ValueError): - url = self.account_url(storage_account_name, "queue").replace('https', 'http') - service_type(url, credential=self.token_credential, queue_name='foo') + url = self.account_url(storage_account_name, "queue").replace("https", "http") + service_type(url, credential=self.token_credential, queue_name="foo") @QueuePreparer() def test_create_service_china(self, **kwargs): @@ -149,17 +154,26 @@ def test_create_service_china(self, **kwargs): for service_type in SERVICES.items(): # Act - url = self.account_url(storage_account_name, "queue").replace('core.windows.net', 'core.chinacloudapi.cn') - service = service_type[0]( - url, credential=storage_account_key, queue_name='foo') + url = self.account_url(storage_account_name, "queue").replace("core.windows.net", "core.chinacloudapi.cn") + service = service_type[0](url, credential=storage_account_key, queue_name="foo") # Assert assert service is not None assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith(f'https://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn') is True - assert service.secondary_endpoint.startswith(f'https://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn') is True + assert ( + service.primary_endpoint.startswith( + f"https://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn" + ) + is True + ) + assert ( + service.secondary_endpoint.startswith( + f"https://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn" + ) + is True + ) @QueuePreparer() def test_create_service_protocol(self, **kwargs): @@ -170,14 +184,14 @@ def test_create_service_protocol(self, **kwargs): for service_type in SERVICES.items(): # Act - url = self.account_url(storage_account_name, "queue").replace('https', 'http') - service = service_type[0]( - url, credential=storage_account_key, queue_name='foo') + url = self.account_url(storage_account_name, "queue").replace("https", "http") + service = service_type[0](url, credential=storage_account_key, queue_name="foo") # Assert self.validate_standard_account_endpoints( - service, service_type[1], storage_account_name, storage_account_key) - assert service.scheme == 'http' + service, service_type[1], storage_account_name, storage_account_key + ) + assert service.scheme == "http" @QueuePreparer() def test_create_service_empty_key(self, **kwargs): @@ -190,7 +204,7 @@ def test_create_service_empty_key(self, **kwargs): for service_type in QUEUE_SERVICES: # Act with pytest.raises(ValueError) as e: - test_service = service_type('testaccount', credential='', queue_name='foo') + test_service = service_type("testaccount", credential="", queue_name="foo") assert str(e.value) == "You need to provide either a SAS token or an account shared key to authenticate." @@ -204,14 +218,19 @@ def test_create_service_with_socket_timeout(self, **kwargs): for service_type in SERVICES.items(): # Act default_service = service_type[0]( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="foo" + ) service = service_type[0]( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, - queue_name='foo', connection_timeout=22) + self.account_url(storage_account_name, "queue"), + credential=storage_account_key, + queue_name="foo", + connection_timeout=22, + ) # Assert self.validate_standard_account_endpoints( - service, service_type[1], storage_account_name, storage_account_key) + service, service_type[1], storage_account_name, storage_account_key + ) assert service._client._client._pipeline._transport.connection_config.timeout == 22 assert default_service._client._client._pipeline._transport.connection_config.timeout in [20, (20, 2000)] @@ -222,16 +241,17 @@ def test_create_service_with_connection_string_key(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - conn_string = f'AccountName={storage_account_name};AccountKey={storage_account_key};' + conn_string = f"AccountName={storage_account_name};AccountKey={storage_account_key};" for service_type in SERVICES.items(): # Act - service = service_type[0].from_connection_string(conn_string, queue_name='foo') + service = service_type[0].from_connection_string(conn_string, queue_name="foo") # Assert self.validate_standard_account_endpoints( - service, service_type[1], storage_account_name, storage_account_key) - assert service.scheme == 'https' + service, service_type[1], storage_account_name, storage_account_key + ) + assert service.scheme == "https" @QueuePreparer() def test_create_service_with_connection_string_sas(self, **kwargs): @@ -239,16 +259,16 @@ def test_create_service_with_connection_string_sas(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - conn_string = f'AccountName={storage_account_name};SharedAccessSignature={self.sas_token};' + conn_string = f"AccountName={storage_account_name};SharedAccessSignature={self.sas_token};" for service_type in SERVICES: # Act - service = service_type.from_connection_string(conn_string, queue_name='foo') + service = service_type.from_connection_string(conn_string, queue_name="foo") # Assert assert service is not None assert service.account_name == storage_account_name - assert service.url.startswith('https://' + storage_account_name + '.queue.core.windows.net') + assert service.url.startswith("https://" + storage_account_name + ".queue.core.windows.net") assert service.url.endswith(self.sas_token) assert service.credential is None @@ -259,11 +279,11 @@ def test_create_service_with_connection_string_endpoint_protocol(self, **kwargs) # Arrange conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'DefaultEndpointsProtocol=http;' - 'EndpointSuffix=core.chinacloudapi.cn;' - ) + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "DefaultEndpointsProtocol=http;" + "EndpointSuffix=core.chinacloudapi.cn;" + ) for service_type in SERVICES.items(): # Act @@ -274,15 +294,25 @@ def test_create_service_with_connection_string_endpoint_protocol(self, **kwargs) assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith(f'http://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn/') is True - assert service.secondary_endpoint.startswith(f'http://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn') is True - assert service.scheme == 'http' + assert ( + service.primary_endpoint.startswith( + f"http://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn/" + ) + is True + ) + assert ( + service.secondary_endpoint.startswith( + f"http://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn" + ) + is True + ) + assert service.scheme == "http" @QueuePreparer() def test_create_service_with_connection_string_emulated(self, *args): # Arrange for service_type in SERVICES.items(): - conn_string = 'UseDevelopmentStorage=true;' + conn_string = "UseDevelopmentStorage=true;" # Act with pytest.raises(ValueError): @@ -296,10 +326,10 @@ def test_create_service_with_connection_string_custom_domain(self, **kwargs): # Arrange for service_type in SERVICES.items(): conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'QueueEndpoint=www.mydomain.com;' - ) + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "QueueEndpoint=www.mydomain.com;" + ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -309,8 +339,13 @@ def test_create_service_with_connection_string_custom_domain(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith(f'https://{storage_account_name}-secondary.queue.core.windows.net') is True + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert ( + service.secondary_endpoint.startswith( + f"https://{storage_account_name}-secondary.queue.core.windows.net" + ) + is True + ) @QueuePreparer() def test_create_service_with_conn_str_custom_domain_trailing_slash(self, **kwargs): @@ -320,10 +355,10 @@ def test_create_service_with_conn_str_custom_domain_trailing_slash(self, **kwarg # Arrange for service_type in SERVICES.items(): conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'QueueEndpoint=www.mydomain.com/;' - ) + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "QueueEndpoint=www.mydomain.com/;" + ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -332,8 +367,13 @@ def test_create_service_with_conn_str_custom_domain_trailing_slash(self, **kwarg assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith(f'https://{storage_account_name}-secondary.queue.core.windows.net') is True + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert ( + service.secondary_endpoint.startswith( + f"https://{storage_account_name}-secondary.queue.core.windows.net" + ) + is True + ) @QueuePreparer() def test_create_service_with_conn_str_custom_domain_sec_override(self, **kwargs): @@ -343,21 +383,22 @@ def test_create_service_with_conn_str_custom_domain_sec_override(self, **kwargs) # Arrange for service_type in SERVICES.items(): conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'QueueEndpoint=www.mydomain.com/;' - ) + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "QueueEndpoint=www.mydomain.com/;" + ) # Act service = service_type[0].from_connection_string( - conn_string, secondary_hostname="www-sec.mydomain.com", queue_name="foo") + conn_string, secondary_hostname="www-sec.mydomain.com", queue_name="foo" + ) # Assert assert service is not None assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith('https://www-sec.mydomain.com/') + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert service.secondary_endpoint.startswith("https://www-sec.mydomain.com/") @QueuePreparer() def test_create_service_with_conn_str_fails_if_sec_without_primary(self, **kwargs): @@ -367,9 +408,9 @@ def test_create_service_with_conn_str_fails_if_sec_without_primary(self, **kwarg for service_type in SERVICES.items(): # Arrange conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - f'{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www.mydomain.com;' + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + f"{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www.mydomain.com;" ) # Act @@ -385,10 +426,10 @@ def test_create_service_with_conn_str_succeeds_if_sec_with_primary(self, **kwarg for service_type in SERVICES.items(): # Arrange conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - f'{_CONNECTION_ENDPOINTS.get(service_type[1])}=www.mydomain.com;' - f'{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www-sec.mydomain.com;' + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + f"{_CONNECTION_ENDPOINTS.get(service_type[1])}=www.mydomain.com;" + f"{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www-sec.mydomain.com;" ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -398,8 +439,8 @@ def test_create_service_with_conn_str_succeeds_if_sec_with_primary(self, **kwarg assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith('https://www-sec.mydomain.com/') + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert service.secondary_endpoint.startswith("https://www-sec.mydomain.com/") @QueuePreparer() def test_create_service_with_custom_account_endpoint_path(self, **kwargs): @@ -409,10 +450,10 @@ def test_create_service_with_custom_account_endpoint_path(self, **kwargs): custom_account_url = "http://local-machine:11002/custom/account/path/" + self.sas_token for service_type in SERVICES.items(): conn_string = ( - f'DefaultEndpointsProtocol=http;' - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - f'QueueEndpoint={custom_account_url};' + f"DefaultEndpointsProtocol=http;" + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + f"QueueEndpoint={custom_account_url};" ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -421,27 +462,27 @@ def test_create_service_with_custom_account_endpoint_path(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_hostname == 'local-machine:11002/custom/account/path' + assert service.primary_hostname == "local-machine:11002/custom/account/path" service = QueueServiceClient(account_url=custom_account_url) assert service.account_name == None assert service.credential == None - assert service.primary_hostname == 'local-machine:11002/custom/account/path' - assert service.url.startswith('http://local-machine:11002/custom/account/path/?') + assert service.primary_hostname == "local-machine:11002/custom/account/path" + assert service.url.startswith("http://local-machine:11002/custom/account/path/?") service = QueueClient(account_url=custom_account_url, queue_name="foo") assert service.account_name == None assert service.queue_name == "foo" assert service.credential == None - assert service.primary_hostname == 'local-machine:11002/custom/account/path' - assert service.url.startswith('http://local-machine:11002/custom/account/path/foo?') + assert service.primary_hostname == "local-machine:11002/custom/account/path" + assert service.url.startswith("http://local-machine:11002/custom/account/path/foo?") service = QueueClient.from_queue_url("http://local-machine:11002/custom/account/path/foo" + self.sas_token) assert service.account_name == None assert service.queue_name == "foo" assert service.credential == None - assert service.primary_hostname == 'local-machine:11002/custom/account/path' - assert service.url.startswith('http://local-machine:11002/custom/account/path/foo?') + assert service.primary_hostname == "local-machine:11002/custom/account/path" + assert service.url.startswith("http://local-machine:11002/custom/account/path/foo?") @QueuePreparer() @recorded_by_proxy @@ -451,16 +492,16 @@ def test_request_callback_signed_header(self, **kwargs): # Arrange service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) - name = self.get_resource_name('cont') + name = self.get_resource_name("cont") # Act try: - headers = {'x-ms-meta-hello': 'world'} + headers = {"x-ms-meta-hello": "world"} queue = service.create_queue(name, headers=headers) # Assert metadata = queue.get_queue_properties().metadata - assert metadata == {'hello': 'world'} + assert metadata == {"hello": "world"} finally: service.delete_queue(name) @@ -472,7 +513,7 @@ def test_response_callback(self, **kwargs): # Arrange service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) - name = self.get_resource_name('cont') + name = self.get_resource_name("cont") queue = service.get_queue_client(name) # Act @@ -493,8 +534,8 @@ def test_user_agent_default(self, **kwargs): service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) def callback(response): - assert 'User-Agent' in response.http_request.headers - assert f"azsdk-python-storage-queue/{VERSION}" in response.http_request.headers['User-Agent'] + assert "User-Agent" in response.http_request.headers + assert f"azsdk-python-storage-queue/{VERSION}" in response.http_request.headers["User-Agent"] service.get_service_properties(raw_response_hook=callback) @@ -506,24 +547,25 @@ def test_user_agent_custom(self, **kwargs): custom_app = "TestApp/v1.0" service = QueueServiceClient( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, user_agent=custom_app) + self.account_url(storage_account_name, "queue"), credential=storage_account_key, user_agent=custom_app + ) def callback(response): - assert 'User-Agent' in response.http_request.headers + assert "User-Agent" in response.http_request.headers assert ( f"TestApp/v1.0 azsdk-python-storage-queue/{VERSION} " f"Python/{platform.python_version()} " f"({platform.platform()})" - ) in response.http_request.headers['User-Agent'] + ) in response.http_request.headers["User-Agent"] service.get_service_properties(raw_response_hook=callback) def callback(response): - assert 'User-Agent' in response.http_request.headers + assert "User-Agent" in response.http_request.headers assert ( f"TestApp/v2.0 TestApp/v1.0 azsdk-python-storage-queue/{VERSION} " f"Python/{platform.python_version()} ({platform.platform()})" - ) in response.http_request.headers['User-Agent'] + ) in response.http_request.headers["User-Agent"] service.get_service_properties(raw_response_hook=callback, user_agent="TestApp/v2.0") @@ -536,13 +578,13 @@ def test_user_agent_append(self, **kwargs): service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) def callback(response): - assert 'User-Agent' in response.http_request.headers + assert "User-Agent" in response.http_request.headers assert ( f"customer_user_agent azsdk-python-storage-queue/{VERSION} " f"Python/{platform.python_version()} ({platform.platform()})" - ) in response.http_request.headers['User-Agent'] + ) in response.http_request.headers["User-Agent"] - service.get_service_properties(raw_response_hook=callback, user_agent='customer_user_agent') + service.get_service_properties(raw_response_hook=callback, user_agent="customer_user_agent") @QueuePreparer() def test_create_queue_client_with_complete_queue_url(self, **kwargs): @@ -551,11 +593,11 @@ def test_create_queue_client_with_complete_queue_url(self, **kwargs): # Arrange queue_url = self.account_url(storage_account_name, "queue") + "/foo" - service = QueueClient(queue_url, queue_name='bar', credential=storage_account_key) + service = QueueClient(queue_url, queue_name="bar", credential=storage_account_key) - # Assert - assert service.scheme == 'https' - assert service.queue_name == 'bar' + # Assert + assert service.scheme == "https" + assert service.queue_name == "bar" def test_error_with_malformed_conn_str(self): # Arrange @@ -566,9 +608,9 @@ def test_error_with_malformed_conn_str(self): with pytest.raises(ValueError) as e: service = service_type[0].from_connection_string(conn_str, queue_name="test") - if conn_str in("", "foobar", "foo;bar;baz", ";"): + if conn_str in ("", "foobar", "foo;bar;baz", ";"): assert str(e.value) == "Connection string is either blank or malformed." - elif conn_str in ("foobar=baz=foo" , "foo=;bar=;", "=", "=;=="): + elif conn_str in ("foobar=baz=foo", "foo=;bar=;", "=", "=;=="): assert str(e.value) == "Connection string missing required connection details." @QueuePreparer() @@ -580,11 +622,12 @@ def test_closing_pipeline_client(self, **kwargs): for client, url in SERVICES.items(): # Act service = client( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='queue') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="queue" + ) # Assert with service: - assert hasattr(service, 'close') + assert hasattr(service, "close") service.close() @QueuePreparer() @@ -596,7 +639,8 @@ def test_closing_pipeline_client_simple(self, **kwargs): for client, url in SERVICES.items(): # Act service = client( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='queue') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="queue" + ) service.close() @QueuePreparer() @@ -615,6 +659,7 @@ def test_get_and_set_queue_access_policy_oauth(self, **kwargs): acl = queue_client.get_queue_access_policy() assert acl is not None + # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py index 074fc0712d07..e5b4d516bb7b 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_client_async.py @@ -17,12 +17,12 @@ # ------------------------------------------------------------------------------ SERVICES = { - QueueServiceClient: 'queue', - QueueClient: 'queue', + QueueServiceClient: "queue", + QueueClient: "queue", } -_CONNECTION_ENDPOINTS = {'queue': 'QueueEndpoint'} +_CONNECTION_ENDPOINTS = {"queue": "QueueEndpoint"} -_CONNECTION_ENDPOINTS_SECONDARY = {'queue': 'QueueSecondaryEndpoint'} +_CONNECTION_ENDPOINTS_SECONDARY = {"queue": "QueueSecondaryEndpoint"} class TestAsyncStorageQueueClient(AsyncStorageRecordedTestCase): @@ -36,8 +36,8 @@ def validate_standard_account_endpoints(self, service, url_type, storage_account assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert f'{storage_account_name}.{url_type}.core.windows.net' in service.url - assert f'{storage_account_name}-secondary.{url_type}.core.windows.net' in service.secondary_endpoint + assert f"{storage_account_name}.{url_type}.core.windows.net" in service.url + assert f"{storage_account_name}-secondary.{url_type}.core.windows.net" in service.secondary_endpoint def generate_fake_sas_token(self): fake_key = "a" * 30 + "b" * 30 @@ -62,26 +62,29 @@ def test_create_service_with_key(self, **kwargs): for client, url in SERVICES.items(): # Act service = client( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="foo" + ) # Assert self.validate_standard_account_endpoints(service, url, storage_account_name, storage_account_key) - assert service.scheme == 'https' + assert service.scheme == "https" @QueuePreparer() def test_create_service_with_connection_string(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - for service_type in SERVICES.items(): # Act service = service_type[0].from_connection_string( - self.connection_string(storage_account_name, storage_account_key), queue_name="test") + self.connection_string(storage_account_name, storage_account_key), queue_name="test" + ) # Assert - self.validate_standard_account_endpoints(service, service_type[1], storage_account_name, storage_account_key) - assert service.scheme == 'https' + self.validate_standard_account_endpoints( + service, service_type[1], storage_account_name, storage_account_key + ) + assert service.scheme == "https" @QueuePreparer() def test_create_service_with_sas(self, **kwargs): @@ -93,12 +96,13 @@ def test_create_service_with_sas(self, **kwargs): for service_type in SERVICES: # Act service = service_type( - self.account_url(storage_account_name, "queue"), credential=self.sas_token, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=self.sas_token, queue_name="foo" + ) # Assert assert service is not None assert service.account_name == storage_account_name - assert service.url.startswith('https://' + storage_account_name + '.queue.core.windows.net') + assert service.url.startswith("https://" + storage_account_name + ".queue.core.windows.net") assert service.url.endswith(self.sas_token) assert service.credential is None @@ -110,15 +114,16 @@ async def test_create_service_with_token(self, **kwargs): for service_type in SERVICES: # Act service = service_type( - self.account_url(storage_account_name, "queue"), credential=self.token_credential, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=self.token_credential, queue_name="foo" + ) # Assert assert service is not None assert service.account_name == storage_account_name - assert service.url.startswith('https://' + storage_account_name + '.queue.core.windows.net') + assert service.url.startswith("https://" + storage_account_name + ".queue.core.windows.net") assert service.credential == self.token_credential - assert not hasattr(service.credential, 'account_key') - assert hasattr(service.credential, 'get_token') + assert not hasattr(service.credential, "account_key") + assert hasattr(service.credential, "get_token") @QueuePreparer() async def test_create_service_with_token_and_http(self, **kwargs): @@ -128,8 +133,8 @@ async def test_create_service_with_token_and_http(self, **kwargs): for service_type in SERVICES: # Act with pytest.raises(ValueError): - url = self.account_url(storage_account_name, "queue").replace('https', 'http') - service_type(url, credential=self.token_credential, queue_name='foo') + url = self.account_url(storage_account_name, "queue").replace("https", "http") + service_type(url, credential=self.token_credential, queue_name="foo") @QueuePreparer() def test_create_service_china(self, **kwargs): @@ -140,17 +145,26 @@ def test_create_service_china(self, **kwargs): for service_type in SERVICES.items(): # Act - url = self.account_url(storage_account_name, "queue").replace('core.windows.net', 'core.chinacloudapi.cn') - service = service_type[0]( - url, credential=storage_account_key, queue_name='foo') + url = self.account_url(storage_account_name, "queue").replace("core.windows.net", "core.chinacloudapi.cn") + service = service_type[0](url, credential=storage_account_key, queue_name="foo") # Assert assert service is not None assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith(f'https://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn') is True - assert service.secondary_endpoint.startswith(f'https://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn') is True + assert ( + service.primary_endpoint.startswith( + f"https://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn" + ) + is True + ) + assert ( + service.secondary_endpoint.startswith( + f"https://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn" + ) + is True + ) @QueuePreparer() def test_create_service_protocol(self, **kwargs): @@ -161,13 +175,14 @@ def test_create_service_protocol(self, **kwargs): for service_type in SERVICES.items(): # Act - url = self.account_url(storage_account_name, "queue").replace('https', 'http') - service = service_type[0]( - url, credential=storage_account_key, queue_name='foo') + url = self.account_url(storage_account_name, "queue").replace("https", "http") + service = service_type[0](url, credential=storage_account_key, queue_name="foo") # Assert - self.validate_standard_account_endpoints(service, service_type[1], storage_account_name, storage_account_key) - assert service.scheme == 'http' + self.validate_standard_account_endpoints( + service, service_type[1], storage_account_name, storage_account_key + ) + assert service.scheme == "http" @QueuePreparer() def test_create_service_empty_key(self, **kwargs): @@ -180,7 +195,7 @@ def test_create_service_empty_key(self, **kwargs): for service_type in QUEUE_SERVICES: # Act with pytest.raises(ValueError) as e: - test_service = service_type('testaccount', credential='', queue_name='foo') + test_service = service_type("testaccount", credential="", queue_name="foo") assert str(e.value) == "You need to provide either a SAS token or an account shared key to authenticate." @@ -194,13 +209,19 @@ def test_create_service_with_socket_timeout(self, **kwargs): for service_type in SERVICES.items(): # Act default_service = service_type[0]( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='foo') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="foo" + ) service = service_type[0]( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, - queue_name='foo', connection_timeout=22) + self.account_url(storage_account_name, "queue"), + credential=storage_account_key, + queue_name="foo", + connection_timeout=22, + ) # Assert - self.validate_standard_account_endpoints(service, service_type[1], storage_account_name, storage_account_key) + self.validate_standard_account_endpoints( + service, service_type[1], storage_account_name, storage_account_key + ) assert service._client._client._pipeline._transport.connection_config.timeout == 22 assert default_service._client._client._pipeline._transport.connection_config.timeout in [20, (20, 2000)] @@ -211,17 +232,16 @@ def test_create_service_with_connection_string_key(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - ) + conn_string = f"AccountName={storage_account_name};" f"AccountKey={storage_account_key};" for service_type in SERVICES.items(): # Act - service = service_type[0].from_connection_string(conn_string, queue_name='foo') + service = service_type[0].from_connection_string(conn_string, queue_name="foo") # Assert - self.validate_standard_account_endpoints(service, service_type[1], storage_account_name, storage_account_key) - assert service.scheme == 'https' + self.validate_standard_account_endpoints( + service, service_type[1], storage_account_name, storage_account_key + ) + assert service.scheme == "https" @QueuePreparer() def test_create_service_with_connection_string_sas(self, **kwargs): @@ -229,19 +249,16 @@ def test_create_service_with_connection_string_sas(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - conn_string = ( - f'AccountName={storage_account_name};' - f'SharedAccessSignature={self.sas_token};' - ) + conn_string = f"AccountName={storage_account_name};" f"SharedAccessSignature={self.sas_token};" for service_type in SERVICES: # Act - service = service_type.from_connection_string(conn_string, queue_name='foo') + service = service_type.from_connection_string(conn_string, queue_name="foo") # Assert assert service is not None assert service.account_name == storage_account_name - assert service.url.startswith('https://' + storage_account_name + '.queue.core.windows.net') + assert service.url.startswith("https://" + storage_account_name + ".queue.core.windows.net") assert service.url.endswith(self.sas_token) assert service.credential is None @@ -252,10 +269,10 @@ def test_create_service_with_conn_str_endpoint_protocol(self, **kwargs): # Arrange conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'DefaultEndpointsProtocol=http;EndpointSuffix=core.chinacloudapi.cn;' - ) + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "DefaultEndpointsProtocol=http;EndpointSuffix=core.chinacloudapi.cn;" + ) for service_type in SERVICES.items(): # Act @@ -266,15 +283,25 @@ def test_create_service_with_conn_str_endpoint_protocol(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith(f'http://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn/') is True - assert service.secondary_endpoint.startswith(f'http://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn') is True - assert service.scheme == 'http' + assert ( + service.primary_endpoint.startswith( + f"http://{storage_account_name}.{service_type[1]}.core.chinacloudapi.cn/" + ) + is True + ) + assert ( + service.secondary_endpoint.startswith( + f"http://{storage_account_name}-secondary.{service_type[1]}.core.chinacloudapi.cn" + ) + is True + ) + assert service.scheme == "http" @QueuePreparer() def test_create_service_with_connection_string_emulated(self, *args): # Arrange for service_type in SERVICES.items(): - conn_string = 'UseDevelopmentStorage=true;' + conn_string = "UseDevelopmentStorage=true;" # Act with pytest.raises(ValueError): @@ -288,9 +315,10 @@ def test_create_service_with_connection_string_custom_domain(self, **kwargs): # Arrange for service_type in SERVICES.items(): conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'QueueEndpoint=www.mydomain.com;') + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "QueueEndpoint=www.mydomain.com;" + ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -300,8 +328,10 @@ def test_create_service_with_connection_string_custom_domain(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith(f'https://{storage_account_name}-secondary.queue.core.windows.net') + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert service.secondary_endpoint.startswith( + f"https://{storage_account_name}-secondary.queue.core.windows.net" + ) @QueuePreparer() def test_create_serv_with_cs_custom_dmn_trlng_slash(self, **kwargs): @@ -311,9 +341,10 @@ def test_create_serv_with_cs_custom_dmn_trlng_slash(self, **kwargs): # Arrange for service_type in SERVICES.items(): conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'QueueEndpoint=www.mydomain.com/;') + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "QueueEndpoint=www.mydomain.com/;" + ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -323,9 +354,10 @@ def test_create_serv_with_cs_custom_dmn_trlng_slash(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith(f'https://{storage_account_name}-secondary.queue.core.windows.net') - + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert service.secondary_endpoint.startswith( + f"https://{storage_account_name}-secondary.queue.core.windows.net" + ) @QueuePreparer() def test_create_service_with_cs_custom_dmn_sec_override(self, **kwargs): @@ -335,21 +367,23 @@ def test_create_service_with_cs_custom_dmn_sec_override(self, **kwargs): # Arrange for service_type in SERVICES.items(): conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - 'QueueEndpoint=www.mydomain.com/;') + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + "QueueEndpoint=www.mydomain.com/;" + ) # Act service = service_type[0].from_connection_string( - conn_string, secondary_hostname="www-sec.mydomain.com", queue_name="foo") + conn_string, secondary_hostname="www-sec.mydomain.com", queue_name="foo" + ) # Assert assert service is not None assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith('https://www-sec.mydomain.com/') + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert service.secondary_endpoint.startswith("https://www-sec.mydomain.com/") @QueuePreparer() def test_create_service_with_cs_fails_if_sec_without_prim(self, **kwargs): @@ -359,9 +393,10 @@ def test_create_service_with_cs_fails_if_sec_without_prim(self, **kwargs): for service_type in SERVICES.items(): # Arrange conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - f'{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www.mydomain.com;') + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + f"{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www.mydomain.com;" + ) # Act @@ -377,10 +412,11 @@ def test_create_service_with_cs_succeeds_if_sec_with_prim(self, **kwargs): for service_type in SERVICES.items(): # Arrange conn_string = ( - f'AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - f'{_CONNECTION_ENDPOINTS.get(service_type[1])}=www.mydomain.com;' - f'{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www-sec.mydomain.com;') + f"AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + f"{_CONNECTION_ENDPOINTS.get(service_type[1])}=www.mydomain.com;" + f"{_CONNECTION_ENDPOINTS_SECONDARY.get(service_type[1])}=www-sec.mydomain.com;" + ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -390,8 +426,8 @@ def test_create_service_with_cs_succeeds_if_sec_with_prim(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_endpoint.startswith('https://www.mydomain.com/') - assert service.secondary_endpoint.startswith('https://www-sec.mydomain.com/') + assert service.primary_endpoint.startswith("https://www.mydomain.com/") + assert service.secondary_endpoint.startswith("https://www-sec.mydomain.com/") @QueuePreparer() def test_create_service_with_custom_account_endpoint_path(self, **kwargs): @@ -401,9 +437,10 @@ def test_create_service_with_custom_account_endpoint_path(self, **kwargs): custom_account_url = "http://local-machine:11002/custom/account/path/" + self.sas_token for service_type in SERVICES.items(): conn_string = ( - f'DefaultEndpointsProtocol=http;AccountName={storage_account_name};' - f'AccountKey={storage_account_key};' - f'QueueEndpoint={custom_account_url};') + f"DefaultEndpointsProtocol=http;AccountName={storage_account_name};" + f"AccountKey={storage_account_key};" + f"QueueEndpoint={custom_account_url};" + ) # Act service = service_type[0].from_connection_string(conn_string, queue_name="foo") @@ -412,27 +449,27 @@ def test_create_service_with_custom_account_endpoint_path(self, **kwargs): assert service.account_name == storage_account_name assert service.credential.account_name == storage_account_name assert service.credential.account_key == storage_account_key - assert service.primary_hostname == 'local-machine:11002/custom/account/path' + assert service.primary_hostname == "local-machine:11002/custom/account/path" service = QueueServiceClient(account_url=custom_account_url) assert service.account_name == None assert service.credential == None - assert service.primary_hostname == 'local-machine:11002/custom/account/path' - assert service.url.startswith('http://local-machine:11002/custom/account/path/?') + assert service.primary_hostname == "local-machine:11002/custom/account/path" + assert service.url.startswith("http://local-machine:11002/custom/account/path/?") service = QueueClient(account_url=custom_account_url, queue_name="foo") assert service.account_name == None assert service.queue_name == "foo" assert service.credential == None - assert service.primary_hostname == 'local-machine:11002/custom/account/path' - assert service.url.startswith('http://local-machine:11002/custom/account/path/foo?') + assert service.primary_hostname == "local-machine:11002/custom/account/path" + assert service.url.startswith("http://local-machine:11002/custom/account/path/foo?") service = QueueClient.from_queue_url("http://local-machine:11002/custom/account/path/foo" + self.sas_token) assert service.account_name == None assert service.queue_name == "foo" assert service.credential == None - assert service.primary_hostname == 'local-machine:11002/custom/account/path' - assert service.url.startswith('http://local-machine:11002/custom/account/path/foo?') + assert service.primary_hostname == "local-machine:11002/custom/account/path" + assert service.url.startswith("http://local-machine:11002/custom/account/path/foo?") @QueuePreparer() @recorded_by_proxy_async @@ -442,17 +479,17 @@ async def test_request_callback_signed_header(self, **kwargs): # Arrange service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) - name = self.get_resource_name('cont') + name = self.get_resource_name("cont") # Act try: - headers = {'x-ms-meta-hello': 'world'} + headers = {"x-ms-meta-hello": "world"} queue = await service.create_queue(name, headers=headers) # Assert metadata_cr = await queue.get_queue_properties() metadata = metadata_cr.metadata - assert metadata == {'hello': 'world'} + assert metadata == {"hello": "world"} finally: await service.delete_queue(name) @@ -464,14 +501,13 @@ async def test_response_callback(self, **kwargs): # Arrange service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) - name = self.get_resource_name('cont') + name = self.get_resource_name("cont") queue = service.get_queue_client(name) # Act def callback(response): response.http_response.status_code = 200 - # Assert exists = await queue.get_queue_properties(raw_response_hook=callback) assert exists @@ -485,8 +521,8 @@ async def test_user_agent_default(self, **kwargs): service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) def callback(response): - assert 'User-Agent' in response.http_request.headers - assert f"azsdk-python-storage-queue/{VERSION}" in response.http_request.headers['User-Agent'] + assert "User-Agent" in response.http_request.headers + assert f"azsdk-python-storage-queue/{VERSION}" in response.http_request.headers["User-Agent"] await service.get_service_properties(raw_response_hook=callback) @@ -498,23 +534,25 @@ async def test_user_agent_custom(self, **kwargs): custom_app = "TestApp/v1.0" service = QueueServiceClient( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, user_agent=custom_app) + self.account_url(storage_account_name, "queue"), credential=storage_account_key, user_agent=custom_app + ) def callback(response): - assert 'User-Agent' in response.http_request.headers + assert "User-Agent" in response.http_request.headers assert ( f"TestApp/v1.0 azsdk-python-storage-queue/{VERSION} " f"Python/{platform.python_version()} " - f"({platform.platform()})") in response.http_request.headers['User-Agent'] + f"({platform.platform()})" + ) in response.http_request.headers["User-Agent"] await service.get_service_properties(raw_response_hook=callback) def callback(response): - assert 'User-Agent' in response.http_request.headers + assert "User-Agent" in response.http_request.headers assert ( f"TestApp/v2.0 TestApp/v1.0 azsdk-python-storage-queue/{VERSION} " f"Python/{platform.python_version()} ({platform.platform()})" - ) in response.http_request.headers['User-Agent'] + ) in response.http_request.headers["User-Agent"] await service.get_service_properties(raw_response_hook=callback, user_agent="TestApp/v2.0") @@ -527,12 +565,13 @@ async def test_user_agent_append(self, **kwargs): service = QueueServiceClient(self.account_url(storage_account_name, "queue"), credential=storage_account_key) def callback(response): - assert 'User-Agent' in response.http_request.headers - assert (f"customer_user_agent azsdk-python-storage-queue/{VERSION} " - f"Python/{platform.python_version()} ({platform.platform()})" - ) in response.http_request.headers['User-Agent'] + assert "User-Agent" in response.http_request.headers + assert ( + f"customer_user_agent azsdk-python-storage-queue/{VERSION} " + f"Python/{platform.python_version()} ({platform.platform()})" + ) in response.http_request.headers["User-Agent"] - await service.get_service_properties(raw_response_hook=callback, user_agent='customer_user_agent') + await service.get_service_properties(raw_response_hook=callback, user_agent="customer_user_agent") @QueuePreparer() async def test_closing_pipeline_client(self, **kwargs): @@ -543,11 +582,12 @@ async def test_closing_pipeline_client(self, **kwargs): for client, url in SERVICES.items(): # Act service = client( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='queue') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="queue" + ) # Assert async with service: - assert hasattr(service, 'close') + assert hasattr(service, "close") await service.close() @QueuePreparer() @@ -559,7 +599,8 @@ async def test_closing_pipeline_client_simple(self, **kwargs): for client, url in SERVICES.items(): # Act service = client( - self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name='queue') + self.account_url(storage_account_name, "queue"), credential=storage_account_key, queue_name="queue" + ) await service.close() @QueuePreparer() @@ -578,6 +619,7 @@ async def test_get_and_set_queue_access_policy_oauth(self, **kwargs): acl = await queue_client.get_queue_access_policy() assert acl is not None + # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py index 82d8a1532c59..4fd5d4818674 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encodings.py @@ -13,7 +13,7 @@ QueueClient, QueueServiceClient, TextBase64DecodePolicy, - TextBase64EncodePolicy + TextBase64EncodePolicy, ) from azure.storage.queue._message_encoding import NoDecodePolicy, NoEncodePolicy @@ -22,11 +22,12 @@ from settings.testcase import QueuePreparer # ------------------------------------------------------------------------------ -TEST_QUEUE_PREFIX = 'mytestqueue' +TEST_QUEUE_PREFIX = "mytestqueue" # ------------------------------------------------------------------------------ + class TestStorageQueueEncoding(StorageRecordedTestCase): # --Helpers----------------------------------------------------------------- def _get_queue_reference(self, qsc, prefix=TEST_QUEUE_PREFIX): @@ -66,7 +67,7 @@ def test_message_text_xml(self, **kwargs): # Arrange. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - message = '' + message = "" queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts @@ -82,7 +83,7 @@ def test_message_text_xml_whitespace(self, **kwargs): # Arrange. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - message = ' mess\t age1\n' + message = " mess\t age1\n" queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts @@ -97,7 +98,7 @@ def test_message_text_xml_invalid_chars(self, **kwargs): # Action. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._get_queue_reference(qsc) - message = '\u0001' + message = "\u0001" # Asserts with pytest.raises(HttpResponseError): @@ -116,9 +117,10 @@ def test_message_text_base64(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=TextBase64EncodePolicy(), - message_decode_policy=TextBase64DecodePolicy()) + message_decode_policy=TextBase64DecodePolicy(), + ) - message = '\u0001' + message = "\u0001" # Asserts self._validate_encoding(queue, message) @@ -136,9 +138,10 @@ def test_message_bytes_base64(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) + message_decode_policy=BinaryBase64DecodePolicy(), + ) - message = b'xyz' + message = b"xyz" # Asserts self._validate_encoding(queue, message) @@ -151,19 +154,20 @@ def test_message_bytes_fails(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - queue = qsc.get_queue_client(self.get_resource_name('failqueue')) + queue = qsc.get_queue_client(self.get_resource_name("failqueue")) queue.create_queue() - # Action. with pytest.raises(TypeError) as e: - message = b'xyz' + message = b"xyz" queue.send_message(message) # Asserts - assert str(e.exception.startswith( - 'Message content must not be bytes. ' - 'Use the BinaryBase64EncodePolicy to send bytes.')) + assert str( + e.exception.startswith( + "Message content must not be bytes. " "Use the BinaryBase64EncodePolicy to send bytes." + ) + ) @QueuePreparer() def test_message_text_fails(self, **kwargs): @@ -177,15 +181,16 @@ def test_message_text_fails(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) + message_decode_policy=BinaryBase64DecodePolicy(), + ) # Action. with pytest.raises(TypeError) as e: - message = 'xyz' + message = "xyz" queue.send_message(message) # Asserts - assert str(e.value).startswith('Message content must be bytes') + assert str(e.value).startswith("Message content must be bytes") @QueuePreparer() @recorded_by_proxy @@ -200,12 +205,13 @@ def test_message_base64_decode_fails(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=None, - message_decode_policy=BinaryBase64DecodePolicy()) + message_decode_policy=BinaryBase64DecodePolicy(), + ) try: queue.create_queue() except ResourceExistsError: pass - message = 'xyz' + message = "xyz" queue.send_message(message) # Action. @@ -213,7 +219,7 @@ def test_message_base64_decode_fails(self, **kwargs): queue.peek_messages() # Asserts - assert -1 != str(e.value).find('Message content is not valid base 64') + assert -1 != str(e.value).find("Message content is not valid base 64") def test_message_no_encoding(self): # Arrange @@ -222,7 +228,8 @@ def test_message_no_encoding(self): queue_name="queue", credential="account_key", message_encode_policy=None, - message_decode_policy=None) + message_decode_policy=None, + ) # Asserts assert isinstance(queue._message_encode_policy, NoEncodePolicy) @@ -230,5 +237,5 @@ def test_message_no_encoding(self): # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py index 0759e7292a28..349d6cfc6ee9 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encodings_async.py @@ -7,7 +7,12 @@ import pytest from azure.core.exceptions import DecodeError, HttpResponseError, ResourceExistsError -from azure.storage.queue import BinaryBase64DecodePolicy, BinaryBase64EncodePolicy, TextBase64DecodePolicy, TextBase64EncodePolicy +from azure.storage.queue import ( + BinaryBase64DecodePolicy, + BinaryBase64EncodePolicy, + TextBase64DecodePolicy, + TextBase64EncodePolicy, +) from azure.storage.queue.aio import QueueClient, QueueServiceClient from devtools_testutils.aio import recorded_by_proxy_async @@ -15,7 +20,7 @@ from settings.testcase import QueuePreparer # ------------------------------------------------------------------------------ -TEST_QUEUE_PREFIX = 'mytestqueue' +TEST_QUEUE_PREFIX = "mytestqueue" # ------------------------------------------------------------------------------ @@ -60,7 +65,7 @@ async def test_message_text_xml(self, **kwargs): # Arrange. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - message = '' + message = "" queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts @@ -74,7 +79,7 @@ async def test_message_text_xml_whitespace(self, **kwargs): # Arrange. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - message = ' mess\t age1\n' + message = " mess\t age1\n" queue = qsc.get_queue_client(self.get_resource_name(TEST_QUEUE_PREFIX)) # Asserts @@ -89,7 +94,7 @@ async def test_message_text_xml_invalid_chars(self, **kwargs): # Action. qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._get_queue_reference(qsc) - message = '\u0001' + message = "\u0001" # Asserts with pytest.raises(HttpResponseError): @@ -108,9 +113,10 @@ async def test_message_text_base64(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=TextBase64EncodePolicy(), - message_decode_policy=TextBase64DecodePolicy()) + message_decode_policy=TextBase64DecodePolicy(), + ) - message = '\u0001' + message = "\u0001" # Asserts await self._validate_encoding(queue, message) @@ -122,17 +128,16 @@ async def test_message_bytes_base64(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange. - qsc = QueueServiceClient( - self.account_url(storage_account_name, "queue"), - storage_account_key) + qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = QueueClient( account_url=self.account_url(storage_account_name, "queue"), queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) + message_decode_policy=BinaryBase64DecodePolicy(), + ) - message = b'xyz' + message = b"xyz" # Asserts await self._validate_encoding(queue, message) @@ -148,13 +153,15 @@ async def test_message_bytes_fails(self, **kwargs): queue = await self._create_queue(qsc) # Action. with pytest.raises(TypeError) as e: - message = b'xyz' + message = b"xyz" await queue.send_message(message) # Asserts - assert str(e.exception.startswith( - 'Message content must not be bytes. ' - 'Use the BinaryBase64EncodePolicy to send bytes.')) + assert str( + e.exception.startswith( + "Message content must not be bytes. " "Use the BinaryBase64EncodePolicy to send bytes." + ) + ) @QueuePreparer() async def test_message_text_fails(self, **kwargs): @@ -168,15 +175,16 @@ async def test_message_text_fails(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) + message_decode_policy=BinaryBase64DecodePolicy(), + ) # Action. with pytest.raises(TypeError) as e: - message = 'xyz' + message = "xyz" await queue.send_message(message) # Asserts - assert str(e.value).startswith('Message content must be bytes') + assert str(e.value).startswith("Message content must be bytes") @QueuePreparer() @recorded_by_proxy_async @@ -191,12 +199,13 @@ async def test_message_base64_decode_fails(self, **kwargs): queue_name=self.get_resource_name(TEST_QUEUE_PREFIX), credential=storage_account_key, message_encode_policy=None, - message_decode_policy=BinaryBase64DecodePolicy()) + message_decode_policy=BinaryBase64DecodePolicy(), + ) try: await queue.create_queue() except ResourceExistsError: pass - message = 'xyz' + message = "xyz" await queue.send_message(message) # Action. @@ -204,8 +213,9 @@ async def test_message_base64_decode_fails(self, **kwargs): await queue.peek_messages() # Asserts - assert -1 != str(e.value).find('Message content is not valid base 64') + assert -1 != str(e.value).find("Message content is not valid base 64") + # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py b/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py index 5655112e093e..6b776d8e0cb6 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encryption.py @@ -43,15 +43,17 @@ # ------------------------------------------------------------------------------ -TEST_QUEUE_PREFIX = 'encryptionqueue' +TEST_QUEUE_PREFIX = "encryptionqueue" # ------------------------------------------------------------------------------ + def _decode_base64_to_bytes(data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") return b64decode(data) -@mock.patch('os.urandom', mock_urandom) + +@mock.patch("os.urandom", mock_urandom) class TestStorageQueueEncryption(StorageRecordedTestCase): # --Helpers----------------------------------------------------------------- def _get_queue_reference(self, qsc, prefix=TEST_QUEUE_PREFIX, **kwargs): @@ -77,15 +79,15 @@ def test_get_messages_encrypted_kek(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = self._create_queue(qsc) - queue.send_message('encrypted_message_2') + queue.send_message("encrypted_message_2") # Act li = next(queue.receive_messages()) # Assert - assert li.content == 'encrypted_message_2' + assert li.content == "encrypted_message_2" @QueuePreparer() @recorded_by_proxy @@ -95,9 +97,9 @@ def test_get_messages_encrypted_resolver(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = self._create_queue(qsc) - queue.send_message('encrypted_message_2') + queue.send_message("encrypted_message_2") key_resolver = KeyResolver() key_resolver.put_key(qsc.key_encryption_key) queue.key_resolver_function = key_resolver.resolve_key @@ -107,7 +109,7 @@ def test_get_messages_encrypted_resolver(self, **kwargs): li = next(queue.receive_messages()) # Assert - assert li.content == 'encrypted_message_2' + assert li.content == "encrypted_message_2" @QueuePreparer() @recorded_by_proxy @@ -117,15 +119,15 @@ def test_peek_messages_encrypted_kek(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = self._create_queue(qsc) - queue.send_message('encrypted_message_3') + queue.send_message("encrypted_message_3") # Act li = queue.peek_messages() # Assert - assert li[0].content == 'encrypted_message_3' + assert li[0].content == "encrypted_message_3" @QueuePreparer() @recorded_by_proxy @@ -135,9 +137,9 @@ def test_peek_messages_encrypted_resolver(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = self._create_queue(qsc) - queue.send_message('encrypted_message_4') + queue.send_message("encrypted_message_4") key_resolver = KeyResolver() key_resolver.put_key(qsc.key_encryption_key) queue.key_resolver_function = key_resolver.resolve_key @@ -147,7 +149,7 @@ def test_peek_messages_encrypted_resolver(self, **kwargs): li = queue.peek_messages() # Assert - assert li[0].content == 'encrypted_message_4' + assert li[0].content == "encrypted_message_4" @pytest.mark.live_test_only @QueuePreparer() @@ -155,21 +157,20 @@ def test_peek_messages_encrypted_kek_RSA(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - # We can only generate random RSA keys, so this must be run live or # the playback test will fail due to a change in kek values. # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = RSAKeyWrapper('key2') + qsc.key_encryption_key = RSAKeyWrapper("key2") queue = self._create_queue(qsc) - queue.send_message('encrypted_message_3') + queue.send_message("encrypted_message_3") # Act li = queue.peek_messages() # Assert - assert li[0].content == 'encrypted_message_3' + assert li[0].content == "encrypted_message_3" @QueuePreparer() @recorded_by_proxy @@ -180,19 +181,19 @@ def test_update_encrypted_message(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - queue.send_message('Update Me') + queue.key_encryption_key = KeyWrapper("key1") + queue.send_message("Update Me") messages = queue.receive_messages() list_result1 = next(messages) - list_result1.content = 'Updated' + list_result1.content = "Updated" # Act message = queue.update_message(list_result1) list_result2 = next(messages) # Assert - assert 'Updated' == list_result2.content + assert "Updated" == list_result2.content @QueuePreparer() @recorded_by_proxy @@ -203,9 +204,9 @@ def test_update_encrypted_binary_message(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue( - qsc, message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) - queue.key_encryption_key = KeyWrapper('key1') + qsc, message_encode_policy=BinaryBase64EncodePolicy(), message_decode_policy=BinaryBase64DecodePolicy() + ) + queue.key_encryption_key = KeyWrapper("key1") binary_message = self.get_random_bytes(100) queue.send_message(binary_message) @@ -223,7 +224,6 @@ def test_update_encrypted_binary_message(self, **kwargs): messages.append(m) list_result2 = messages[0] - # Assert assert binary_message == list_result2.content @@ -236,15 +236,15 @@ def test_update_encrypted_raw_text_message(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc, message_encode_policy=None, message_decode_policy=None) - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") - raw_text = 'Update Me' + raw_text = "Update Me" queue.send_message(raw_text) messages = queue.receive_messages() list_result1 = next(messages) # Act - raw_text = 'Updated' + raw_text = "Updated" list_result1.content = raw_text queue.update_message(list_result1) @@ -262,17 +262,17 @@ def test_update_encrypted_json_message(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc, message_encode_policy=None, message_decode_policy=None) - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") - message_dict = {'val1': 1, 'val2': '2'} + message_dict = {"val1": 1, "val2": "2"} json_text = dumps(message_dict) queue.send_message(json_text) messages = queue.receive_messages() list_result1 = next(messages) # Act - message_dict['val1'] = 0 - message_dict['val2'] = 'updated' + message_dict["val1"] = 0 + message_dict["val2"] = "updated" json_text = dumps(message_dict) list_result1.content = json_text queue.update_message(list_result1) @@ -291,23 +291,23 @@ def test_invalid_value_kek_wrap(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") queue.key_encryption_key.get_kid = None with pytest.raises(AttributeError) as e: - queue.send_message('message') + queue.send_message("message") - assert str(e.value.args[0]), _ERROR_OBJECT_INVALID.format('key encryption key' == 'get_kid') + assert str(e.value.args[0]), _ERROR_OBJECT_INVALID.format("key encryption key" == "get_kid") - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") queue.key_encryption_key.get_kid = None with pytest.raises(AttributeError): - queue.send_message('message') + queue.send_message("message") - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") queue.key_encryption_key.wrap_key = None with pytest.raises(AttributeError): - queue.send_message('message') + queue.send_message("message") @QueuePreparer() @recorded_by_proxy @@ -319,7 +319,7 @@ def test_missing_attribute_kek_wrap(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - valid_key = KeyWrapper('key1') + valid_key = KeyWrapper("key1") # Act invalid_key_1 = lambda: None # functions are objects, so this effectively creates an empty object @@ -328,7 +328,7 @@ def test_missing_attribute_kek_wrap(self, **kwargs): # No attribute wrap_key queue.key_encryption_key = invalid_key_1 with pytest.raises(AttributeError): - queue.send_message('message') + queue.send_message("message") invalid_key_2 = lambda: None # functions are objects, so this effectively creates an empty object invalid_key_2.wrap_key = valid_key.wrap_key @@ -336,7 +336,7 @@ def test_missing_attribute_kek_wrap(self, **kwargs): # No attribute get_key_wrap_algorithm queue.key_encryption_key = invalid_key_2 with pytest.raises(AttributeError): - queue.send_message('message') + queue.send_message("message") invalid_key_3 = lambda: None # functions are objects, so this effectively creates an empty object invalid_key_3.get_key_wrap_algorithm = valid_key.get_key_wrap_algorithm @@ -344,7 +344,7 @@ def test_missing_attribute_kek_wrap(self, **kwargs): # No attribute get_kid queue.key_encryption_key = invalid_key_3 with pytest.raises(AttributeError): - queue.send_message('message') + queue.send_message("message") @QueuePreparer() @recorded_by_proxy @@ -355,8 +355,8 @@ def test_invalid_value_kek_unwrap(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - queue.send_message('message') + queue.key_encryption_key = KeyWrapper("key1") + queue.send_message("message") # Act queue.key_encryption_key.unwrap_key = None @@ -376,11 +376,11 @@ def test_missing_attribute_kek_unwrap(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - queue.send_message('message') + queue.key_encryption_key = KeyWrapper("key1") + queue.send_message("message") # Act - valid_key = KeyWrapper('key1') + valid_key = KeyWrapper("key1") invalid_key_1 = lambda: None # functions are objects, so this effectively creates an empty object invalid_key_1.unwrap_key = valid_key.unwrap_key # No attribute get_kid @@ -406,9 +406,9 @@ def test_validate_encryption(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") queue.key_encryption_key = kek - queue.send_message('message') + queue.send_message("message") # Act queue.key_encryption_key = None # Message will not be decrypted @@ -416,30 +416,30 @@ def test_validate_encryption(self, **kwargs): message = li[0].content message = loads(message) - encryption_data = message['EncryptionData'] + encryption_data = message["EncryptionData"] - wrapped_content_key = encryption_data['WrappedContentKey'] + wrapped_content_key = encryption_data["WrappedContentKey"] wrapped_content_key = _WrappedContentKey( - wrapped_content_key['Algorithm'], - b64decode(wrapped_content_key['EncryptedKey'].encode(encoding='utf-8')), - wrapped_content_key['KeyId']) + wrapped_content_key["Algorithm"], + b64decode(wrapped_content_key["EncryptedKey"].encode(encoding="utf-8")), + wrapped_content_key["KeyId"], + ) - encryption_agent = encryption_data['EncryptionAgent'] - encryption_agent = _EncryptionAgent( - encryption_agent['EncryptionAlgorithm'], - encryption_agent['Protocol']) + encryption_agent = encryption_data["EncryptionAgent"] + encryption_agent = _EncryptionAgent(encryption_agent["EncryptionAlgorithm"], encryption_agent["Protocol"]) encryption_data = _EncryptionData( - b64decode(encryption_data['ContentEncryptionIV'].encode(encoding='utf-8')), + b64decode(encryption_data["ContentEncryptionIV"].encode(encoding="utf-8")), None, encryption_agent, wrapped_content_key, - {'EncryptionLibrary': VERSION}) + {"EncryptionLibrary": VERSION}, + ) - message = message['EncryptedMessageContents'] + message = message["EncryptedMessageContents"] content_encryption_key = kek.unwrap_key( - encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm + ) # Create decryption cipher backend = backends.default_backend() @@ -450,16 +450,16 @@ def test_validate_encryption(self, **kwargs): # decode and decrypt data decrypted_data = decode_base64_to_bytes(message) decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) + decrypted_data = decryptor.update(decrypted_data) + decryptor.finalize() # unpad data unpadder = PKCS7(128).unpadder() - decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) + decrypted_data = unpadder.update(decrypted_data) + unpadder.finalize() - decrypted_data = decrypted_data.decode(encoding='utf-8') + decrypted_data = decrypted_data.decode(encoding="utf-8") # Assert - assert decrypted_data == 'message' + assert decrypted_data == "message" @QueuePreparer() @recorded_by_proxy @@ -470,16 +470,16 @@ def test_put_with_strict_mode(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") queue.key_encryption_key = kek queue.require_encryption = True - queue.send_message('message') + queue.send_message("message") queue.key_encryption_key = None # Assert with pytest.raises(ValueError) as e: - queue.send_message('message') + queue.send_message("message") assert str(e.value.args[0]) == "Encryption required but no key was provided." @@ -492,14 +492,14 @@ def test_get_with_strict_mode(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - queue.send_message('message') + queue.send_message("message") queue.require_encryption = True - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") with pytest.raises(ValueError) as e: next(queue.receive_messages()) - assert 'Message was either not encrypted or metadata was incorrect.' in str(e.value.args[0]) + assert "Message was either not encrypted or metadata was incorrect." in str(e.value.args[0]) @QueuePreparer() @recorded_by_proxy @@ -510,13 +510,13 @@ def test_encryption_add_encrypted_64k_message(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - message = 'a' * 1024 * 64 + message = "a" * 1024 * 64 # Act queue.send_message(message) # Assert - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") with pytest.raises(HttpResponseError): queue.send_message(message) @@ -529,11 +529,11 @@ def test_encryption_nonmatching_kid(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) queue = self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - queue.send_message('message') + queue.key_encryption_key = KeyWrapper("key1") + queue.send_message("message") # Act - queue.key_encryption_key.kid = 'Invalid' + queue.key_encryption_key.kid = "Invalid" # Assert with pytest.raises(HttpResponseError) as e: @@ -552,10 +552,11 @@ def test_get_message_encrypted_kek_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) queue = self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" # Act queue.send_message(content) @@ -575,13 +576,14 @@ def test_get_message_encrypted_resolver_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) key_resolver = KeyResolver() key_resolver.put_key(qsc.key_encryption_key) queue = self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" # Act queue.send_message(content) @@ -607,10 +609,11 @@ def test_get_message_encrypted_kek_RSA_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=RSAKeyWrapper('key2')) + encryption_version="2.0", + key_encryption_key=RSAKeyWrapper("key2"), + ) queue = self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" # Act queue.send_message(content) @@ -630,20 +633,21 @@ def test_update_encrypted_message_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, requires_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) queue = self._create_queue(qsc) - queue.send_message('Update Me') + queue.send_message("Update Me") message = queue.receive_message() - message.content = 'Updated' + message.content = "Updated" # Act queue.update_message(message) message = queue.receive_message() # Assert - assert 'Updated' == message.content + assert "Updated" == message.content @QueuePreparer() @recorded_by_proxy @@ -656,24 +660,24 @@ def test_update_encrypted_binary_message_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, requires_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) queue = self._create_queue( - qsc, - message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) - queue.key_encryption_key = KeyWrapper('key1') + qsc, message_encode_policy=BinaryBase64EncodePolicy(), message_decode_policy=BinaryBase64DecodePolicy() + ) + queue.key_encryption_key = KeyWrapper("key1") - queue.send_message(b'Update Me') + queue.send_message(b"Update Me") message = queue.receive_message() - message.content = b'Updated' + message.content = b"Updated" # Act queue.update_message(message) message = queue.receive_message() # Assert - assert b'Updated' == message.content + assert b"Updated" == message.content @QueuePreparer() @recorded_by_proxy @@ -682,15 +686,16 @@ def test_encryption_v2_v1_downgrade(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") qsc = QueueServiceClient( self.account_url(storage_account_name, "queue"), storage_account_key, requires_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) queue = self._create_queue(qsc) - queue.send_message('Hello World Encrypted!') + queue.send_message("Hello World Encrypted!") queue.require_encryption = False queue.key_encryption_key = None # Message will not be decrypted @@ -698,12 +703,12 @@ def test_encryption_v2_v1_downgrade(self, **kwargs): content = loads(message.content) # Modify metadata to look like V1 - encryption_data = content['EncryptionData'] - encryption_data['EncryptionAgent']['Protocol'] = '1.0' - encryption_data['EncryptionAgent']['EncryptionAlgorithm'] = 'AES_CBC_256' + encryption_data = content["EncryptionData"] + encryption_data["EncryptionAgent"]["Protocol"] = "1.0" + encryption_data["EncryptionAgent"]["EncryptionAlgorithm"] = "AES_CBC_256" iv = b64encode(os.urandom(16)) - encryption_data['ContentEncryptionIV'] = iv.decode('utf-8') - content['EncryptionData'] = encryption_data + encryption_data["ContentEncryptionIV"] = iv.decode("utf-8") + content["EncryptionData"] = encryption_data message.content = dumps(content) @@ -718,7 +723,7 @@ def test_encryption_v2_v1_downgrade(self, **kwargs): with pytest.raises(HttpResponseError) as e: new_message = queue.receive_message() - assert 'Decryption failed.' in str(e.value.args[0]) + assert "Decryption failed." in str(e.value.args[0]) @QueuePreparer() @recorded_by_proxy @@ -727,15 +732,16 @@ def test_validate_encryption_v2(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") qsc = QueueServiceClient( self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) queue = self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" queue.send_message(content) # Act @@ -744,10 +750,10 @@ def test_validate_encryption_v2(self, **kwargs): message = queue.receive_message().content message = loads(message) - encryption_data = _dict_to_encryption_data(message['EncryptionData']) + encryption_data = _dict_to_encryption_data(message["EncryptionData"]) encryption_agent = encryption_data.encryption_agent - assert '2.0' == encryption_agent.protocol - assert 'AES_GCM_256' == encryption_agent.encryption_algorithm + assert "2.0" == encryption_agent.protocol + assert "AES_GCM_256" == encryption_agent.encryption_algorithm encrypted_region_info = encryption_data.encrypted_region_info assert _GCM_NONCE_LENGTH == encrypted_region_info.nonce_length @@ -757,7 +763,7 @@ def test_validate_encryption_v2(self, **kwargs): nonce_length = encrypted_region_info.nonce_length - message = message['EncryptedMessageContents'] + message = message["EncryptedMessageContents"] message = decode_base64_to_bytes(message) # First bytes are the nonce @@ -767,7 +773,7 @@ def test_validate_encryption_v2(self, **kwargs): aesgcm = AESGCM(content_encryption_key) decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) - decrypted_data = decrypted_data.decode(encoding='utf-8') + decrypted_data = decrypted_data.decode(encoding="utf-8") # Assert assert content == decrypted_data @@ -778,21 +784,22 @@ def test_encryption_user_agent(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - app_id = 'TestAppId' - content = 'Hello World Encrypted!' - kek = KeyWrapper('key1') + app_id = "TestAppId" + content = "Hello World Encrypted!" + kek = KeyWrapper("key1") def assert_user_agent(request): - start = f'{app_id} azstorage-clientsideencryption/2.0 ' - assert request.http_request.headers['User-Agent'].startswith(start) + start = f"{app_id} azstorage-clientsideencryption/2.0 " + assert request.http_request.headers["User-Agent"].startswith(start) # Test method level keyword qsc = QueueServiceClient( self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) queue = self._create_queue(qsc) queue.send_message(content, raw_request_hook=assert_user_agent, user_agent=app_id) @@ -801,14 +808,15 @@ def assert_user_agent(request): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', + encryption_version="2.0", key_encryption_key=kek, - user_agent=app_id) + user_agent=app_id, + ) queue = self._get_queue_reference(qsc) queue.send_message(content, raw_request_hook=assert_user_agent) # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py index 8fdf84b2952b..9872122ad7ed 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_encryption_async.py @@ -37,15 +37,17 @@ from settings.testcase import QueuePreparer # ------------------------------------------------------------------------------ -TEST_QUEUE_PREFIX = 'encryptionqueue' +TEST_QUEUE_PREFIX = "encryptionqueue" # ------------------------------------------------------------------------------ + def _decode_base64_to_bytes(data): - if isinstance(data, str): - data = data.encode('utf-8') - return b64decode(data) + if isinstance(data, str): + data = data.encode("utf-8") + return b64decode(data) + -@mock.patch('os.urandom', mock_urandom) +@mock.patch("os.urandom", mock_urandom) class TestAsyncStorageQueueEncryption(AsyncStorageRecordedTestCase): # --Helpers----------------------------------------------------------------- def _get_queue_reference(self, qsc, prefix=TEST_QUEUE_PREFIX, **kwargs): @@ -60,6 +62,7 @@ async def _create_queue(self, qsc, prefix=TEST_QUEUE_PREFIX, **kwargs): except ResourceExistsError: pass return queue + # -------------------------------------------------------------------------- @QueuePreparer() @@ -70,9 +73,9 @@ async def test_get_messages_encrypted_kek(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = await self._create_queue(qsc) - await queue.send_message('encrypted_message_2') + await queue.send_message("encrypted_message_2") # Act li = None @@ -80,7 +83,7 @@ async def test_get_messages_encrypted_kek(self, **kwargs): li = m # Assert - assert li.content == 'encrypted_message_2' + assert li.content == "encrypted_message_2" @QueuePreparer() @recorded_by_proxy_async @@ -90,9 +93,9 @@ async def test_get_messages_encrypted_resolver(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = await self._create_queue(qsc) - await queue.send_message('encrypted_message_2') + await queue.send_message("encrypted_message_2") key_resolver = KeyResolver() key_resolver.put_key(qsc.key_encryption_key) queue.key_resolver_function = key_resolver.resolve_key @@ -104,7 +107,7 @@ async def test_get_messages_encrypted_resolver(self, **kwargs): li = m # Assert - assert li.content == 'encrypted_message_2' + assert li.content == "encrypted_message_2" @QueuePreparer() @recorded_by_proxy_async @@ -114,15 +117,15 @@ async def test_peek_messages_encrypted_kek(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = await self._create_queue(qsc) - await queue.send_message('encrypted_message_3') + await queue.send_message("encrypted_message_3") # Act li = await queue.peek_messages() # Assert - assert li[0].content == 'encrypted_message_3' + assert li[0].content == "encrypted_message_3" @QueuePreparer() @recorded_by_proxy_async @@ -132,9 +135,9 @@ async def test_peek_messages_encrypted_resolver(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange - qsc.key_encryption_key = KeyWrapper('key1') + qsc.key_encryption_key = KeyWrapper("key1") queue = await self._create_queue(qsc) - await queue.send_message('encrypted_message_4') + await queue.send_message("encrypted_message_4") key_resolver = KeyResolver() key_resolver.put_key(qsc.key_encryption_key) queue.key_resolver_function = key_resolver.resolve_key @@ -144,7 +147,7 @@ async def test_peek_messages_encrypted_resolver(self, **kwargs): li = await queue.peek_messages() # Assert - assert li[0].content == 'encrypted_message_4' + assert li[0].content == "encrypted_message_4" @pytest.mark.live_test_only @QueuePreparer() @@ -157,15 +160,15 @@ async def test_peek_messages_encrypted_kek_RSA(self, **kwargs): # the playback test will fail due to a change in kek values. # Arrange - qsc.key_encryption_key = RSAKeyWrapper('key2') + qsc.key_encryption_key = RSAKeyWrapper("key2") queue = await self._create_queue(qsc) - await queue.send_message('encrypted_message_3') + await queue.send_message("encrypted_message_3") # Act li = await queue.peek_messages() # Assert - assert li[0].content == 'encrypted_message_3' + assert li[0].content == "encrypted_message_3" @QueuePreparer() @recorded_by_proxy_async @@ -176,14 +179,14 @@ async def test_update_encrypted_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - await queue.send_message('Update Me') + queue.key_encryption_key = KeyWrapper("key1") + await queue.send_message("Update Me") messages = [] async for m in queue.receive_messages(): messages.append(m) list_result1 = messages[0] - list_result1.content = 'Updated' + list_result1.content = "Updated" # Act message = await queue.update_message(list_result1) @@ -192,7 +195,7 @@ async def test_update_encrypted_message(self, **kwargs): list_result2 = messages[0] # Assert - assert 'Updated' == list_result2.content + assert "Updated" == list_result2.content @QueuePreparer() @recorded_by_proxy_async @@ -203,9 +206,9 @@ async def test_update_encrypted_binary_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue( - qsc, message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) - queue.key_encryption_key = KeyWrapper('key1') + qsc, message_encode_policy=BinaryBase64EncodePolicy(), message_decode_policy=BinaryBase64DecodePolicy() + ) + queue.key_encryption_key = KeyWrapper("key1") binary_message = self.get_random_bytes(100) await queue.send_message(binary_message) @@ -235,9 +238,9 @@ async def test_update_encrypted_raw_text_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc, message_encode_policy=None, message_decode_policy=None) - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") - raw_text = 'Update Me' + raw_text = "Update Me" await queue.send_message(raw_text) messages = [] async for m in queue.receive_messages(): @@ -245,7 +248,7 @@ async def test_update_encrypted_raw_text_message(self, **kwargs): list_result1 = messages[0] # Act - raw_text = 'Updated' + raw_text = "Updated" list_result1.content = raw_text async for m in queue.receive_messages(): messages.append(m) @@ -263,9 +266,9 @@ async def test_update_encrypted_json_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc, message_encode_policy=None, message_decode_policy=None) - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") - message_dict = {'val1': 1, 'val2': '2'} + message_dict = {"val1": 1, "val2": "2"} json_text = dumps(message_dict) await queue.send_message(json_text) messages = [] @@ -274,8 +277,8 @@ async def test_update_encrypted_json_message(self, **kwargs): list_result1 = messages[0] # Act - message_dict['val1'] = 0 - message_dict['val2'] = 'updated' + message_dict["val1"] = 0 + message_dict["val2"] = "updated" json_text = dumps(message_dict) list_result1.content = json_text await queue.update_message(list_result1) @@ -296,23 +299,23 @@ async def test_invalid_value_kek_wrap(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") queue.key_encryption_key.get_kid = None with pytest.raises(AttributeError) as e: - await queue.send_message('message') + await queue.send_message("message") - assert str(e.value.args[0]), _ERROR_OBJECT_INVALID.format('key encryption key' == 'get_kid') + assert str(e.value.args[0]), _ERROR_OBJECT_INVALID.format("key encryption key" == "get_kid") - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") queue.key_encryption_key.get_kid = None with pytest.raises(AttributeError): - await queue.send_message('message') + await queue.send_message("message") - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") queue.key_encryption_key.wrap_key = None with pytest.raises(AttributeError): - await queue.send_message('message') + await queue.send_message("message") @QueuePreparer() @recorded_by_proxy_async @@ -324,7 +327,7 @@ async def test_missing_attribute_kek_wrap(self, **kwargs): # Arrange queue = await self._create_queue(qsc) - valid_key = KeyWrapper('key1') + valid_key = KeyWrapper("key1") # Act invalid_key_1 = lambda: None # functions are objects, so this effectively creates an empty object @@ -333,7 +336,7 @@ async def test_missing_attribute_kek_wrap(self, **kwargs): # No attribute wrap_key queue.key_encryption_key = invalid_key_1 with pytest.raises(AttributeError): - await queue.send_message('message') + await queue.send_message("message") invalid_key_2 = lambda: None # functions are objects, so this effectively creates an empty object invalid_key_2.wrap_key = valid_key.wrap_key @@ -341,7 +344,7 @@ async def test_missing_attribute_kek_wrap(self, **kwargs): # No attribute get_key_wrap_algorithm queue.key_encryption_key = invalid_key_2 with pytest.raises(AttributeError): - await queue.send_message('message') + await queue.send_message("message") invalid_key_3 = lambda: None # functions are objects, so this effectively creates an empty object invalid_key_3.get_key_wrap_algorithm = valid_key.get_key_wrap_algorithm @@ -349,7 +352,7 @@ async def test_missing_attribute_kek_wrap(self, **kwargs): # No attribute get_kid queue.key_encryption_key = invalid_key_3 with pytest.raises(AttributeError): - await queue.send_message('message') + await queue.send_message("message") @QueuePreparer() @recorded_by_proxy_async @@ -360,8 +363,8 @@ async def test_invalid_value_kek_unwrap(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - await queue.send_message('message') + queue.key_encryption_key = KeyWrapper("key1") + await queue.send_message("message") # Act queue.key_encryption_key.unwrap_key = None @@ -381,11 +384,11 @@ async def test_missing_attribute_kek_unwrap(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - await queue.send_message('message') + queue.key_encryption_key = KeyWrapper("key1") + await queue.send_message("message") # Act - valid_key = KeyWrapper('key1') + valid_key = KeyWrapper("key1") invalid_key_1 = lambda: None # functions are objects, so this effectively creates an empty object invalid_key_1.unwrap_key = valid_key.unwrap_key # No attribute get_kid @@ -411,9 +414,9 @@ async def test_validate_encryption(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") queue.key_encryption_key = kek - await queue.send_message('message') + await queue.send_message("message") # Act queue.key_encryption_key = None # Message will not be decrypted @@ -421,30 +424,30 @@ async def test_validate_encryption(self, **kwargs): message = li[0].content message = loads(message) - encryption_data = message['EncryptionData'] + encryption_data = message["EncryptionData"] - wrapped_content_key = encryption_data['WrappedContentKey'] + wrapped_content_key = encryption_data["WrappedContentKey"] wrapped_content_key = _WrappedContentKey( - wrapped_content_key['Algorithm'], - b64decode(wrapped_content_key['EncryptedKey'].encode(encoding='utf-8')), - wrapped_content_key['KeyId']) + wrapped_content_key["Algorithm"], + b64decode(wrapped_content_key["EncryptedKey"].encode(encoding="utf-8")), + wrapped_content_key["KeyId"], + ) - encryption_agent = encryption_data['EncryptionAgent'] - encryption_agent = _EncryptionAgent( - encryption_agent['EncryptionAlgorithm'], - encryption_agent['Protocol']) + encryption_agent = encryption_data["EncryptionAgent"] + encryption_agent = _EncryptionAgent(encryption_agent["EncryptionAlgorithm"], encryption_agent["Protocol"]) encryption_data = _EncryptionData( - b64decode(encryption_data['ContentEncryptionIV'].encode(encoding='utf-8')), + b64decode(encryption_data["ContentEncryptionIV"].encode(encoding="utf-8")), None, encryption_agent, wrapped_content_key, - {'EncryptionLibrary': VERSION}) + {"EncryptionLibrary": VERSION}, + ) - message = message['EncryptedMessageContents'] + message = message["EncryptedMessageContents"] content_encryption_key = kek.unwrap_key( - encryption_data.wrapped_content_key.encrypted_key, - encryption_data.wrapped_content_key.algorithm) + encryption_data.wrapped_content_key.encrypted_key, encryption_data.wrapped_content_key.algorithm + ) # Create decryption cipher backend = backends.default_backend() @@ -455,16 +458,16 @@ async def test_validate_encryption(self, **kwargs): # decode and decrypt data decrypted_data = _decode_base64_to_bytes(message) decryptor = cipher.decryptor() - decrypted_data = (decryptor.update(decrypted_data) + decryptor.finalize()) + decrypted_data = decryptor.update(decrypted_data) + decryptor.finalize() # unpad data unpadder = PKCS7(128).unpadder() - decrypted_data = (unpadder.update(decrypted_data) + unpadder.finalize()) + decrypted_data = unpadder.update(decrypted_data) + unpadder.finalize() - decrypted_data = decrypted_data.decode(encoding='utf-8') + decrypted_data = decrypted_data.decode(encoding="utf-8") # Assert - assert decrypted_data == 'message' + assert decrypted_data == "message" @QueuePreparer() @recorded_by_proxy_async @@ -475,16 +478,16 @@ async def test_put_with_strict_mode(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") queue.key_encryption_key = kek queue.require_encryption = True - await queue.send_message('message') + await queue.send_message("message") queue.key_encryption_key = None # Assert with pytest.raises(ValueError) as e: - await queue.send_message('message') + await queue.send_message("message") assert str(e.value.args[0]) == "Encryption required but no key was provided." @@ -497,16 +500,16 @@ async def test_get_with_strict_mode(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - await queue.send_message('message') + await queue.send_message("message") queue.require_encryption = True - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") with pytest.raises(ValueError) as e: messages = [] async for m in queue.receive_messages(): messages.append(m) _ = messages[0] - assert 'Message was either not encrypted or metadata was incorrect.' in str(e.value.args[0]) + assert "Message was either not encrypted or metadata was incorrect." in str(e.value.args[0]) @QueuePreparer() @recorded_by_proxy_async @@ -517,13 +520,13 @@ async def test_encryption_add_encrypted_64k_message(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - message = 'a' * 1024 * 64 + message = "a" * 1024 * 64 # Act await queue.send_message(message) # Assert - queue.key_encryption_key = KeyWrapper('key1') + queue.key_encryption_key = KeyWrapper("key1") with pytest.raises(HttpResponseError): await queue.send_message(message) @@ -536,11 +539,11 @@ async def test_encryption_nonmatching_kid(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Arrange queue = await self._create_queue(qsc) - queue.key_encryption_key = KeyWrapper('key1') - await queue.send_message('message') + queue.key_encryption_key = KeyWrapper("key1") + await queue.send_message("message") # Act - queue.key_encryption_key.kid = 'Invalid' + queue.key_encryption_key.kid = "Invalid" # Assert with pytest.raises(HttpResponseError) as e: @@ -561,10 +564,11 @@ async def test_get_message_encrypted_kek_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) queue = await self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" # Act await queue.send_message(content) @@ -584,13 +588,14 @@ async def test_get_message_encrypted_resolver_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) key_resolver = KeyResolver() key_resolver.put_key(qsc.key_encryption_key) queue = await self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" # Act await queue.send_message(content) @@ -616,10 +621,11 @@ async def test_get_message_encrypted_kek_RSA_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=RSAKeyWrapper('key2')) + encryption_version="2.0", + key_encryption_key=RSAKeyWrapper("key2"), + ) queue = await self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" # Act await queue.send_message(content) @@ -639,20 +645,21 @@ async def test_update_encrypted_message_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) queue = await self._create_queue(qsc) - await queue.send_message('Update Me') + await queue.send_message("Update Me") message = await queue.receive_message() - message.content = 'Updated' + message.content = "Updated" # Act await queue.update_message(message) message = await queue.receive_message() # Assert - assert 'Updated' == message.content + assert "Updated" == message.content @QueuePreparer() @recorded_by_proxy_async @@ -665,24 +672,24 @@ async def test_update_encrypted_binary_message_v2(self, **kwargs): self.account_url(storage_account_name, "queue"), storage_account_key, requires_encryption=True, - encryption_version='2.0', - key_encryption_key=KeyWrapper('key1')) + encryption_version="2.0", + key_encryption_key=KeyWrapper("key1"), + ) queue = await self._create_queue( - qsc, - message_encode_policy=BinaryBase64EncodePolicy(), - message_decode_policy=BinaryBase64DecodePolicy()) - queue.key_encryption_key = KeyWrapper('key1') + qsc, message_encode_policy=BinaryBase64EncodePolicy(), message_decode_policy=BinaryBase64DecodePolicy() + ) + queue.key_encryption_key = KeyWrapper("key1") - await queue.send_message(b'Update Me') + await queue.send_message(b"Update Me") message = await queue.receive_message() - message.content = b'Updated' + message.content = b"Updated" # Act await queue.update_message(message) message = await queue.receive_message() # Assert - assert b'Updated' == message.content + assert b"Updated" == message.content @QueuePreparer() @recorded_by_proxy_async @@ -691,15 +698,16 @@ async def test_encryption_v2_v1_downgrade(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") qsc = QueueServiceClient( self.account_url(storage_account_name, "queue"), storage_account_key, requires_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) queue = await self._create_queue(qsc) - await queue.send_message('Hello World Encrypted!') + await queue.send_message("Hello World Encrypted!") queue.require_encryption = False queue.key_encryption_key = None # Message will not be decrypted @@ -707,12 +715,12 @@ async def test_encryption_v2_v1_downgrade(self, **kwargs): content = loads(message.content) # Modify metadata to look like V1 - encryption_data = content['EncryptionData'] - encryption_data['EncryptionAgent']['Protocol'] = '1.0' - encryption_data['EncryptionAgent']['EncryptionAlgorithm'] = 'AES_CBC_256' + encryption_data = content["EncryptionData"] + encryption_data["EncryptionAgent"]["Protocol"] = "1.0" + encryption_data["EncryptionAgent"]["EncryptionAlgorithm"] = "AES_CBC_256" iv = b64encode(os.urandom(16)) - encryption_data['ContentEncryptionIV'] = iv.decode('utf-8') - content['EncryptionData'] = encryption_data + encryption_data["ContentEncryptionIV"] = iv.decode("utf-8") + content["EncryptionData"] = encryption_data message.content = dumps(content) @@ -727,7 +735,7 @@ async def test_encryption_v2_v1_downgrade(self, **kwargs): with pytest.raises(HttpResponseError) as e: await queue.receive_message() - assert 'Decryption failed.' in str(e.value.args[0]) + assert "Decryption failed." in str(e.value.args[0]) @QueuePreparer() @recorded_by_proxy_async @@ -736,15 +744,16 @@ async def test_validate_encryption_v2(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Arrange - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") qsc = QueueServiceClient( self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) queue = await self._create_queue(qsc) - content = 'Hello World Encrypted!' + content = "Hello World Encrypted!" await queue.send_message(content) # Act @@ -753,10 +762,10 @@ async def test_validate_encryption_v2(self, **kwargs): message = (await queue.receive_message()).content message = loads(message) - encryption_data = _dict_to_encryption_data(message['EncryptionData']) + encryption_data = _dict_to_encryption_data(message["EncryptionData"]) encryption_agent = encryption_data.encryption_agent - assert '2.0' == encryption_agent.protocol - assert 'AES_GCM_256' == encryption_agent.encryption_algorithm + assert "2.0" == encryption_agent.protocol + assert "AES_GCM_256" == encryption_agent.encryption_algorithm encrypted_region_info = encryption_data.encrypted_region_info assert _GCM_NONCE_LENGTH == encrypted_region_info.nonce_length @@ -766,7 +775,7 @@ async def test_validate_encryption_v2(self, **kwargs): nonce_length = encrypted_region_info.nonce_length - message = message['EncryptedMessageContents'] + message = message["EncryptedMessageContents"] message = _decode_base64_to_bytes(message) # First bytes are the nonce @@ -776,7 +785,7 @@ async def test_validate_encryption_v2(self, **kwargs): aesgcm = AESGCM(content_encryption_key) decrypted_data = aesgcm.decrypt(nonce, ciphertext_with_tag, None) - decrypted_data = decrypted_data.decode(encoding='utf-8') + decrypted_data = decrypted_data.decode(encoding="utf-8") # Assert assert content == decrypted_data @@ -787,21 +796,22 @@ async def test_encryption_user_agent(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - app_id = 'TestAppId' - content = 'Hello World Encrypted!' - kek = KeyWrapper('key1') + app_id = "TestAppId" + content = "Hello World Encrypted!" + kek = KeyWrapper("key1") def assert_user_agent(request): - start = f'{app_id} azstorage-clientsideencryption/2.0 ' - assert request.http_request.headers['User-Agent'].startswith(start) + start = f"{app_id} azstorage-clientsideencryption/2.0 " + assert request.http_request.headers["User-Agent"].startswith(start) # Test method level keyword qsc = QueueServiceClient( self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) queue = await self._create_queue(qsc) await queue.send_message(content, raw_request_hook=assert_user_agent, user_agent=app_id) @@ -810,14 +820,15 @@ def assert_user_agent(request): self.account_url(storage_account_name, "queue"), storage_account_key, require_encryption=True, - encryption_version='2.0', + encryption_version="2.0", key_encryption_key=kek, - user_agent=app_id) + user_agent=app_id, + ) queue = self._get_queue_reference(qsc) await queue.send_message(content, raw_request_hook=assert_user_agent) # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_service_properties.py b/sdk/storage/azure-storage-queue/tests/test_queue_service_properties.py index 4cabe6cbb41d..d84b0ad05843 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_service_properties.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_service_properties.py @@ -24,10 +24,10 @@ class TestQueueServiceProperties(StorageRecordedTestCase): def _assert_properties_default(self, prop): assert prop is not None - self._assert_logging_equal(prop['analytics_logging'], QueueAnalyticsLogging()) - self._assert_metrics_equal(prop['hour_metrics'], Metrics()) - self._assert_metrics_equal(prop['minute_metrics'], Metrics()) - self._assert_cors_equal(prop['cors'], []) + self._assert_logging_equal(prop["analytics_logging"], QueueAnalyticsLogging()) + self._assert_metrics_equal(prop["hour_metrics"], Metrics()) + self._assert_metrics_equal(prop["minute_metrics"], Metrics()) + self._assert_cors_equal(prop["cors"], []) def _assert_logging_equal(self, log1, log2): if log1 is None or log2 is None: @@ -106,19 +106,15 @@ def test_queue_service_properties(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) # Act resp = qsc.set_service_properties( - analytics_logging=QueueAnalyticsLogging(), - hour_metrics=Metrics(), - minute_metrics=Metrics(), - cors=[]) + analytics_logging=QueueAnalyticsLogging(), hour_metrics=Metrics(), minute_metrics=Metrics(), cors=[] + ) # Assert assert resp is None self._assert_properties_default(qsc.get_service_properties()) - # --Test cases per feature --------------------------------------- - @QueuePreparer() @recorded_by_proxy def test_set_logging(self, **kwargs): @@ -127,14 +123,16 @@ def test_set_logging(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + logging = QueueAnalyticsLogging( + read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Act qsc.set_service_properties(analytics_logging=logging) # Assert received_props = qsc.get_service_properties() - self._assert_logging_equal(received_props['analytics_logging'], logging) + self._assert_logging_equal(received_props["analytics_logging"], logging) @QueuePreparer() @recorded_by_proxy @@ -151,7 +149,7 @@ def test_set_hour_metrics(self, **kwargs): # Assert received_props = qsc.get_service_properties() - self._assert_metrics_equal(received_props['hour_metrics'], hour_metrics) + self._assert_metrics_equal(received_props["hour_metrics"], hour_metrics) @QueuePreparer() @recorded_by_proxy @@ -161,15 +159,16 @@ def test_set_minute_metrics(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - minute_metrics = Metrics(enabled=True, include_apis=True, - retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Act qsc.set_service_properties(minute_metrics=minute_metrics) # Assert received_props = qsc.get_service_properties() - self._assert_metrics_equal(received_props['minute_metrics'], minute_metrics) + self._assert_metrics_equal(received_props["minute_metrics"], minute_metrics) @QueuePreparer() @recorded_by_proxy @@ -179,10 +178,10 @@ def test_set_cors(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + cors_rule1 = CorsRule(["www.xyz.com"], ["GET"]) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] + allowed_origins = ["www.xyz.com", "www.ab.com", "www.bc.com"] + allowed_methods = ["GET", "PUT"] max_age_in_seconds = 500 exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] @@ -191,7 +190,8 @@ def test_set_cors(self, **kwargs): allowed_methods, max_age_in_seconds=max_age_in_seconds, exposed_headers=exposed_headers, - allowed_headers=allowed_headers) + allowed_headers=allowed_headers, + ) cors = [cors_rule1, cors_rule2] @@ -200,7 +200,7 @@ def test_set_cors(self, **kwargs): # Assert received_props = qsc.get_service_properties() - self._assert_cors_equal(received_props['cors'], cors) + self._assert_cors_equal(received_props["cors"], cors) # --Test cases for errors --------------------------------------- @QueuePreparer() @@ -209,9 +209,7 @@ def test_retention_no_days(self, **kwargs): storage_account_key = kwargs.pop("storage_account_key") # Assert - pytest.raises(ValueError, - RetentionPolicy, - True, None) + pytest.raises(ValueError, RetentionPolicy, True, None) @QueuePreparer() @recorded_by_proxy @@ -223,11 +221,10 @@ def test_too_many_cors_rules(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) cors = [] for i in range(0, 6): - cors.append(CorsRule(['www.xyz.com'], ['GET'])) + cors.append(CorsRule(["www.xyz.com"], ["GET"])) # Assert - pytest.raises(HttpResponseError, - qsc.set_service_properties, None, None, None, cors) + pytest.raises(HttpResponseError, qsc.set_service_properties, None, None, None, cors) @QueuePreparer() @recorded_by_proxy @@ -237,15 +234,14 @@ def test_retention_too_long(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - minute_metrics = Metrics(enabled=True, include_apis=True, - retention_policy=RetentionPolicy(enabled=True, days=366)) + minute_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=366) + ) # Assert - pytest.raises(HttpResponseError, - qsc.set_service_properties, - None, None, minute_metrics) + pytest.raises(HttpResponseError, qsc.set_service_properties, None, None, minute_metrics) # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py index a372b7488c92..7e870420898b 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_service_properties_async.py @@ -20,10 +20,10 @@ class TestAsyncQueueServiceProperties(AsyncStorageRecordedTestCase): def _assert_properties_default(self, prop): assert prop is not None - self._assert_logging_equal(prop['analytics_logging'], QueueAnalyticsLogging()) - self._assert_metrics_equal(prop['hour_metrics'], Metrics()) - self._assert_metrics_equal(prop['minute_metrics'], Metrics()) - self._assert_cors_equal(prop['cors'], []) + self._assert_logging_equal(prop["analytics_logging"], QueueAnalyticsLogging()) + self._assert_metrics_equal(prop["hour_metrics"], Metrics()) + self._assert_metrics_equal(prop["minute_metrics"], Metrics()) + self._assert_cors_equal(prop["cors"], []) def _assert_logging_equal(self, log1, log2): if log1 is None or log2 is None: @@ -103,10 +103,8 @@ async def test_queue_service_properties(self, **kwargs): # Act resp = await qsc.set_service_properties( - analytics_logging=QueueAnalyticsLogging(), - hour_metrics=Metrics(), - minute_metrics=Metrics(), - cors=[]) + analytics_logging=QueueAnalyticsLogging(), hour_metrics=Metrics(), minute_metrics=Metrics(), cors=[] + ) # Assert assert resp is None @@ -122,14 +120,16 @@ async def test_set_logging(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - logging = QueueAnalyticsLogging(read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5)) + logging = QueueAnalyticsLogging( + read=True, write=True, delete=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Act await qsc.set_service_properties(analytics_logging=logging) # Assert received_props = await qsc.get_service_properties() - self._assert_logging_equal(received_props['analytics_logging'], logging) + self._assert_logging_equal(received_props["analytics_logging"], logging) @QueuePreparer() @recorded_by_proxy_async @@ -146,7 +146,7 @@ async def test_set_hour_metrics(self, **kwargs): # Assert received_props = await qsc.get_service_properties() - self._assert_metrics_equal(received_props['hour_metrics'], hour_metrics) + self._assert_metrics_equal(received_props["hour_metrics"], hour_metrics) @QueuePreparer() @recorded_by_proxy_async @@ -156,15 +156,16 @@ async def test_set_minute_metrics(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - minute_metrics = Metrics(enabled=True, include_apis=True, - retention_policy=RetentionPolicy(enabled=True, days=5)) + minute_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=5) + ) # Act await qsc.set_service_properties(minute_metrics=minute_metrics) # Assert received_props = await qsc.get_service_properties() - self._assert_metrics_equal(received_props['minute_metrics'], minute_metrics) + self._assert_metrics_equal(received_props["minute_metrics"], minute_metrics) @QueuePreparer() @recorded_by_proxy_async @@ -174,10 +175,10 @@ async def test_set_cors(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - cors_rule1 = CorsRule(['www.xyz.com'], ['GET']) + cors_rule1 = CorsRule(["www.xyz.com"], ["GET"]) - allowed_origins = ['www.xyz.com', "www.ab.com", "www.bc.com"] - allowed_methods = ['GET', 'PUT'] + allowed_origins = ["www.xyz.com", "www.ab.com", "www.bc.com"] + allowed_methods = ["GET", "PUT"] max_age_in_seconds = 500 exposed_headers = ["x-ms-meta-data*", "x-ms-meta-source*", "x-ms-meta-abc", "x-ms-meta-bcd"] allowed_headers = ["x-ms-meta-data*", "x-ms-meta-target*", "x-ms-meta-xyz", "x-ms-meta-foo"] @@ -186,7 +187,8 @@ async def test_set_cors(self, **kwargs): allowed_methods, max_age_in_seconds=max_age_in_seconds, exposed_headers=exposed_headers, - allowed_headers=allowed_headers) + allowed_headers=allowed_headers, + ) cors = [cors_rule1, cors_rule2] @@ -195,7 +197,7 @@ async def test_set_cors(self, **kwargs): # Assert received_props = await qsc.get_service_properties() - self._assert_cors_equal(received_props['cors'], cors) + self._assert_cors_equal(received_props["cors"], cors) # --Test cases for errors --------------------------------------- @@ -206,9 +208,7 @@ async def test_retention_no_days(self, **kwargs): # Assert qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - pytest.raises(ValueError, - RetentionPolicy, - True, None) + pytest.raises(ValueError, RetentionPolicy, True, None) @QueuePreparer() @recorded_by_proxy_async @@ -220,7 +220,7 @@ async def test_too_many_cors_rules(self, **kwargs): qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) cors = [] for _ in range(0, 6): - cors.append(CorsRule(['www.xyz.com'], ['GET'])) + cors.append(CorsRule(["www.xyz.com"], ["GET"])) # Assert with pytest.raises(HttpResponseError): @@ -234,13 +234,15 @@ async def test_retention_too_long(self, **kwargs): # Arrange qsc = QueueServiceClient(self.account_url(storage_account_name, "queue"), storage_account_key) - minute_metrics = Metrics(enabled=True, include_apis=True, - retention_policy=RetentionPolicy(enabled=True, days=366)) + minute_metrics = Metrics( + enabled=True, include_apis=True, retention_policy=RetentionPolicy(enabled=True, days=366) + ) # Assert with pytest.raises(HttpResponseError): await qsc.set_service_properties() + # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_service_stats.py b/sdk/storage/azure-storage-queue/tests/test_queue_service_stats.py index 10c34452eb85..21b5caaa3815 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_service_stats.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_service_stats.py @@ -12,23 +12,24 @@ from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import QueuePreparer + # --Test Class ----------------------------------------------------------------- class TestQueueServiceStats(StorageRecordedTestCase): # --Helpers----------------------------------------------------------------- def _assert_stats_default(self, stats): assert stats is not None - assert stats['geo_replication'] is not None + assert stats["geo_replication"] is not None - assert stats['geo_replication']['status'] == 'live' - assert stats['geo_replication']['last_sync_time'] is not None + assert stats["geo_replication"]["status"] == "live" + assert stats["geo_replication"]["last_sync_time"] is not None def _assert_stats_unavailable(self, stats): assert stats is not None - assert stats['geo_replication'] is not None + assert stats["geo_replication"] is not None - assert stats['geo_replication']['status'] == 'unavailable' - assert stats['geo_replication']['last_sync_time'] is None + assert stats["geo_replication"]["status"] == "unavailable" + assert stats["geo_replication"]["last_sync_time"] is None # --Test cases per service --------------------------------------- @pytest.mark.playback_test_only @@ -64,5 +65,5 @@ def test_queue_service_stats_when_unavailable(self, **kwargs): # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py b/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py index ad928208f578..cc9e2a5de7cc 100644 --- a/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py +++ b/sdk/storage/azure-storage-queue/tests/test_queue_service_stats_async.py @@ -12,23 +12,24 @@ from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import QueuePreparer + # --Test Class ----------------------------------------------------------------- class TestAsyncQueueServiceStats(AsyncStorageRecordedTestCase): # --Helpers----------------------------------------------------------------- def _assert_stats_default(self, stats): assert stats is not None - assert stats['geo_replication'] is not None + assert stats["geo_replication"] is not None - assert stats['geo_replication']['status'] == 'live' - assert stats['geo_replication']['last_sync_time'] is not None + assert stats["geo_replication"]["status"] == "live" + assert stats["geo_replication"]["last_sync_time"] is not None def _assert_stats_unavailable(self, stats): assert stats is not None - assert stats['geo_replication'] is not None + assert stats["geo_replication"] is not None - assert stats['geo_replication']['status'] == 'unavailable' - assert stats['geo_replication']['last_sync_time'] is None + assert stats["geo_replication"]["status"] == "unavailable" + assert stats["geo_replication"]["last_sync_time"] is None # --Test cases per service --------------------------------------- @pytest.mark.playback_test_only @@ -62,6 +63,7 @@ async def test_queue_service_stats_when_unavailable(self, **kwargs): # Assert self._assert_stats_unavailable(stats) + # ------------------------------------------------------------------------------ -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()