Skip to content

Commit c7d0845

Browse files
[Storage] Fixed Download/Upload APIs to be compatible with all transports + responses (#40615)
1 parent c0d1041 commit c7d0845

File tree

17 files changed

+399
-155
lines changed

17 files changed

+399
-155
lines changed

sdk/storage/azure-storage-blob/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/storage/azure-storage-blob",
5-
"Tag": "python/storage/azure-storage-blob_e1b13301f6"
5+
"Tag": "python/storage/azure-storage-blob_2bfcc41daa"
66
}

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ async def retry_hook(settings, **kwargs):
3939
async def is_checksum_retry(response):
4040
# retry if invalid content md5
4141
if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"):
42-
try:
43-
await response.http_response.load_body() # Load the body in memory and close the socket
44-
except (StreamClosedError, StreamConsumedError):
45-
pass
42+
if hasattr(response.http_response, "load_body"):
43+
try:
44+
await response.http_response.load_body() # Load the body in memory and close the socket
45+
except (StreamClosedError, StreamConsumedError):
46+
pass
4647
computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64(
4748
StorageContentValidation.get_content_md5(response.http_response.body())
4849
)

sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@
4646
async def process_content(data: Any, start_offset: int, end_offset: int, encryption: Dict[str, Any]) -> bytes:
4747
if data is None:
4848
raise ValueError("Response cannot be None.")
49-
await data.response.load_body()
50-
content = cast(bytes, data.response.body())
49+
if hasattr(data.response, "is_stream_consumed") and data.response.is_stream_consumed:
50+
content = data.response.content
51+
else:
52+
content = b"".join([d async for d in data])
5153
if encryption.get('key') is not None or encryption.get('resolver') is not None:
5254
try:
5355
return decrypt_blob(
@@ -57,12 +59,14 @@ async def process_content(data: Any, start_offset: int, end_offset: int, encrypt
5759
content,
5860
start_offset,
5961
end_offset,
60-
data.response.headers)
62+
data.response.headers
63+
)
6164
except Exception as error:
6265
raise HttpResponseError(
6366
message="Decryption failed.",
6467
response=data.response,
65-
error=error) from error
68+
error=error
69+
) from error
6670
return content
6771

6872

sdk/storage/azure-storage-blob/tests/test_common_blob.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,7 @@
5353
from devtools_testutils import FakeTokenCredential, recorded_by_proxy
5454
from devtools_testutils.storage import StorageRecordedTestCase
5555
from settings.testcase import BlobPreparer
56-
from test_helpers import (
57-
MockStorageTransport,
58-
_build_base_file_share_headers,
59-
_create_file_share_oauth,
60-
)
56+
from test_helpers import _build_base_file_share_headers, _create_file_share_oauth
6157

6258
# ------------------------------------------------------------------------------
6359
SMALL_BLOB_SIZE = 1024
@@ -3597,57 +3593,4 @@ def test_upload_blob_partial_stream_chunked(self, **kwargs):
35973593
result = blob.download_blob().readall()
35983594
assert result == data[:length]
35993595

3600-
@BlobPreparer()
3601-
def test_mock_transport_no_content_validation(self, **kwargs):
3602-
storage_account_name = kwargs.pop("storage_account_name")
3603-
storage_account_key = kwargs.pop("storage_account_key")
3604-
3605-
transport = MockStorageTransport()
3606-
blob_client = BlobClient(
3607-
self.account_url(storage_account_name, "blob"),
3608-
container_name='test_cont',
3609-
blob_name='test_blob',
3610-
credential=storage_account_key,
3611-
transport=transport,
3612-
retry_total=0
3613-
)
3614-
3615-
content = blob_client.download_blob()
3616-
assert content is not None
3617-
3618-
props = blob_client.get_blob_properties()
3619-
assert props is not None
3620-
3621-
data = b"Hello World!"
3622-
resp = blob_client.upload_blob(data, overwrite=True)
3623-
assert resp is not None
3624-
3625-
blob_data = blob_client.download_blob().read()
3626-
assert blob_data == b"Hello World!" # data is fixed by mock transport
3627-
3628-
resp = blob_client.delete_blob()
3629-
assert resp is None
3630-
3631-
@BlobPreparer()
3632-
def test_mock_transport_with_content_validation(self, **kwargs):
3633-
storage_account_name = kwargs.pop("storage_account_name")
3634-
storage_account_key = kwargs.pop("storage_account_key")
3635-
3636-
transport = MockStorageTransport()
3637-
blob_client = BlobClient(
3638-
self.account_url(storage_account_name, "blob"),
3639-
container_name='test_cont',
3640-
blob_name='test_blob',
3641-
credential=storage_account_key,
3642-
transport=transport,
3643-
retry_total=0
3644-
)
3645-
3646-
data = b"Hello World!"
3647-
resp = blob_client.upload_blob(data, overwrite=True, validate_content=True)
3648-
assert resp is not None
3649-
3650-
blob_data = blob_client.download_blob(validate_content=True).read()
3651-
assert blob_data == b"Hello World!" # data is fixed by mock transport
3652-
36533596
# ------------------------------------------------------------------------------

sdk/storage/azure-storage-blob/tests/test_common_blob_async.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from settings.testcase import BlobPreparer
5656
from test_helpers_async import (
5757
AsyncStream,
58-
MockStorageTransport,
5958
_build_base_file_share_headers,
6059
_create_file_share_oauth
6160
)
@@ -3525,59 +3524,4 @@ async def test_upload_blob_partial_stream_chunked(self, **kwargs):
35253524
result = await (await blob.download_blob()).readall()
35263525
assert result == data[:length]
35273526

3528-
@BlobPreparer()
3529-
async def test_mock_transport_no_content_validation(self, **kwargs):
3530-
storage_account_name = kwargs.pop("storage_account_name")
3531-
storage_account_key = kwargs.pop("storage_account_key")
3532-
3533-
transport = MockStorageTransport()
3534-
blob_client = BlobClient(
3535-
self.account_url(storage_account_name, "blob"),
3536-
container_name='test_cont',
3537-
blob_name='test_blob',
3538-
credential=storage_account_key,
3539-
transport=transport,
3540-
retry_total=0
3541-
)
3542-
3543-
content = await blob_client.download_blob()
3544-
assert content is not None
3545-
3546-
props = await blob_client.get_blob_properties()
3547-
assert props is not None
3548-
3549-
data = b"Hello Async World!"
3550-
stream = AsyncStream(data)
3551-
resp = await blob_client.upload_blob(stream, overwrite=True)
3552-
assert resp is not None
3553-
3554-
blob_data = await (await blob_client.download_blob()).read()
3555-
assert blob_data == b"Hello Async World!" # data is fixed by mock transport
3556-
3557-
resp = await blob_client.delete_blob()
3558-
assert resp is None
3559-
3560-
@BlobPreparer()
3561-
async def test_mock_transport_with_content_validation(self, **kwargs):
3562-
storage_account_name = kwargs.pop("storage_account_name")
3563-
storage_account_key = kwargs.pop("storage_account_key")
3564-
3565-
transport = MockStorageTransport()
3566-
blob_client = BlobClient(
3567-
self.account_url(storage_account_name, "blob"),
3568-
container_name='test_cont',
3569-
blob_name='test_blob',
3570-
credential=storage_account_key,
3571-
transport=transport,
3572-
retry_total=0
3573-
)
3574-
3575-
data = b"Hello Async World!"
3576-
stream = AsyncStream(data)
3577-
resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True)
3578-
assert resp is not None
3579-
3580-
blob_data = await (await blob_client.download_blob(validate_content=True)).read()
3581-
assert blob_data == b"Hello Async World!" # data is fixed by mock transport
3582-
35833527
# ------------------------------------------------------------------------------

sdk/storage/azure-storage-blob/tests/test_helpers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Dict, Optional, Tuple
1111
from typing_extensions import Self
1212

13-
from azure.core.pipeline.transport import HttpTransport, RequestsTransportResponse
13+
from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse
1414
from azure.core.rest import HttpRequest
1515
from azure.storage.blob._serialize import get_api_version
1616
from requests import Response
@@ -92,15 +92,15 @@ def tell(self):
9292
return self.wrapped_stream.tell()
9393

9494

95-
class MockHttpClientResponse(Response):
95+
class MockClientResponse(Response):
9696
def __init__(
9797
self, url: str,
9898
body_bytes: bytes,
9999
headers: Dict[str, Any],
100100
status: int = 200,
101101
reason: str = "OK"
102102
) -> None:
103-
super(MockHttpClientResponse).__init__()
103+
super(MockClientResponse).__init__()
104104
self._url = url
105105
self._body = body_bytes
106106
self._content = body_bytes
@@ -113,9 +113,9 @@ def __init__(
113113
self.raw = HTTPResponse()
114114

115115

116-
class MockStorageTransport(HttpTransport):
116+
class MockLegacyTransport(RequestsTransport):
117117
"""
118-
This transport returns legacy http response objects from azure core and is
118+
This transport returns http response objects from azure core pipelines and is
119119
intended only to test our backwards compatibility support.
120120
"""
121121
def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse:
@@ -132,7 +132,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse
132132

133133
rest_response = RequestsTransportResponse(
134134
request=request,
135-
requests_response=MockHttpClientResponse(
135+
requests_response=MockClientResponse(
136136
request.url,
137137
b"Hello World!",
138138
headers,
@@ -142,7 +142,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse
142142
# get_blob_properties
143143
rest_response = RequestsTransportResponse(
144144
request=request,
145-
requests_response=MockHttpClientResponse(
145+
requests_response=MockClientResponse(
146146
request.url,
147147
b"",
148148
{
@@ -155,7 +155,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse
155155
# upload_blob
156156
rest_response = RequestsTransportResponse(
157157
request=request,
158-
requests_response=MockHttpClientResponse(
158+
requests_response=MockClientResponse(
159159
request.url,
160160
b"",
161161
{
@@ -169,7 +169,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse
169169
# delete_blob
170170
rest_response = RequestsTransportResponse(
171171
request=request,
172-
requests_response=MockHttpClientResponse(
172+
requests_response=MockClientResponse(
173173
request.url,
174174
b"",
175175
{
@@ -180,7 +180,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse
180180
)
181181
)
182182
else:
183-
raise ValueError("The request is not accepted as part of MockStorageTransport.")
183+
raise ValueError("The request is not accepted as part of MockLegacyTransport.")
184184
return rest_response
185185

186186
def __enter__(self) -> Self:

sdk/storage/azure-storage-blob/tests/test_helpers_async.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6-
6+
import asyncio
77
import aiohttp
8+
from collections import deque
89
from datetime import datetime, timezone
910
from io import IOBase, UnsupportedOperation
1011
from typing import Any, Dict, Optional, Tuple
12+
from unittest.mock import Mock, AsyncMock
1113

1214
from azure.core.pipeline.transport import AioHttpTransportResponse, AsyncHttpTransport
1315
from azure.core.rest import HttpRequest
1416
from azure.storage.blob._serialize import get_api_version
1517
from aiohttp import ClientResponse
18+
from aiohttp.streams import StreamReader
19+
from aiohttp.client_proto import ResponseHandler
1620

1721

1822
def _build_base_file_share_headers(bearer_token_string: str, content_length: int = 0) -> Dict[str, Any]:
@@ -126,11 +130,15 @@ def __init__(
126130
self._loop = None
127131
self.status = status
128132
self.reason = reason
133+
self.content = StreamReader(ResponseHandler(asyncio.get_event_loop()), 65535)
134+
self.content.total_bytes = len(body_bytes)
135+
self.content._buffer = deque([body_bytes])
136+
self.content._eof = True
129137

130138

131-
class MockStorageTransport(AsyncHttpTransport):
139+
class MockLegacyTransport(AsyncHttpTransport):
132140
"""
133-
This transport returns legacy http response objects from azure core and is
141+
This transport returns legacy http response objects from azure core and is
134142
intended only to test our backwards compatibility support.
135143
"""
136144
async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportResponse:
@@ -199,7 +207,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes
199207
decompress=False
200208
)
201209
else:
202-
raise ValueError("The request is not accepted as part of MockStorageTransport.")
210+
raise ValueError("The request is not accepted as part of MockLegacyTransport.")
203211

204212
await rest_response.load_body()
205213
return rest_response

0 commit comments

Comments
 (0)