From 3fcee3b258c225a7c1fe907a6b5e087532803ea7 Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Thu, 27 Feb 2025 00:16:23 +0100 Subject: [PATCH 01/11] [901] Large file uploads --- databricks/sdk/config.py | 40 + databricks/sdk/mixins/files.py | 537 ++++++++++++- tests/test_files.py | 1311 +++++++++++++++++++++++++++++++- 3 files changed, 1884 insertions(+), 4 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 490c6ba4e..f047a2156 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -1,5 +1,6 @@ import configparser import copy +import datetime import logging import os import pathlib @@ -97,6 +98,45 @@ class Config: files_api_client_download_max_total_recovers = None files_api_client_download_max_total_recovers_without_progressing = 1 + # File multipart upload parameters + # ---------------------- + + # Minimal input stream size (bytes) to use multipart / resumable uploads. + # For small files it's more efficient to make one single-shot upload request. + # When uploading a file, SDK will initially buffer this many bytes from input stream. + # This parameter can be less or bigger than multipart_upload_chunk_size. + multipart_upload_min_stream_size: int = 5 * 1024 * 1024 + + # Maximum number of presigned URLs that can be requested at a time. + # + # The more URLs we request at once, the higher chance is that some of the URLs will expire + # before we get to use it. We discover the presigned URL is expired *after* sending the + # input stream partition to the server. So to retry the upload of this partition we must rewind + # the stream back. In case of a non-seekable stream we cannot rewind, so we'll abort + # the upload. To reduce the chance of this, we're requesting presigned URLs one by one + # and using them immediately. + multipart_upload_batch_url_count: int = 1 + + # Size of the chunk to use for multipart uploads. + # + # The smaller chunk is, the less chance for network errors (or URL get expired), + # but the more requests we'll make. + # For AWS, minimum is 5Mb: https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html + # For GCP, minimum is 256 KiB (and also recommended multiple is 256 KiB) + # boto uses 8Mb: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig + multipart_upload_chunk_size: int = 10 * 1024 * 1024 + + # use maximum duration of 1 hour + multipart_upload_url_expiration_duration: datetime.timedelta = datetime.timedelta(hours=1) + + # This is not a "wall time" cutoff for the whole upload request, + # but a maximum time between consecutive data reception events (even 1 byte) from the server + multipart_upload_single_chunk_upload_timeout_seconds: float = 60 + + # Limit of retries during multipart upload. + # Retry counter is reset when progressing along the stream. + multipart_upload_max_retries = 3 + def __init__( self, *, diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 678b4b630..d6d8b0437 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -1,26 +1,34 @@ from __future__ import annotations import base64 +import datetime import logging import os import pathlib import platform +import re import shutil import sys +import xml.etree.ElementTree as ET from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterator +from datetime import timedelta from io import BytesIO from types import TracebackType from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable, Optional, Type, Union) from urllib import parse +import requests from requests import RequestException -from .._base_client import _RawResponse, _StreamingResponse +from .._base_client import _BaseClient, _RawResponse, _StreamingResponse from .._property import _cached_property -from ..errors import NotFound +from ..clock import Clock, RealClock +from ..errors import AlreadyExists, NotFound +from ..errors.mapper import _error_mapper +from ..retries import retried from ..service import files from ..service._internal import _escape_multi_segment_path_parameter from ..service.files import DownloadResponse @@ -650,9 +658,11 @@ def delete(self, path: str, *, recursive=False): class FilesExt(files.FilesAPI): __doc__ = files.FilesAPI.__doc__ - def __init__(self, api_client, config: Config): + def __init__(self, api_client, config: Config, clock: Clock = None): super().__init__(api_client) self._config = config.copy() + self._clock = clock or RealClock() + self._multipart_upload_read_ahead_bytes = 1 def download(self, file_path: str) -> DownloadResponse: """Download a file. @@ -678,6 +688,527 @@ def download(self, file_path: str) -> DownloadResponse: initial_response.contents._response = wrapped_response return initial_response + def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): + # Upload empty and small files with one-shot upload. + pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size) + if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size: + _LOG.debug( + f'Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes' + ) + return super().upload(file_path=file_path, contents=pre_read_buffer, overwrite=overwrite) + + query = {'action': 'initiate-upload'} + if overwrite is not None: + query['overwrite'] = overwrite + + # _api.do() does retry + initiate_upload_response = self._api.do( + 'POST', f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}', query=query) + # no need to check response status, _api.do() will throw exception on failure + + if initiate_upload_response.get('multipart_upload'): + cloud_provider_session = self._create_cloud_provider_session() + session_token = initiate_upload_response['multipart_upload'].get('session_token') + if not session_token: + raise ValueError(f'Unexpected server response: {initiate_upload_response}') + + try: + self._multipart_upload(file_path, contents, session_token, pre_read_buffer, + cloud_provider_session) + except Exception as e: + _LOG.info(f"Aborting multipart upload on error: {e}") + try: + self._abort_multipart_upload(file_path, session_token, cloud_provider_session) + except BaseException as ex: + _LOG.warning(f"Failed to abort upload: {ex}") + # ignore, abort is a best-effort + finally: + # rethrow original exception + raise e from None + + elif initiate_upload_response.get('resumable_upload'): + cloud_provider_session = self._create_cloud_provider_session() + session_token = initiate_upload_response['resumable_upload']['session_token'] + self._resumable_upload(file_path, contents, session_token, overwrite, pre_read_buffer, + cloud_provider_session) + else: + raise ValueError(f'Unexpected server response: {initiate_upload_response}') + + def _multipart_upload(self, target_path: str, input_stream: BinaryIO, session_token: str, + pre_read_buffer: bytes, cloud_provider_session: requests.Session): + current_part_number = 1 + etags: dict = {} + + # Why are we buffering the current chunk? + # AWS and Azure don't support traditional "Transfer-encoding: chunked", so we must + # provide each chunk size up front. In case of a non-seekable input stream we need + # to buffer a chunk before uploading to know its size. This also allows us to rewind + # the stream before retrying on request failure. + # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html + # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads + + chunk_offset = 0 # used only for logging + + # This buffer is expected to contain at least multipart_upload_chunk_size bytes. + # Note that initially buffer can be bigger (from pre_read_buffer). + buffer = pre_read_buffer + + def fill_buffer(): + bytes_to_read = max(0, self._config.multipart_upload_chunk_size - len(buffer)) + if bytes_to_read > 0: + next_buf = input_stream.read(bytes_to_read) + new_buffer = buffer + next_buf + return new_buffer + else: + # we have already buffered enough data + return buffer + + retry_count = 0 + eof = False + while not eof: + # If needed, buffer the next chunk. + buffer = fill_buffer() + if not len(buffer): + # End of stream, no need to request the next block of upload URLs. + break + + _LOG.debug( + f"Multipart upload: requesting next {self._config.multipart_upload_batch_url_count} upload URLs starting from part {current_part_number}" + ) + + body: dict = { + 'path': target_path, + 'session_token': session_token, + 'start_part_number': current_part_number, + 'count': self._config.multipart_upload_batch_url_count, + 'expire_time': self._get_url_expire_time() + } + + headers = {'Content-Type': 'application/json'} + + # _api.do() does retry + upload_part_urls_response = self._api.do('POST', + '/api/2.0/fs/create-upload-part-urls', + headers=headers, + body=body) + # no need to check response status, _api.do() will throw exception on failure + + upload_part_urls = upload_part_urls_response.get('upload_part_urls', []) + if not len(upload_part_urls): + raise ValueError(f'Unexpected server response: {upload_part_urls_response}') + + for upload_part_url in upload_part_urls: + buffer = fill_buffer() + actual_buffer_length = len(buffer) + if not actual_buffer_length: + eof = True + break + + url = upload_part_url['url'] + required_headers = upload_part_url.get('headers', []) + assert current_part_number == upload_part_url['part_number'] + + headers: dict = {'Content-Type': 'application/octet-stream'} + for h in required_headers: + headers[h['name']] = h['value'] + + actual_chunk_length = min(actual_buffer_length, self._config.multipart_upload_chunk_size) + _LOG.debug( + f'Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]' + ) + + chunk = BytesIO(buffer[:actual_chunk_length]) + + def rewind(): + chunk.seek(0, os.SEEK_SET) + + def perform(): + result = cloud_provider_session.request( + 'PUT', + url, + headers=headers, + data=chunk, + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + return result + + # following _BaseClient timeout + retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + upload_response = retried(timeout=timedelta(seconds=retry_timeout_seconds), + is_retryable=_BaseClient._is_retryable, + clock=self._clock, + before_retry=rewind)(perform)() + + if upload_response.status_code in (200, 201): + # Chunk upload successful + + chunk_offset += actual_chunk_length + + etag = upload_response.headers.get('ETag', '') + etags[current_part_number] = etag + + # Discard uploaded bytes + buffer = buffer[actual_chunk_length:] + + # Reset retry count when progressing along the stream + retry_count = 0 + + elif FilesExt._is_url_expired_response(upload_response): + if retry_count < self._config.multipart_upload_max_retries: + retry_count += 1 + _LOG.debug('Upload URL expired') + # Preserve the buffer so we'll upload the current part again using next upload URL + else: + # don't confuse user with unrelated "Permission denied" error. + raise ValueError(f'Unsuccessful chunk upload: upload URL expired') + + else: + message = f'Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}' + _LOG.warning(message) + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError(message) + + current_part_number += 1 + + _LOG.debug( + f'Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes' + ) + + query = {'action': 'complete-upload', 'upload_type': 'multipart', 'session_token': session_token} + headers = {'Content-Type': 'application/json'} + body: dict = {} + + parts = [] + for etag in sorted(etags.items()): + part = {'part_number': etag[0], 'etag': etag[1]} + parts.append(part) + + body['parts'] = parts + + # _api.do() does retry + self._api.do('POST', + f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}', + query=query, + headers=headers, + body=body) + # no need to check response status, _api.do() will throw exception on failure + + @staticmethod + def _is_url_expired_response(response: requests.Response): + if response.status_code != 403: + return False + + try: + xml_root = ET.fromstring(response.content) + if xml_root.tag != 'Error': + return False + + code = xml_root.find('Code') + if code is None: + return False + + if code.text == 'AuthenticationFailed': + # Azure + details = xml_root.find('AuthenticationErrorDetail') + if details is not None and 'Signature not valid in the specified time frame' in details.text: + return True + + if code.text == 'AccessDenied': + # AWS + message = xml_root.find('Message') + if message is not None and message.text == 'Request has expired': + return True + + except ET.ParseError: + pass + + return False + + def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_token: str, overwrite: bool, + pre_read_buffer: bytes, cloud_provider_session: requests.Session): + # https://cloud.google.com/storage/docs/performing-resumable-uploads + # Session URI we're using expires after a week + + # Why are we buffering the current chunk? + # When using resumable upload API we're uploading data in chunks. During chunk upload + # server responds with the "received offset" confirming how much data it stored so far, + # so we should continue uploading from that offset. (Note this is not a failure but an + # expected behaviour as per the docs.) But, input stream might be consumed beyond that + # offset, since server might have read more data than it confirmed received, or some data + # might have been pre-cached by e.g. OS or a proxy. So, to continue upload, we must rewind + # the input stream back to the byte next to "received offset". This is not possible + # for non-seekable input stream, so we must buffer the whole last chunk and seek inside + # the buffer. By always uploading from the buffer we fully support non-seekable streams. + + # Why are we doing read-ahead? + # It's not possible to upload an empty chunk as "Content-Range" header format does not + # support this. So if current chunk happens to finish exactly at the end of the stream, + # we need to know that and mark the chunk as last (by passing real file size in the + # "Content-Range" header) when uploading it. To detect if we're at the end of the stream + # we're reading "ahead" an extra bytes but not uploading them immediately. If + # nothing has been read ahead, it means we're at the end of the stream. + # On the contrary, in multipart upload we can decide to complete upload *after* + # last chunk has been sent. + + body: dict = {'path': target_path, 'session_token': session_token} + + headers = {'Content-Type': 'application/json'} + + # _api.do() does retry + resumable_upload_url_response = self._api.do('POST', + '/api/2.0/fs/create-resumable-upload-url', + headers=headers, + body=body) + # no need to check response status, _api.do() will throw exception on failure + + resumable_upload_url_node = resumable_upload_url_response.get('resumable_upload_url') + if not resumable_upload_url_node: + raise ValueError(f'Unexpected server response: {resumable_upload_url_response}') + + resumable_upload_url = resumable_upload_url_node.get('url') + if not resumable_upload_url: + raise ValueError(f'Unexpected server response: {resumable_upload_url_response}') + + required_headers = resumable_upload_url_node.get('headers', []) + + try: + # We will buffer this many bytes: one chunk + read-ahead block. + # Note buffer may contain more data initially (from pre_read_buffer). + min_buffer_size = self._config.multipart_upload_chunk_size + self._multipart_upload_read_ahead_bytes + + buffer = pre_read_buffer + + # How many bytes in the buffer were confirmed to be received by the server. + # All the remaining bytes in the buffer must be uploaded. + uploaded_bytes_count = 0 + + chunk_offset = 0 + + retry_count = 0 + while True: + # If needed, fill the buffer to contain at least min_buffer_size bytes + # (unless end of stream), discarding already uploaded bytes. + bytes_to_read = max(0, min_buffer_size - (len(buffer) - uploaded_bytes_count)) + next_buf = input_stream.read(bytes_to_read) + buffer = buffer[uploaded_bytes_count:] + next_buf + + if len(next_buf) < bytes_to_read: + # This is the last chunk in the stream. + # Let's upload all the remaining bytes in one go. + actual_chunk_length = len(buffer) + file_size = chunk_offset + actual_chunk_length + else: + # More chunks expected, let's upload current chunk (excluding read-ahead block). + actual_chunk_length = self._config.multipart_upload_chunk_size + file_size = '*' + + headers: dict = {'Content-Type': 'application/octet-stream'} + for h in required_headers: + headers[h['name']] = h['value'] + + chunk_last_byte_offset = chunk_offset + actual_chunk_length - 1 + content_range_header = f'bytes {chunk_offset}-{chunk_last_byte_offset}/{file_size}' + _LOG.debug(f'Uploading chunk: {content_range_header}') + headers['Content-Range'] = content_range_header + + try: + # We're not retrying this single request as we don't know where to rewind to. + # Instead, in case of retryable failure we'll re-request the current offset + # from the server and resume upload from there in the main upload loop. + upload_response: requests.Response = cloud_provider_session.request( + 'PUT', + resumable_upload_url, + headers=headers, + data=BytesIO(buffer[:actual_chunk_length]), + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + retry_count = 0 # reset retry count when progressing along the stream + except RequestException as e: + _LOG.warning(f'Failure during upload request: {sys.exc_info()}') + if _BaseClient._is_retryable( + e) and retry_count < self._config.multipart_upload_max_retries: + retry_count += 1 + # Chunk upload threw an error, try to retrieve the current received offset + try: + # https://cloud.google.com/storage/docs/performing-resumable-uploads#status-check + headers['Content-Range'] = 'bytes */*' + upload_response = cloud_provider_session.request( + 'PUT', + resumable_upload_url, + headers=headers, + data=b'', + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + except RequestException: + # status check failed, abort the upload + raise e from None + else: + # error is not retryable, abort the upload + raise e from None + + if upload_response.status_code in (200, 201): + if file_size == '*': + raise ValueError( + f"Received unexpected status {upload_response.status_code} before reaching end of stream" + ) + + # upload complete + break + + elif upload_response.status_code == 308: + # chunk accepted, let's determine received offset to resume from there + range_string = upload_response.headers.get('Range') + confirmed_offset = self._extract_range_offset(range_string) + _LOG.debug(f"Received confirmed offset: {confirmed_offset}") + + if confirmed_offset: + if confirmed_offset < chunk_offset - 1 or confirmed_offset > chunk_last_byte_offset: + raise ValueError( + f"Unexpected received offset: {confirmed_offset} is outside of expected range, chunk offset: {chunk_offset}, chunk last byte offset: {chunk_last_byte_offset}" + ) + else: + if chunk_offset > 0: + raise ValueError( + f"Unexpected received offset: {confirmed_offset} is outside of expected range, chunk offset: {chunk_offset}, chunk last byte offset: {chunk_last_byte_offset}" + ) + + # We have just uploaded a part of chunk starting from offset "chunk_offset" and ending + # at offset "confirmed_offset" (inclusive), so the next chunk will start at + # offset "confirmed_offset + 1" + if confirmed_offset: + next_chunk_offset = confirmed_offset + 1 + else: + next_chunk_offset = chunk_offset + uploaded_bytes_count = next_chunk_offset - chunk_offset + chunk_offset = next_chunk_offset + + elif upload_response.status_code == 400: + # Expecting response body to be small to be safely logged + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError( + f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}") + + elif upload_response.status_code == 412 and not overwrite: + # Assuming this is only possible reason + # Full message in this case: "At least one of the pre-conditions you specified did not hold." + raise AlreadyExists('The file being created already exists.') + else: + if _LOG.isEnabledFor(logging.DEBUG): + _LOG.debug( + f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}" + ) + + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError(f"Failed to upload: {upload_response}") + + except Exception as e: + _LOG.info(f"Aborting resumable upload on error: {e}") + try: + self._abort_resumable_upload(resumable_upload_url, required_headers, cloud_provider_session) + except BaseException as ex: + _LOG.warning(f"Failed to abort upload: {ex}") + # ignore, abort is a best-effort + finally: + # rethrow original exception + raise e from None + + @staticmethod + def _extract_range_offset(range_string: Optional[str]) -> Optional[int]: + if not range_string: + return None # server did not yet confirm any bytes + + if match := re.match('bytes=0-(\\d+)', range_string): + return int(match.group(1)) + else: + raise ValueError(f"Cannot parse response header: Range: {range_string}") + + def _get_url_expire_time(self): + current_time = datetime.datetime.now(datetime.timezone.utc) + expire_time = current_time + self._config.multipart_upload_url_expiration_duration + # From Google Protobuf doc: + # In JSON format, the Timestamp type is encoded as a string in the + # * [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the + # * format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" + return expire_time.strftime('%Y-%m-%dT%H:%M:%SZ') + + def _abort_multipart_upload(self, target_path: str, session_token: str, + cloud_provider_session: requests.Session): + body: dict = { + 'path': target_path, + 'session_token': session_token, + 'expire_time': self._get_url_expire_time() + } + + headers = {'Content-Type': 'application/json'} + + # _api.do() does retry + abort_url_response = self._api.do('POST', + '/api/2.0/fs/create-abort-upload-url', + headers=headers, + body=body) + # no need to check response status, _api.do() will throw exception on failure + + abort_upload_url_node = abort_url_response['abort_upload_url'] + abort_url = abort_upload_url_node['url'] + required_headers = abort_upload_url_node.get('headers', []) + + headers: dict = {'Content-Type': 'application/octet-stream'} + for h in required_headers: + headers[h['name']] = h['value'] + + def perform(): + result = cloud_provider_session.request( + 'DELETE', + abort_url, + headers=headers, + data=b'', + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + return result + + # following _BaseClient timeout + retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + abort_response = retried(timeout=timedelta(seconds=retry_timeout_seconds), + is_retryable=_BaseClient._is_retryable, + clock=self._clock)(perform)() + + if abort_response.status_code not in (200, 201): + raise ValueError(abort_response) + + def _abort_resumable_upload(self, resumable_upload_url: str, required_headers: list, + cloud_provider_session: requests.Session): + headers: dict = {} + for h in required_headers: + headers[h['name']] = h['value'] + + def perform(): + result = cloud_provider_session.request( + 'DELETE', + resumable_upload_url, + headers=headers, + data=b'', + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + return result + + # following _BaseClient timeout + retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + abort_response = retried(timeout=timedelta(seconds=retry_timeout_seconds), + is_retryable=_BaseClient._is_retryable, + clock=self._clock)(perform)() + + if abort_response.status_code not in (200, 201): + raise ValueError(abort_response) + + def _create_cloud_provider_session(self): + # Create a separate session which does not inherit + # auth headers from BaseClient session. + session = requests.Session() + + # following session config in _BaseClient + http_adapter = requests.adapters.HTTPAdapter(self._config.max_connection_pools or 20, + self._config.max_connections_per_pool or 20, + pool_block=True) + session.mount("https://", http_adapter) + # presigned URL for storage proxy can use plain HTTP + session.mount("http://", http_adapter) + return session + def _download_raw_stream(self, file_path: str, start_byte_offset: int, diff --git a/tests/test_files.py b/tests/test_files.py index f4d916f6f..3ed7cf8dd 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,14 +1,27 @@ +import copy +import hashlib +import io +import json import logging import os +import random import re +import time from dataclasses import dataclass -from typing import List, Union +from datetime import datetime, timedelta, timezone +from tempfile import mkstemp +from typing import Callable, List, Optional, Union +from urllib.parse import parse_qs, urlparse import pytest +import requests +import requests_mock from requests import RequestException from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config +from databricks.sdk.errors.platform import (AlreadyExists, BadRequest, + InternalError, PermissionDenied) logger = logging.getLogger(__name__) @@ -338,3 +351,1299 @@ class _Constants: ids=DownloadTestCase.to_string) def test_download_recover(config: Config, test_case: DownloadTestCase): test_case.run(config) + + +class FileContent: + + def __init__(self, length: int, checksum: str): + self._length = length + self.checksum = checksum + + @classmethod + def from_bytes(cls, data: bytes): + sha256 = hashlib.sha256() + sha256.update(data) + return FileContent(len(data), sha256.hexdigest()) + + def __repr__(self): + return f"Length: {self._length}, checksum: {self.checksum}" + + def __eq__(self, other): + if not isinstance(other, FileContent): + return NotImplemented + return self._length == other._length and self.checksum == other.checksum + + +class MultipartUploadServerState: + upload_chunk_url_prefix = 'https://cloud_provider.com/upload-chunk/' + abort_upload_url_prefix = 'https://cloud_provider.com/abort-upload/' + + def __init__(self): + self.issued_multipart_urls = {} # part_number -> expiration_time + self.uploaded_chunks = {} # part_number -> [chunk file path, etag] + self.session_token = 'token-' + MultipartUploadServerState.randomstr() + self.file_content = None + self.issued_abort_url_expire_time = None + self.aborted = False + + def create_upload_chunk_url(self, path: str, part_number: int, expire_time: datetime) -> str: + assert not self.aborted + # client may have requested a URL for the same part if retrying on network error + self.issued_multipart_urls[part_number] = expire_time + return f'{self.upload_chunk_url_prefix}{path}/{part_number}' + + def create_abort_url(self, path: str, expire_time: datetime) -> str: + assert not self.aborted + self.issued_abort_url_expire_time = expire_time + return f'{self.abort_upload_url_prefix}{path}' + + def save_part(self, part_number: int, part_content: bytes, etag: str): + assert not self.aborted + assert len(part_content) > 0 + + logger.info(f"Saving part {part_number} of size {len(part_content)}") + + # chunk might already have been uploaded + existing_chunk = self.uploaded_chunks.get(part_number) + if existing_chunk: + chunk_file = existing_chunk[0] + with open(chunk_file, "wb") as f: + f.write(part_content) + else: + fd, chunk_file = mkstemp() + with open(fd, "wb") as f: + f.write(part_content) + + self.uploaded_chunks[part_number] = [chunk_file, etag] + + def cleanup(self): + for [file, _] in self.uploaded_chunks.values(): + os.remove(file) + + def get_file_content(self) -> FileContent: + assert not self.aborted + return self.file_content + + def upload_complete(self, etags: dict): + assert not self.aborted + # validate etags + expected_etags = {} + for part_number in self.uploaded_chunks.keys(): + expected_etags[part_number] = self.uploaded_chunks[part_number][1] + assert etags == expected_etags + + size = 0 + sha256 = hashlib.sha256() + + sorted_chunks = sorted(self.uploaded_chunks.keys()) + for part_number in sorted_chunks: + [chunk_path, _] = self.uploaded_chunks[part_number] + size += os.path.getsize(chunk_path) + with open(chunk_path, "rb") as f: + chunk_content = f.read() + sha256.update(chunk_content) + + self.file_content = FileContent(size, sha256.hexdigest()) + + def abort_upload(self): + self.aborted = True + + @staticmethod + def randomstr(): + return f'{random.randrange(10000)}-{int(time.time())}' + + +class CustomResponse: + """ Custom response allows to override the "default" response generated by the server + with the "custom" response to simulate failure error code, unexpected response body or + network error. + + The server is represented by the `processor` parameter in `generate_response()` call. + """ + + def __init__( + self, + # If False, default response is always returned. + # If True, response is defined by the current invocation count + # with respect to first_invocation / last_invocation / only_invocation + enabled=True, + + # Custom code to return + code: Optional[int] = 200, + + # Custom body to return + body: Optional[str] = None, + + # Custom exception to raise + exception: Optional[type[BaseException]] = None, + + # Whether exception should be raised before calling processor() + # (so changing server state) + exception_happened_before_processing: bool = False, + + # First invocation (1-based) at which return custom response + first_invocation: Optional[int] = None, + + # Last invocation (1-based) at which return custom response + last_invocation: Optional[int] = None, + + # Only invocation (1-based) at which return custom response + only_invocation: Optional[int] = None, + ): + self.enabled = enabled + self.code = code + self.body = body + self.exception = exception + self.exception_happened_before_processing = exception_happened_before_processing + self.first_invocation = first_invocation + self.last_invocation = last_invocation + self.only_invocation = only_invocation + + if self.only_invocation and (self.first_invocation or self.last_invocation): + raise ValueError("Cannot set both only invocation and first/last invocation") + + self.invocation_count = 0 + + def invocation_matches(self): + if not self.enabled: + return False + + self.invocation_count += 1 + + if self.only_invocation: + return self.invocation_count == self.only_invocation + + if self.first_invocation and self.invocation_count < self.first_invocation: + return False + if self.last_invocation and self.invocation_count > self.last_invocation: + return False + return True + + def generate_response(self, request: requests.Request, processor: Callable[[], list]): + activate_for_current_invocation = self.invocation_matches() + + if activate_for_current_invocation and self.exception and self.exception_happened_before_processing: + # if network exception is thrown while processing a request, it's not defined + # if server actually processed the request (and so changed its state) + raise self.exception + + custom_response = [self.code, self.body or '', {}] + + if activate_for_current_invocation: + if self.code and 400 <= self.code < 500: + # if server returns client error, it's not supposed to change its state, + # so we're not calling processor() + [code, body, headers] = custom_response + else: + # we're calling processor() but override its response with the custom one + processor() + [code, body, headers] = custom_response + else: + [code, body, headers] = processor() + + if activate_for_current_invocation and self.exception: + # self.exception_happened_before_processing is False + raise self.exception + + resp = requests.Response() + + resp.request = request + resp.status_code = code + resp._content = body.encode() + + for key in headers: + resp.headers[key] = headers[key] + + return resp + + +class MultipartUploadTestCase: + """Test case for multipart upload of a file. Multipart uploads are used on AWS and Azure. + + Multipart upload via presigned URLs involves multiple HTTP requests: + - initiating upload (call to Databricks Files API) + - requesting upload part URLs (calls to Databricks Files API) + - uploading data in chunks (calls to cloud storage provider or Databricks storage proxy) + - completing the upload (call to Databricks Files API) + - requesting abort upload URL (call to Databricks Files API) + - aborting the upload (call to cloud storage provider or Databricks storage proxy) + + Test case uses requests-mock library to mock all these requests. Within a test, mocks use + shared server state that tracks the upload. Mocks generate the "default" (successful) response. + + Response of each call can be modified by parameterising a respective `CustomResponse` object. + """ + + path = '/test.txt' + + expired_url_aws_response = ('' + 'AuthenticationFailedServer failed to authenticate ' + 'the request. Make sure the value of Authorization header is formed ' + 'correctly including the signature.\nRequestId:1abde581-601e-0028-' + '4a6d-5c3952000000\nTime:2025-01-01T16:54:20.5343181ZSignature not valid in the specified ' + 'time frame: Start [Wed, 01 Jan 2025 16:38:41 GMT] - Expiry [Wed, ' + '01 Jan 2025 16:53:45 GMT] - Current [Wed, 01 Jan 2025 16:54:20 ' + 'GMT]') + + expired_url_azure_response = ('\nAccessDenied' + 'Request has expired' + '142025-01-01T17:47:13Z' + '2025-01-01T17:48:01Z' + 'JY66KDXM4CXBZ7X2n8Qayqg60rbvut9P7pk0' + '') + + # TODO test for overwrite = false + + def __init__( + self, + name: str, + stream_size: int, # size of uploaded file or, technically, stream + multipart_upload_chunk_size: Optional[int] = None, + sdk_retry_timeout_seconds: Optional[int] = None, + multipart_upload_max_retries: Optional[int] = None, + multipart_upload_batch_url_count: Optional[int] = None, + custom_response_on_initiate=CustomResponse(enabled=False), + custom_response_on_create_multipart_url=CustomResponse(enabled=False), + custom_response_on_upload=CustomResponse(enabled=False), + custom_response_on_complete=CustomResponse(enabled=False), + custom_response_on_create_abort_url=CustomResponse(enabled=False), + custom_response_on_abort=CustomResponse(enabled=False), + # exception which is expected to be thrown (so upload is expected to have failed) + expected_exception_type: Optional[type[BaseException]] = None, + # if abort is expected to be called + expected_aborted: bool = False): + self.name = name + self.stream_size = stream_size + self.multipart_upload_chunk_size = multipart_upload_chunk_size + self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds + self.multipart_upload_max_retries = multipart_upload_max_retries + self.multipart_upload_batch_url_count = multipart_upload_batch_url_count + self.custom_response_on_initiate = copy.deepcopy(custom_response_on_initiate) + self.custom_response_on_create_multipart_url = copy.deepcopy(custom_response_on_create_multipart_url) + self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload) + self.custom_response_on_complete = copy.deepcopy(custom_response_on_complete) + self.custom_response_on_create_abort_url = copy.deepcopy(custom_response_on_create_abort_url) + self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort) + self.expected_exception_type = expected_exception_type + self.expected_aborted: bool = expected_aborted + + def setup_session_mock(self, session_mock: requests_mock.Mocker, + server_state: MultipartUploadServerState): + + def custom_matcher(request): + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + # initial request + if (request_url.hostname == 'localhost' + and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' + and request_query.get('action') == ['initiate-upload'] and request.method == 'POST'): + + assert MultipartUploadTestCase.is_auth_header_present(request) + assert request.text is None + + def processor(): + response_json = {'multipart_upload': {'session_token': server_state.session_token}} + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_initiate.generate_response(request, processor) + + # multipart upload, create upload part URLs + elif (request_url.hostname == 'localhost' + and request_url.path == '/api/2.0/fs/create-upload-part-urls' and request.method == 'POST'): + + assert MultipartUploadTestCase.is_auth_header_present(request) + + request_json = request.json() + assert request_json.keys() == { + 'count', 'expire_time', 'path', 'session_token', 'start_part_number' + } + assert request_json['path'] == self.path + assert request_json['session_token'] == server_state.session_token + + start_part_number = int(request_json['start_part_number']) + count = int(request_json['count']) + assert count >= 1 + + expire_time = MultipartUploadTestCase.parse_and_validate_expire_time( + request_json['expire_time']) + + def processor(): + response_nodes = [] + for part_number in range(start_part_number, start_part_number + count): + upload_part_url = server_state.create_upload_chunk_url( + self.path, part_number, expire_time) + response_nodes.append({ + 'part_number': part_number, + 'url': upload_part_url, + 'headers': [{ + 'name': 'name1', + 'value': 'value1' + }] + }) + + response_json = {'upload_part_urls': response_nodes} + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_create_multipart_url.generate_response(request, processor) + + # multipart upload, uploading part + elif request.url.startswith( + MultipartUploadServerState.upload_chunk_url_prefix) and request.method == 'PUT': + + assert not MultipartUploadTestCase.is_auth_header_present(request) + + url_path = request.url[len(MultipartUploadServerState.abort_upload_url_prefix):] + part_num = url_path.split('/')[-1] + assert url_path[:-len(part_num) - 1] == self.path + + def processor(): + body = request.body.read() + etag = 'etag-' + MultipartUploadServerState.randomstr() + server_state.save_part(int(part_num), body, etag) + return [200, '', {'ETag': etag}] + + return self.custom_response_on_upload.generate_response(request, processor) + + # multipart upload, completion + elif (request_url.hostname == 'localhost' + and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' + and request_query.get('action') == ['complete-upload'] + and request_query.get('upload_type') == ['multipart'] and request.method == 'POST'): + + assert MultipartUploadTestCase.is_auth_header_present(request) + assert [server_state.session_token] == request_query.get('session_token') + + def processor(): + request_json = request.json() + etags = {} + + for part in request_json['parts']: + etags[part['part_number']] = part['etag'] + + server_state.upload_complete(etags) + return [200, '', {}] + + return self.custom_response_on_complete.generate_response(request, processor) + + # create abort URL + elif request.url == 'http://localhost/api/2.0/fs/create-abort-upload-url' and request.method == 'POST': + assert MultipartUploadTestCase.is_auth_header_present(request) + request_json = request.json() + assert request_json['path'] == self.path + expire_time = MultipartUploadTestCase.parse_and_validate_expire_time( + request_json['expire_time']) + + def processor(): + response_json = { + 'abort_upload_url': { + 'url': server_state.create_abort_url(self.path, expire_time), + 'headers': [{ + 'name': 'header1', + 'value': 'headervalue1' + }] + } + } + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_create_abort_url.generate_response(request, processor) + + # abort upload + elif request.url.startswith( + MultipartUploadServerState.abort_upload_url_prefix) and request.method == 'DELETE': + assert not MultipartUploadTestCase.is_auth_header_present(request) + assert request.url[len(MultipartUploadServerState.abort_upload_url_prefix):] == self.path + + def processor(): + server_state.abort_upload() + return [200, '', {}] + + return self.custom_response_on_abort.generate_response(request, processor) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + @staticmethod + def setup_token_auth(config: Config): + pat_token = "some_pat_token" + config._header_factory = lambda: {'Authorization': f'Bearer {pat_token}'} + + @staticmethod + def is_auth_header_present(r: requests.Request): + return r.headers.get('Authorization') is not None + + @staticmethod + def parse_and_validate_expire_time(s: str) -> datetime: + expire_time = datetime.strptime(s, '%Y-%m-%dT%H:%M:%SZ') + expire_time = expire_time.replace(tzinfo=timezone.utc) # Explicitly add timezone + now = datetime.now(timezone.utc) + max_expiration = now + timedelta(hours=2) + assert now < expire_time < max_expiration + return expire_time + + def run(self, config: Config): + config = config.copy() + + MultipartUploadTestCase.setup_token_auth(config) + + if self.sdk_retry_timeout_seconds: + config.retry_timeout_seconds = self.sdk_retry_timeout_seconds + if self.multipart_upload_chunk_size: + config.multipart_upload_chunk_size = self.multipart_upload_chunk_size + if self.multipart_upload_max_retries: + config.multipart_upload_max_retries = self.multipart_upload_max_retries + if self.multipart_upload_batch_url_count: + config.multipart_upload_batch_url_count = self.multipart_upload_batch_url_count + config.enable_experimental_files_api_client = True + config.multipart_upload_min_stream_size = 0 # disable single-shot uploads + + file_content = os.urandom(self.stream_size) + + upload_state = MultipartUploadServerState() + + try: + w = WorkspaceClient(config=config) + with requests_mock.Mocker() as session_mock: + self.setup_session_mock(session_mock, upload_state) + + def upload(): + w.files.upload('/test.txt', io.BytesIO(file_content), overwrite=True) + + if self.expected_exception_type is not None: + with pytest.raises(self.expected_exception_type): + upload() + else: + upload() + actual_content = upload_state.get_file_content() + assert actual_content == FileContent.from_bytes(file_content) + + assert upload_state.aborted == self.expected_aborted + + finally: + upload_state.cleanup() + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # -------------------------- failures on "initiate upload" -------------------------- + MultipartUploadTestCase( + 'Initiate: 400 response is not retried', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=400, only_invocation=1), + expected_exception_type=BadRequest, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase( + 'Initiate: 403 response is not retried', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=403, only_invocation=1), + expected_exception_type=PermissionDenied, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase( + 'Initiate: 500 response is not retried', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=500, only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase( + 'Initiate: non-JSON response is not retried', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(body='this is not a JSON', only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase( + 'Initiate: meaningless JSON response is not retried', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(body='{"foo": 123}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase( + 'Initiate: no session token in response is not retried', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse( + body='{"multipart_upload":{"session_token1": "token123"}}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase( + 'Initiate: permanent retryable exception', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError), + sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout) + expected_exception_type=TimeoutError, + expected_aborted=False # upload didn't start + ), + MultipartUploadTestCase('Initiate: intermittent retryable exception', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError, + first_invocation=1, + last_invocation=3), + expected_aborted=False), + MultipartUploadTestCase('Initiate: intermittent retryable error', + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse( + code=429, first_invocation=1, last_invocation=3), + expected_aborted=False), + + # -------------------------- failures on "create upload URL" -------------------------- + MultipartUploadTestCase('Create upload URL: client error is not retied', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=400, + only_invocation=1), + expected_exception_type=BadRequest, + expected_aborted=True), + MultipartUploadTestCase('Create upload URL: internal error is not retied', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, + only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=True), + MultipartUploadTestCase('Create upload URL: non-JSON response is not retried', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + body='this is not a JSON', only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=True), + MultipartUploadTestCase('Create upload URL: meaningless JSON response is not retried', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body='{"foo":123}', + only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=True), + MultipartUploadTestCase('Create upload URL: meaningless JSON response is not retried 2', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + body='{"upload_part_urls":[]}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=True), + MultipartUploadTestCase( + 'Create upload URL: meaningless JSON response is not retried 3', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body='{"upload_part_urls":[{"url":""}]}', + only_invocation=1), + expected_exception_type=KeyError, # TODO we might want to make JSON parsing more reliable + expected_aborted=True), + MultipartUploadTestCase( + 'Create upload URL: permanent retryable exception', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(exception=requests.ConnectionError), + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + expected_exception_type=TimeoutError, + expected_aborted=True), + MultipartUploadTestCase('Create upload URL: intermittent retryable exception', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + exception=requests.Timeout, only_invocation=1), + expected_aborted=False), + MultipartUploadTestCase( + 'Create upload URL: intermittent retryable exception 2', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + exception=requests.Timeout, + # 4th request for multipart URLs will fail 3 times and will be retried + first_invocation=4, + last_invocation=6), + expected_aborted=False), + + # -------------------------- failures on chunk upload -------------------------- + MultipartUploadTestCase( + 'Upload chunk: 403 response is not retried', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse(code=403, only_invocation=2), + expected_exception_type=PermissionDenied, + expected_aborted=True), + MultipartUploadTestCase( + 'Upload chunk: 400 response is not retried', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse(code=400, only_invocation=9), + expected_exception_type=BadRequest, + expected_aborted=True), + MultipartUploadTestCase( + 'Upload chunk: 500 response is not retried', # TODO should we retry chunk upload on internal error from Cloud provider? + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse(code=500, only_invocation=5), + expected_exception_type=InternalError, + expected_aborted=True), + MultipartUploadTestCase( + 'Upload chunk: expired URL is retried on AWS', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, body=MultipartUploadTestCase.expired_url_aws_response, only_invocation=2), + expected_aborted=False), + MultipartUploadTestCase( + 'Upload chunk: expired URL is retried on Azure', + multipart_upload_max_retries=3, + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + # 3 failures don't exceed multipart_upload_max_retries + first_invocation=2, + last_invocation=4), + expected_aborted=False), + MultipartUploadTestCase( + 'Upload chunk: expired URL is retried on Azure, requesting urls by 6', + multipart_upload_max_retries=3, + multipart_upload_batch_url_count=6, + stream_size=100 * 1024 * 1024, # 100 chunks + multipart_upload_chunk_size=1 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + # 3 failures don't exceed multipart_upload_max_retries + first_invocation=2, + last_invocation=4), + expected_aborted=False), + MultipartUploadTestCase( + 'Upload chunk: expired URL retry is limited', + multipart_upload_max_retries=3, + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + # 4 failures exceed multipart_upload_max_retries + first_invocation=2, + last_invocation=5), + expected_exception_type=ValueError, + expected_aborted=True), + MultipartUploadTestCase( + 'Upload chunk: permanent retryable error', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, first_invocation=8), + expected_exception_type=TimeoutError, + expected_aborted=True), + MultipartUploadTestCase( + 'Upload chunk: intermittent retryable error', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, first_invocation=2, last_invocation=5), + expected_aborted=False), + + # -------------------------- failures on abort -------------------------- + MultipartUploadTestCase( + 'Abort URL: client error', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse(code=400), + expected_exception_type=InternalError, # original error + expected_aborted=False), + MultipartUploadTestCase( + 'Abort URL: forbidden', + stream_size=1024 * 1024, + custom_response_on_upload=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse(code=403), + expected_exception_type=InternalError, # original error + expected_aborted=False), + MultipartUploadTestCase( + 'Abort URL: intermittent retryable error', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse( + code=429, first_invocation=1, last_invocation=3), + expected_exception_type=InternalError, # original error + expected_aborted=True # abort successfully called after abort URL creation is retried + ), + MultipartUploadTestCase( + 'Abort URL: intermittent retryable error 2', + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse( + exception=requests.Timeout, first_invocation=1, last_invocation=3), + expected_exception_type=InternalError, # original error + expected_aborted=True # abort successfully called after abort URL creation is retried + ), + MultipartUploadTestCase( + 'Abort: exception', + stream_size=1024 * 1024, + # don't wait for 5 min (SDK default timeout) + sdk_retry_timeout_seconds=30, + # simulate PermissionDenied + custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), + custom_response_on_abort=CustomResponse(exception=requests.Timeout, + exception_happened_before_processing=False), + expected_exception_type=PermissionDenied, # original error is reported + expected_aborted=True # abort called but failed + ), + + # -------------------------- happy cases -------------------------- + MultipartUploadTestCase( + 'Multipart upload successful: single chunk', + stream_size=1024 * 1024, # less than chunk size + multipart_upload_chunk_size=10 * 1024 * 1024), + MultipartUploadTestCase( + 'Multipart upload successful: multiple chunks (aligned)', + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024), + MultipartUploadTestCase( + 'Multipart upload successful: multiple chunks (aligned), upload urls by 3', + multipart_upload_batch_url_count=3, + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024), + MultipartUploadTestCase( + 'Multipart upload successful: multiple chunks (not aligned)', + stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder + multipart_upload_chunk_size=7 * 1024 * 1024 - 17), + MultipartUploadTestCase( + 'Multipart upload successful: multiple chunks (not aligned), upload urls by 5', + multipart_upload_batch_url_count=5, + stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder + multipart_upload_chunk_size=7 * 1024 * 1024 - 17), + ], + ids=MultipartUploadTestCase.to_string) +def test_multipart_upload(config: Config, test_case: MultipartUploadTestCase): + test_case.run(config) + + +class SingleShotUploadState: + + def __init__(self): + self.single_shot_file_content = None + + +class SingleShotUploadTestCase: + + def __init__(self, name: str, stream_size: int, multipart_upload_min_stream_size: int, + expected_single_shot: bool): + self.name = name + self.stream_size = stream_size + self.multipart_upload_min_stream_size = multipart_upload_min_stream_size + self.expected_single_shot = expected_single_shot + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + def run(self, config: Config): + config = config.copy() + config.enable_experimental_files_api_client = True + config.multipart_upload_min_stream_size = self.multipart_upload_min_stream_size + + file_content = os.urandom(self.stream_size) + + session = requests.Session() + with requests_mock.Mocker(session=session) as session_mock: + session_mock.get(f'http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}', + status_code=200) + + upload_state = SingleShotUploadState() + + def custom_matcher(request): + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + if self.expected_single_shot: + if (request_url.hostname == 'localhost' + and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' + and request.method == 'PUT'): + body = request.body.read() + upload_state.single_shot_file_content = FileContent.from_bytes(body) + + resp = requests.Response() + resp.status_code = 204 + resp.request = request + resp._content = b'' + return resp + else: + if (request_url.hostname == 'localhost' + and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' + and request_query.get('action') == ['initiate-upload'] + and request.method == 'POST'): + + resp = requests.Response() + resp.status_code = 403 # this will throw, that's fine + resp.request = request + resp._content = b'' + return resp + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + w = WorkspaceClient(config=config) + w.files._api._api_client._session = session + + def upload(): + w.files.upload('/test.txt', io.BytesIO(file_content), overwrite=True) + + if self.expected_single_shot: + upload() + actual_content = upload_state.single_shot_file_content + assert actual_content == FileContent.from_bytes(file_content) + else: + with pytest.raises(PermissionDenied): + upload() + + +@pytest.mark.parametrize("test_case", [ + SingleShotUploadTestCase('Single-shot upload', + stream_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024 + 1, + expected_single_shot=True), + SingleShotUploadTestCase('Multipart upload 1', + stream_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024, + expected_single_shot=False), + SingleShotUploadTestCase('Multipart upload 2', + stream_size=1024 * 1024, + multipart_upload_min_stream_size=0, + expected_single_shot=False), +], + ids=SingleShotUploadTestCase.to_string) +def test_single_shot_upload(config: Config, test_case: SingleShotUploadTestCase): + test_case.run(config) + + +class ResumableUploadServerState: + resumable_upload_url_prefix = 'https://cloud_provider.com/resumable-upload/' + abort_upload_url_prefix = 'https://cloud_provider.com/abort-upload/' + + def __init__(self, unconfirmed_delta: Union[int, list]): + self.unconfirmed_delta = unconfirmed_delta + self.confirmed_last_byte: Optional[int] = None # inclusive + self.uploaded_parts = [] + self.session_token = 'token-' + MultipartUploadServerState.randomstr() + self.file_content = None + self.aborted = False + + def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes, file_size_s: str): + assert not self.aborted + + assert len(part_content) > 0 + if self.confirmed_last_byte: + assert start_offset == self.confirmed_last_byte + 1 + else: + assert start_offset == 0 + + assert end_offset_incl == start_offset + len(part_content) - 1 + + is_last_part = file_size_s != '*' + if is_last_part: + assert int(file_size_s) == end_offset_incl + 1 + else: + assert not self.file_content # last chunk should not have been uploaded yet + + if isinstance(self.unconfirmed_delta, int): + unconfirmed_delta = self.unconfirmed_delta + elif len(self.uploaded_parts) < len(self.unconfirmed_delta): + unconfirmed_delta = self.unconfirmed_delta[len(self.uploaded_parts)] + else: + unconfirmed_delta = self.unconfirmed_delta[-1] # take the last delta + + if unconfirmed_delta >= len(part_content): + unconfirmed_delta = 0 # otherwise we never finish + + logger.info( + f"Saving part {len(self.uploaded_parts) + 1} of original size {len(part_content)} with unconfirmed delta {unconfirmed_delta}. is_last_part = {is_last_part}" + ) + + if unconfirmed_delta > 0: + part_content = part_content[:-unconfirmed_delta] + + fd, chunk_file = mkstemp() + with open(fd, "wb") as f: + f.write(part_content) + + self.uploaded_parts.append(chunk_file) + + if is_last_part and unconfirmed_delta == 0: + size = 0 + sha256 = hashlib.sha256() + for chunk_path in self.uploaded_parts: + size += os.path.getsize(chunk_path) + with open(chunk_path, "rb") as f: + chunk_content = f.read() + sha256.update(chunk_content) + + assert size == end_offset_incl + 1 + self.file_content = FileContent(size, sha256.hexdigest()) + + self.confirmed_last_byte = end_offset_incl - unconfirmed_delta + + def create_abort_url(self, path: str, expire_time: datetime) -> str: + assert not self.aborted + self.issued_abort_url_expire_time = expire_time + return f'{self.abort_upload_url_prefix}{path}' + + def cleanup(self): + for file in self.uploaded_parts: + os.remove(file) + + def get_file_content(self) -> FileContent: + assert not self.aborted + return self.file_content + + def abort_upload(self): + self.aborted = True + + +class ResumableUploadTestCase: + """Test case for resumable upload of a file. Resumable uploads are used on GCP. + + Resumable upload involves multiple HTTP requests: + - initiating upload (call to Databricks Files API) + - requesting resumable upload URL (call to Databricks Files API) + - uploading chunks of data (calls to cloud storage provider or Databricks storage proxy) + - aborting the upload (call to cloud storage provider or Databricks storage proxy) + + Test case uses requests-mock library to mock all these requests. Within a test, mocks use + shared server state that tracks the upload. Mocks generate the "default" (successful) response. + + Response of each call can be modified by parameterising a respective `CustomResponse` object. + """ + path = '/test.txt' + + def __init__( + self, + name: str, + stream_size: int, + overwrite: bool = True, + multipart_upload_chunk_size: Optional[int] = None, + sdk_retry_timeout_seconds: Optional[int] = None, + multipart_upload_max_retries: Optional[int] = None, + + # In resumable upload, when replying to chunk upload request, server returns + # (confirms) last accepted byte offset for the client to resume upload after. + # + # `unconfirmed_delta` defines offset from the end of the chunk that remains + # "unconfirmed", i.e. the last accepted offset would be (range_end - unconfirmed_delta). + # Can be int (same for all chunks) or list (individual for each chunk). + unconfirmed_delta: Union[int, list] = 0, + custom_response_on_create_resumable_url=CustomResponse(enabled=False), + custom_response_on_upload=CustomResponse(enabled=False), + custom_response_on_status_check=CustomResponse(enabled=False), + custom_response_on_abort=CustomResponse(enabled=False), + + # exception which is expected to be thrown (so upload is expected to have failed) + expected_exception_type: Optional[type[BaseException]] = None, + + # if abort is expected to be called + expected_aborted: bool = False): + self.name = name + self.stream_size = stream_size + self.overwrite = overwrite + self.multipart_upload_chunk_size = multipart_upload_chunk_size + self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds + self.multipart_upload_max_retries = multipart_upload_max_retries + self.unconfirmed_delta = unconfirmed_delta + self.custom_response_on_create_resumable_url = copy.deepcopy(custom_response_on_create_resumable_url) + self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload) + self.custom_response_on_status_check = copy.deepcopy(custom_response_on_status_check) + self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort) + self.expected_exception_type = expected_exception_type + self.expected_aborted: bool = expected_aborted + + def setup_session_mock(self, session_mock: requests_mock.Mocker, + server_state: ResumableUploadServerState): + + def custom_matcher(request): + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + # initial request + if (request_url.hostname == 'localhost' + and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' + and request_query.get('action') == ['initiate-upload'] and request.method == 'POST'): + + assert MultipartUploadTestCase.is_auth_header_present(request) + assert request.text is None + + def processor(): + response_json = {'resumable_upload': {'session_token': server_state.session_token}} + return [200, json.dumps(response_json), {}] + + # Different initiate error responses have been verified by test_multipart_upload(), + # so we're always generating a "success" response. + return CustomResponse(enabled=False).generate_response(request, processor) + + elif (request_url.hostname == 'localhost' + and request_url.path == '/api/2.0/fs/create-resumable-upload-url' + and request.method == 'POST'): + + assert MultipartUploadTestCase.is_auth_header_present(request) + + request_json = request.json() + assert request_json.keys() == {'path', 'session_token'} + assert request_json['path'] == self.path + assert request_json['session_token'] == server_state.session_token + + def processor(): + resumable_upload_url = f'{ResumableUploadServerState.resumable_upload_url_prefix}{self.path}' + + response_json = { + 'resumable_upload_url': { + 'url': resumable_upload_url, + 'headers': [{ + 'name': 'name1', + 'value': 'value1' + }] + } + } + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_create_resumable_url.generate_response(request, processor) + + # resumable upload, uploading part + elif request.url.startswith( + ResumableUploadServerState.resumable_upload_url_prefix) and request.method == 'PUT': + + assert not MultipartUploadTestCase.is_auth_header_present(request) + url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix):] + assert url_path == self.path + + content_range_header = request.headers['Content-range'] + is_status_check_request = re.match('bytes \\*/\\*', content_range_header) + if is_status_check_request: + assert not request.body + response_customizer = self.custom_response_on_status_check + else: + response_customizer = self.custom_response_on_upload + + def processor(): + if not is_status_check_request: + body = request.body.read() + + match = re.match('bytes (\d+)-(\d+)/(.+)', content_range_header) + [range_start_s, range_end_s, file_size_s] = match.groups() + + server_state.save_part(int(range_start_s), int(range_end_s), body, file_size_s) + + if server_state.file_content: + # upload complete + return [200, '', {}] + else: + # more data expected + if server_state.confirmed_last_byte: + headers = {'Range': f'bytes=0-{server_state.confirmed_last_byte}'} + else: + headers = {} + return [308, '', headers] + + return response_customizer.generate_response(request, processor) + + # abort upload + elif request.url.startswith( + ResumableUploadServerState.resumable_upload_url_prefix) and request.method == 'DELETE': + + assert not MultipartUploadTestCase.is_auth_header_present(request) + url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix):] + assert url_path == self.path + + def processor(): + server_state.abort_upload() + return [200, '', {}] + + return self.custom_response_on_abort.generate_response(request, processor) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + def run(self, config: Config): + config = config.copy() + if self.sdk_retry_timeout_seconds: + config.retry_timeout_seconds = self.sdk_retry_timeout_seconds + if self.multipart_upload_chunk_size: + config.multipart_upload_chunk_size = self.multipart_upload_chunk_size + if self.multipart_upload_max_retries: + config.multipart_upload_max_retries = self.multipart_upload_max_retries + config.enable_experimental_files_api_client = True + config.multipart_upload_min_stream_size = 0 # disable single-shot uploads + + MultipartUploadTestCase.setup_token_auth(config) + + file_content = os.urandom(self.stream_size) + + upload_state = ResumableUploadServerState(self.unconfirmed_delta) + + try: + with requests_mock.Mocker() as session_mock: + self.setup_session_mock(session_mock, upload_state) + w = WorkspaceClient(config=config) + + def upload(): + w.files.upload('/test.txt', io.BytesIO(file_content), overwrite=self.overwrite) + + if self.expected_exception_type is not None: + with pytest.raises(self.expected_exception_type): + upload() + else: + upload() + actual_content = upload_state.get_file_content() + assert actual_content == FileContent.from_bytes(file_content) + + assert upload_state.aborted == self.expected_aborted + + finally: + upload_state.cleanup() + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # ------------------ failures on creating resumable upload URL ------------------ + ResumableUploadTestCase( + 'Create resumable URL: client error is not retried', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=400, only_invocation=1), + expected_exception_type=BadRequest, + expected_aborted=False # upload didn't start + ), + ResumableUploadTestCase( + 'Create resumable URL: permission denied is not retried', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=403, only_invocation=1), + expected_exception_type=PermissionDenied, + expected_aborted=False # upload didn't start + ), + ResumableUploadTestCase( + 'Create resumable URL: internal error is not retried', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=500, only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=False # upload didn't start + ), + ResumableUploadTestCase( + 'Create resumable URL: non-JSON response is not retried', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(body='Foo bar', only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=False # upload didn't start + ), + ResumableUploadTestCase( + 'Create resumable URL: meaningless JSON response is not retried', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(body='{"upload_part_urls":[{"url":""}]}', + only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=False # upload didn't start + ), + ResumableUploadTestCase( + 'Create resumable URL: permanent retryable exception', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=429), + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + expected_exception_type=TimeoutError, + expected_aborted=False # upload didn't start + ), + ResumableUploadTestCase( + 'Create resumable URL: intermittent retryable exception is retried', + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse( + exception=requests.Timeout, + # 3 failures total + first_invocation=1, + last_invocation=3), + expected_aborted=False # upload succeeds + ), + + # ------------------ failures during upload ------------------ + ResumableUploadTestCase( + 'Upload: retryable exception after file is uploaded', + stream_size=1024 * 1024, + custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, + exception_happened_before_processing=False), + # Despite the returned error, file has been uploaded. We'll discover that + # on the next status check and consider upload completed. + expected_aborted=False), + ResumableUploadTestCase( + 'Upload: retryable exception before file is uploaded, not enough retries', + stream_size=1024 * 1024, + multipart_upload_max_retries=3, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, + # prevent server from saving this chunk + exception_happened_before_processing=True, + # fail 4 times, exceeding max_retries + first_invocation=1, + last_invocation=4), + # File was never uploaded and we gave up retrying + expected_exception_type=requests.ConnectionError, + expected_aborted=True), + ResumableUploadTestCase( + 'Upload: retryable exception before file is uploaded, enough retries', + stream_size=1024 * 1024, + multipart_upload_max_retries=4, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, + # prevent server from saving this chunk + exception_happened_before_processing=True, + # fail 4 times, not exceeding max_retries + first_invocation=1, + last_invocation=4), + # File was uploaded after retries + expected_aborted=False), + + # -------------- abort failures -------------- + ResumableUploadTestCase( + 'Abort: client error', + stream_size=1024 * 1024, + # prevent chunk from being uploaded + custom_response_on_upload=CustomResponse(code=403), + # internal server error does not prevent server state change + custom_response_on_abort=CustomResponse(code=500), + expected_exception_type=PermissionDenied, + # abort returned error but was invoked + expected_aborted=True), + + # -------------- file already exists -------------- + ResumableUploadTestCase('File already exists', + stream_size=1024 * 1024, + overwrite=False, + custom_response_on_upload=CustomResponse(code=412), + expected_exception_type=AlreadyExists, + expected_aborted=True), + + # -------------- success cases -------------- + ResumableUploadTestCase('Multiple chunks, zero unconfirmed delta', + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + unconfirmed_delta=0, + expected_aborted=False), + ResumableUploadTestCase('Multiple chunks, non-zero unconfirmed delta', + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + unconfirmed_delta=239, + expected_aborted=False), + ResumableUploadTestCase('Multiple chunks, variable unconfirmed delta', + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5], + expected_aborted=False), + ], + ids=ResumableUploadTestCase.to_string) +def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase): + test_case.run(config) From 17e388bc8c6e696423fb27dad0d5b17819bd7a7b Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Thu, 27 Feb 2025 00:27:18 +0100 Subject: [PATCH 02/11] Fix warning --- tests/test_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_files.py b/tests/test_files.py index 3ed7cf8dd..9e4bfc1b0 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1429,7 +1429,7 @@ def processor(): if not is_status_check_request: body = request.body.read() - match = re.match('bytes (\d+)-(\d+)/(.+)', content_range_header) + match = re.match('bytes (\\d+)-(\\d+)/(.+)', content_range_header) [range_start_s, range_end_s, file_size_s] = match.groups() server_state.save_part(int(range_start_s), int(range_end_s), body, file_size_s) From e86a56e10a1e9ab328b2db4fb173fb8e9bfc1cd2 Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Thu, 27 Feb 2025 12:25:49 +0100 Subject: [PATCH 03/11] Restore Python 3.8 compatibility (where `builtins.type` does not support subscripting). --- tests/test_files.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_files.py b/tests/test_files.py index 9e4bfc1b0..07a570a74 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta, timezone from tempfile import mkstemp -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Union, Type from urllib.parse import parse_qs, urlparse import pytest @@ -475,7 +475,7 @@ def __init__( body: Optional[str] = None, # Custom exception to raise - exception: Optional[type[BaseException]] = None, + exception: Optional[Type[BaseException]] = None, # Whether exception should be raised before calling processor() # (so changing server state) @@ -610,7 +610,7 @@ def __init__( custom_response_on_create_abort_url=CustomResponse(enabled=False), custom_response_on_abort=CustomResponse(enabled=False), # exception which is expected to be thrown (so upload is expected to have failed) - expected_exception_type: Optional[type[BaseException]] = None, + expected_exception_type: Optional[Type[BaseException]] = None, # if abort is expected to be called expected_aborted: bool = False): self.name = name @@ -1341,7 +1341,7 @@ def __init__( custom_response_on_abort=CustomResponse(enabled=False), # exception which is expected to be thrown (so upload is expected to have failed) - expected_exception_type: Optional[type[BaseException]] = None, + expected_exception_type: Optional[Type[BaseException]] = None, # if abort is expected to be called expected_aborted: bool = False): From 5cf738d4d9bb807a6ce5d091fc82e0aa02e404d2 Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Thu, 27 Feb 2025 12:42:20 +0100 Subject: [PATCH 04/11] Reformat --- databricks/sdk/mixins/files.py | 300 +++++------ tests/test_files.py | 907 ++++++++++++++++++--------------- 2 files changed, 641 insertions(+), 566 deletions(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index c0e7aeb68..797f37d9c 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -732,33 +732,34 @@ def _download_raw_stream( headers = { "Accept": "application/octet-stream", } + def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): # Upload empty and small files with one-shot upload. pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size) if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size: _LOG.debug( - f'Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes' + f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes" ) return super().upload(file_path=file_path, contents=pre_read_buffer, overwrite=overwrite) - query = {'action': 'initiate-upload'} + query = {"action": "initiate-upload"} if overwrite is not None: - query['overwrite'] = overwrite + query["overwrite"] = overwrite # _api.do() does retry initiate_upload_response = self._api.do( - 'POST', f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}', query=query) + "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", query=query + ) # no need to check response status, _api.do() will throw exception on failure - if initiate_upload_response.get('multipart_upload'): + if initiate_upload_response.get("multipart_upload"): cloud_provider_session = self._create_cloud_provider_session() - session_token = initiate_upload_response['multipart_upload'].get('session_token') + session_token = initiate_upload_response["multipart_upload"].get("session_token") if not session_token: - raise ValueError(f'Unexpected server response: {initiate_upload_response}') + raise ValueError(f"Unexpected server response: {initiate_upload_response}") try: - self._multipart_upload(file_path, contents, session_token, pre_read_buffer, - cloud_provider_session) + self._multipart_upload(file_path, contents, session_token, pre_read_buffer, cloud_provider_session) except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") try: @@ -770,16 +771,23 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool # rethrow original exception raise e from None - elif initiate_upload_response.get('resumable_upload'): + elif initiate_upload_response.get("resumable_upload"): cloud_provider_session = self._create_cloud_provider_session() - session_token = initiate_upload_response['resumable_upload']['session_token'] - self._resumable_upload(file_path, contents, session_token, overwrite, pre_read_buffer, - cloud_provider_session) + session_token = initiate_upload_response["resumable_upload"]["session_token"] + self._resumable_upload( + file_path, contents, session_token, overwrite, pre_read_buffer, cloud_provider_session + ) else: - raise ValueError(f'Unexpected server response: {initiate_upload_response}') + raise ValueError(f"Unexpected server response: {initiate_upload_response}") - def _multipart_upload(self, target_path: str, input_stream: BinaryIO, session_token: str, - pre_read_buffer: bytes, cloud_provider_session: requests.Session): + def _multipart_upload( + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, + ): current_part_number = 1 etags: dict = {} @@ -791,7 +799,7 @@ def _multipart_upload(self, target_path: str, input_stream: BinaryIO, session_to # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads - chunk_offset = 0 # used only for logging + chunk_offset = 0 # used only for logging # This buffer is expected to contain at least multipart_upload_chunk_size bytes. # Note that initially buffer can be bigger (from pre_read_buffer). @@ -821,25 +829,24 @@ def fill_buffer(): ) body: dict = { - 'path': target_path, - 'session_token': session_token, - 'start_part_number': current_part_number, - 'count': self._config.multipart_upload_batch_url_count, - 'expire_time': self._get_url_expire_time() + "path": target_path, + "session_token": session_token, + "start_part_number": current_part_number, + "count": self._config.multipart_upload_batch_url_count, + "expire_time": self._get_url_expire_time(), } - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} # _api.do() does retry - upload_part_urls_response = self._api.do('POST', - '/api/2.0/fs/create-upload-part-urls', - headers=headers, - body=body) + upload_part_urls_response = self._api.do( + "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body + ) # no need to check response status, _api.do() will throw exception on failure - upload_part_urls = upload_part_urls_response.get('upload_part_urls', []) + upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) if not len(upload_part_urls): - raise ValueError(f'Unexpected server response: {upload_part_urls_response}') + raise ValueError(f"Unexpected server response: {upload_part_urls_response}") for upload_part_url in upload_part_urls: buffer = fill_buffer() @@ -848,17 +855,17 @@ def fill_buffer(): eof = True break - url = upload_part_url['url'] - required_headers = upload_part_url.get('headers', []) - assert current_part_number == upload_part_url['part_number'] + url = upload_part_url["url"] + required_headers = upload_part_url.get("headers", []) + assert current_part_number == upload_part_url["part_number"] - headers: dict = {'Content-Type': 'application/octet-stream'} + headers: dict = {"Content-Type": "application/octet-stream"} for h in required_headers: - headers[h['name']] = h['value'] + headers[h["name"]] = h["value"] actual_chunk_length = min(actual_buffer_length, self._config.multipart_upload_chunk_size) _LOG.debug( - f'Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]' + f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]" ) chunk = BytesIO(buffer[:actual_chunk_length]) @@ -868,26 +875,29 @@ def rewind(): def perform(): result = cloud_provider_session.request( - 'PUT', + "PUT", url, headers=headers, data=chunk, - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) return result # following _BaseClient timeout retry_timeout_seconds = self._config.retry_timeout_seconds or 300 - upload_response = retried(timeout=timedelta(seconds=retry_timeout_seconds), - is_retryable=_BaseClient._is_retryable, - clock=self._clock, - before_retry=rewind)(perform)() + upload_response = retried( + timeout=timedelta(seconds=retry_timeout_seconds), + is_retryable=_BaseClient._is_retryable, + clock=self._clock, + before_retry=rewind, + )(perform)() if upload_response.status_code in (200, 201): # Chunk upload successful chunk_offset += actual_chunk_length - etag = upload_response.headers.get('ETag', '') + etag = upload_response.headers.get("ETag", "") etags[current_part_number] = etag # Discard uploaded bytes @@ -899,14 +909,14 @@ def perform(): elif FilesExt._is_url_expired_response(upload_response): if retry_count < self._config.multipart_upload_max_retries: retry_count += 1 - _LOG.debug('Upload URL expired') + _LOG.debug("Upload URL expired") # Preserve the buffer so we'll upload the current part again using next upload URL else: # don't confuse user with unrelated "Permission denied" error. - raise ValueError(f'Unsuccessful chunk upload: upload URL expired') + raise ValueError(f"Unsuccessful chunk upload: upload URL expired") else: - message = f'Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}' + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" _LOG.warning(message) mapped_error = _error_mapper(upload_response, {}) raise mapped_error or ValueError(message) @@ -914,26 +924,28 @@ def perform(): current_part_number += 1 _LOG.debug( - f'Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes' + f"Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes" ) - query = {'action': 'complete-upload', 'upload_type': 'multipart', 'session_token': session_token} - headers = {'Content-Type': 'application/json'} + query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} + headers = {"Content-Type": "application/json"} body: dict = {} parts = [] for etag in sorted(etags.items()): - part = {'part_number': etag[0], 'etag': etag[1]} + part = {"part_number": etag[0], "etag": etag[1]} parts.append(part) - body['parts'] = parts + body["parts"] = parts # _api.do() does retry - self._api.do('POST', - f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}', - query=query, - headers=headers, - body=body) + self._api.do( + "POST", + f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}", + query=query, + headers=headers, + body=body, + ) # no need to check response status, _api.do() will throw exception on failure @staticmethod @@ -943,23 +955,23 @@ def _is_url_expired_response(response: requests.Response): try: xml_root = ET.fromstring(response.content) - if xml_root.tag != 'Error': + if xml_root.tag != "Error": return False - code = xml_root.find('Code') + code = xml_root.find("Code") if code is None: return False - if code.text == 'AuthenticationFailed': + if code.text == "AuthenticationFailed": # Azure - details = xml_root.find('AuthenticationErrorDetail') - if details is not None and 'Signature not valid in the specified time frame' in details.text: + details = xml_root.find("AuthenticationErrorDetail") + if details is not None and "Signature not valid in the specified time frame" in details.text: return True - if code.text == 'AccessDenied': + if code.text == "AccessDenied": # AWS - message = xml_root.find('Message') - if message is not None and message.text == 'Request has expired': + message = xml_root.find("Message") + if message is not None and message.text == "Request has expired": return True except ET.ParseError: @@ -967,8 +979,15 @@ def _is_url_expired_response(response: requests.Response): return False - def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_token: str, overwrite: bool, - pre_read_buffer: bytes, cloud_provider_session: requests.Session): + def _resumable_upload( + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + overwrite: bool, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, + ): # https://cloud.google.com/storage/docs/performing-resumable-uploads # Session URI we're using expires after a week @@ -993,26 +1012,25 @@ def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_to # On the contrary, in multipart upload we can decide to complete upload *after* # last chunk has been sent. - body: dict = {'path': target_path, 'session_token': session_token} + body: dict = {"path": target_path, "session_token": session_token} - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} # _api.do() does retry - resumable_upload_url_response = self._api.do('POST', - '/api/2.0/fs/create-resumable-upload-url', - headers=headers, - body=body) + resumable_upload_url_response = self._api.do( + "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body + ) # no need to check response status, _api.do() will throw exception on failure - resumable_upload_url_node = resumable_upload_url_response.get('resumable_upload_url') + resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url") if not resumable_upload_url_node: - raise ValueError(f'Unexpected server response: {resumable_upload_url_response}') + raise ValueError(f"Unexpected server response: {resumable_upload_url_response}") - resumable_upload_url = resumable_upload_url_node.get('url') + resumable_upload_url = resumable_upload_url_node.get("url") if not resumable_upload_url: - raise ValueError(f'Unexpected server response: {resumable_upload_url_response}') + raise ValueError(f"Unexpected server response: {resumable_upload_url_response}") - required_headers = resumable_upload_url_node.get('headers', []) + required_headers = resumable_upload_url_node.get("headers", []) try: # We will buffer this many bytes: one chunk + read-ahead block. @@ -1043,43 +1061,44 @@ def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_to else: # More chunks expected, let's upload current chunk (excluding read-ahead block). actual_chunk_length = self._config.multipart_upload_chunk_size - file_size = '*' + file_size = "*" - headers: dict = {'Content-Type': 'application/octet-stream'} + headers: dict = {"Content-Type": "application/octet-stream"} for h in required_headers: - headers[h['name']] = h['value'] + headers[h["name"]] = h["value"] chunk_last_byte_offset = chunk_offset + actual_chunk_length - 1 - content_range_header = f'bytes {chunk_offset}-{chunk_last_byte_offset}/{file_size}' - _LOG.debug(f'Uploading chunk: {content_range_header}') - headers['Content-Range'] = content_range_header + content_range_header = f"bytes {chunk_offset}-{chunk_last_byte_offset}/{file_size}" + _LOG.debug(f"Uploading chunk: {content_range_header}") + headers["Content-Range"] = content_range_header try: # We're not retrying this single request as we don't know where to rewind to. # Instead, in case of retryable failure we'll re-request the current offset # from the server and resume upload from there in the main upload loop. upload_response: requests.Response = cloud_provider_session.request( - 'PUT', + "PUT", resumable_upload_url, headers=headers, data=BytesIO(buffer[:actual_chunk_length]), - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) - retry_count = 0 # reset retry count when progressing along the stream + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + retry_count = 0 # reset retry count when progressing along the stream except RequestException as e: - _LOG.warning(f'Failure during upload request: {sys.exc_info()}') - if _BaseClient._is_retryable( - e) and retry_count < self._config.multipart_upload_max_retries: + _LOG.warning(f"Failure during upload request: {sys.exc_info()}") + if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries: retry_count += 1 # Chunk upload threw an error, try to retrieve the current received offset try: # https://cloud.google.com/storage/docs/performing-resumable-uploads#status-check - headers['Content-Range'] = 'bytes */*' + headers["Content-Range"] = "bytes */*" upload_response = cloud_provider_session.request( - 'PUT', + "PUT", resumable_upload_url, headers=headers, - data=b'', - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) except RequestException: # status check failed, abort the upload raise e from None @@ -1088,7 +1107,7 @@ def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_to raise e from None if upload_response.status_code in (200, 201): - if file_size == '*': + if file_size == "*": raise ValueError( f"Received unexpected status {upload_response.status_code} before reaching end of stream" ) @@ -1098,7 +1117,7 @@ def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_to elif upload_response.status_code == 308: # chunk accepted, let's determine received offset to resume from there - range_string = upload_response.headers.get('Range') + range_string = upload_response.headers.get("Range") confirmed_offset = self._extract_range_offset(range_string) _LOG.debug(f"Received confirmed offset: {confirmed_offset}") @@ -1127,17 +1146,16 @@ def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_to # Expecting response body to be small to be safely logged mapped_error = _error_mapper(upload_response, {}) raise mapped_error or ValueError( - f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}") + f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}" + ) elif upload_response.status_code == 412 and not overwrite: # Assuming this is only possible reason # Full message in this case: "At least one of the pre-conditions you specified did not hold." - raise AlreadyExists('The file being created already exists.') + raise AlreadyExists("The file being created already exists.") else: if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}" - ) + _LOG.debug(f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}") mapped_error = _error_mapper(upload_response, {}) raise mapped_error or ValueError(f"Failed to upload: {upload_response}") @@ -1156,9 +1174,9 @@ def _resumable_upload(self, target_path: str, input_stream: BinaryIO, session_to @staticmethod def _extract_range_offset(range_string: Optional[str]) -> Optional[int]: if not range_string: - return None # server did not yet confirm any bytes + return None # server did not yet confirm any bytes - if match := re.match('bytes=0-(\\d+)', range_string): + if match := re.match("bytes=0-(\\d+)", range_string): return int(match.group(1)) else: raise ValueError(f"Cannot parse response header: Range: {range_string}") @@ -1170,71 +1188,66 @@ def _get_url_expire_time(self): # In JSON format, the Timestamp type is encoded as a string in the # * [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the # * format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" - return expire_time.strftime('%Y-%m-%dT%H:%M:%SZ') - - def _abort_multipart_upload(self, target_path: str, session_token: str, - cloud_provider_session: requests.Session): - body: dict = { - 'path': target_path, - 'session_token': session_token, - 'expire_time': self._get_url_expire_time() - } + return expire_time.strftime("%Y-%m-%dT%H:%M:%SZ") - headers = {'Content-Type': 'application/json'} + def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_provider_session: requests.Session): + body: dict = {"path": target_path, "session_token": session_token, "expire_time": self._get_url_expire_time()} + + headers = {"Content-Type": "application/json"} # _api.do() does retry - abort_url_response = self._api.do('POST', - '/api/2.0/fs/create-abort-upload-url', - headers=headers, - body=body) + abort_url_response = self._api.do("POST", "/api/2.0/fs/create-abort-upload-url", headers=headers, body=body) # no need to check response status, _api.do() will throw exception on failure - abort_upload_url_node = abort_url_response['abort_upload_url'] - abort_url = abort_upload_url_node['url'] - required_headers = abort_upload_url_node.get('headers', []) + abort_upload_url_node = abort_url_response["abort_upload_url"] + abort_url = abort_upload_url_node["url"] + required_headers = abort_upload_url_node.get("headers", []) - headers: dict = {'Content-Type': 'application/octet-stream'} + headers: dict = {"Content-Type": "application/octet-stream"} for h in required_headers: - headers[h['name']] = h['value'] + headers[h["name"]] = h["value"] def perform(): result = cloud_provider_session.request( - 'DELETE', + "DELETE", abort_url, headers=headers, - data=b'', - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) return result # following _BaseClient timeout retry_timeout_seconds = self._config.retry_timeout_seconds or 300 - abort_response = retried(timeout=timedelta(seconds=retry_timeout_seconds), - is_retryable=_BaseClient._is_retryable, - clock=self._clock)(perform)() + abort_response = retried( + timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, clock=self._clock + )(perform)() if abort_response.status_code not in (200, 201): raise ValueError(abort_response) - def _abort_resumable_upload(self, resumable_upload_url: str, required_headers: list, - cloud_provider_session: requests.Session): + def _abort_resumable_upload( + self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session + ): headers: dict = {} for h in required_headers: - headers[h['name']] = h['value'] + headers[h["name"]] = h["value"] def perform(): result = cloud_provider_session.request( - 'DELETE', + "DELETE", resumable_upload_url, headers=headers, - data=b'', - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds) + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) return result # following _BaseClient timeout retry_timeout_seconds = self._config.retry_timeout_seconds or 300 - abort_response = retried(timeout=timedelta(seconds=retry_timeout_seconds), - is_retryable=_BaseClient._is_retryable, - clock=self._clock)(perform)() + abort_response = retried( + timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, clock=self._clock + )(perform)() if abort_response.status_code not in (200, 201): raise ValueError(abort_response) @@ -1245,19 +1258,20 @@ def _create_cloud_provider_session(self): session = requests.Session() # following session config in _BaseClient - http_adapter = requests.adapters.HTTPAdapter(self._config.max_connection_pools or 20, - self._config.max_connections_per_pool or 20, - pool_block=True) + http_adapter = requests.adapters.HTTPAdapter( + self._config.max_connection_pools or 20, self._config.max_connections_per_pool or 20, pool_block=True + ) session.mount("https://", http_adapter) # presigned URL for storage proxy can use plain HTTP session.mount("http://", http_adapter) return session - def _download_raw_stream(self, - file_path: str, - start_byte_offset: int, - if_unmodified_since_timestamp: Optional[str] = None) -> DownloadResponse: - headers = {'Accept': 'application/octet-stream', } + def _download_raw_stream( + self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None + ) -> DownloadResponse: + headers = { + "Accept": "application/octet-stream", + } if start_byte_offset and not if_unmodified_since_timestamp: raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified") diff --git a/tests/test_files.py b/tests/test_files.py index 9e8949695..e33941f2f 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta, timezone from tempfile import mkstemp -from typing import Callable, List, Optional, Union, Type +from typing import Callable, List, Optional, Type, Union from urllib.parse import parse_qs, urlparse import pytest @@ -429,13 +429,13 @@ def __eq__(self, other): class MultipartUploadServerState: - upload_chunk_url_prefix = 'https://cloud_provider.com/upload-chunk/' - abort_upload_url_prefix = 'https://cloud_provider.com/abort-upload/' + upload_chunk_url_prefix = "https://cloud_provider.com/upload-chunk/" + abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" def __init__(self): - self.issued_multipart_urls = {} # part_number -> expiration_time - self.uploaded_chunks = {} # part_number -> [chunk file path, etag] - self.session_token = 'token-' + MultipartUploadServerState.randomstr() + self.issued_multipart_urls = {} # part_number -> expiration_time + self.uploaded_chunks = {} # part_number -> [chunk file path, etag] + self.session_token = "token-" + MultipartUploadServerState.randomstr() self.file_content = None self.issued_abort_url_expire_time = None self.aborted = False @@ -444,12 +444,12 @@ def create_upload_chunk_url(self, path: str, part_number: int, expire_time: date assert not self.aborted # client may have requested a URL for the same part if retrying on network error self.issued_multipart_urls[part_number] = expire_time - return f'{self.upload_chunk_url_prefix}{path}/{part_number}' + return f"{self.upload_chunk_url_prefix}{path}/{part_number}" def create_abort_url(self, path: str, expire_time: datetime) -> str: assert not self.aborted self.issued_abort_url_expire_time = expire_time - return f'{self.abort_upload_url_prefix}{path}' + return f"{self.abort_upload_url_prefix}{path}" def save_part(self, part_number: int, part_content: bytes, etag: str): assert not self.aborted @@ -504,11 +504,11 @@ def abort_upload(self): @staticmethod def randomstr(): - return f'{random.randrange(10000)}-{int(time.time())}' + return f"{random.randrange(10000)}-{int(time.time())}" class CustomResponse: - """ Custom response allows to override the "default" response generated by the server + """Custom response allows to override the "default" response generated by the server with the "custom" response to simulate failure error code, unexpected response body or network error. @@ -521,26 +521,19 @@ def __init__( # If True, response is defined by the current invocation count # with respect to first_invocation / last_invocation / only_invocation enabled=True, - # Custom code to return code: Optional[int] = 200, - # Custom body to return body: Optional[str] = None, - # Custom exception to raise exception: Optional[Type[BaseException]] = None, - # Whether exception should be raised before calling processor() # (so changing server state) exception_happened_before_processing: bool = False, - # First invocation (1-based) at which return custom response first_invocation: Optional[int] = None, - # Last invocation (1-based) at which return custom response last_invocation: Optional[int] = None, - # Only invocation (1-based) at which return custom response only_invocation: Optional[int] = None, ): @@ -581,7 +574,7 @@ def generate_response(self, request: requests.Request, processor: Callable[[], l # if server actually processed the request (and so changed its state) raise self.exception - custom_response = [self.code, self.body or '', {}] + custom_response = [self.code, self.body or "", {}] if activate_for_current_invocation: if self.code and 400 <= self.code < 500: @@ -628,45 +621,50 @@ class MultipartUploadTestCase: Response of each call can be modified by parameterising a respective `CustomResponse` object. """ - path = '/test.txt' - - expired_url_aws_response = ('' - 'AuthenticationFailedServer failed to authenticate ' - 'the request. Make sure the value of Authorization header is formed ' - 'correctly including the signature.\nRequestId:1abde581-601e-0028-' - '4a6d-5c3952000000\nTime:2025-01-01T16:54:20.5343181ZSignature not valid in the specified ' - 'time frame: Start [Wed, 01 Jan 2025 16:38:41 GMT] - Expiry [Wed, ' - '01 Jan 2025 16:53:45 GMT] - Current [Wed, 01 Jan 2025 16:54:20 ' - 'GMT]') - - expired_url_azure_response = ('\nAccessDenied' - 'Request has expired' - '142025-01-01T17:47:13Z' - '2025-01-01T17:48:01Z' - 'JY66KDXM4CXBZ7X2n8Qayqg60rbvut9P7pk0' - '') + path = "/test.txt" + + expired_url_aws_response = ( + '' + "AuthenticationFailedServer failed to authenticate " + "the request. Make sure the value of Authorization header is formed " + "correctly including the signature.\nRequestId:1abde581-601e-0028-" + "4a6d-5c3952000000\nTime:2025-01-01T16:54:20.5343181ZSignature not valid in the specified " + "time frame: Start [Wed, 01 Jan 2025 16:38:41 GMT] - Expiry [Wed, " + "01 Jan 2025 16:53:45 GMT] - Current [Wed, 01 Jan 2025 16:54:20 " + "GMT]" + ) + + expired_url_azure_response = ( + '\nAccessDenied' + "Request has expired" + "142025-01-01T17:47:13Z" + "2025-01-01T17:48:01Z" + "JY66KDXM4CXBZ7X2n8Qayqg60rbvut9P7pk0" + "" + ) # TODO test for overwrite = false def __init__( - self, - name: str, - stream_size: int, # size of uploaded file or, technically, stream - multipart_upload_chunk_size: Optional[int] = None, - sdk_retry_timeout_seconds: Optional[int] = None, - multipart_upload_max_retries: Optional[int] = None, - multipart_upload_batch_url_count: Optional[int] = None, - custom_response_on_initiate=CustomResponse(enabled=False), - custom_response_on_create_multipart_url=CustomResponse(enabled=False), - custom_response_on_upload=CustomResponse(enabled=False), - custom_response_on_complete=CustomResponse(enabled=False), - custom_response_on_create_abort_url=CustomResponse(enabled=False), - custom_response_on_abort=CustomResponse(enabled=False), - # exception which is expected to be thrown (so upload is expected to have failed) - expected_exception_type: Optional[Type[BaseException]] = None, - # if abort is expected to be called - expected_aborted: bool = False): + self, + name: str, + stream_size: int, # size of uploaded file or, technically, stream + multipart_upload_chunk_size: Optional[int] = None, + sdk_retry_timeout_seconds: Optional[int] = None, + multipart_upload_max_retries: Optional[int] = None, + multipart_upload_batch_url_count: Optional[int] = None, + custom_response_on_initiate=CustomResponse(enabled=False), + custom_response_on_create_multipart_url=CustomResponse(enabled=False), + custom_response_on_upload=CustomResponse(enabled=False), + custom_response_on_complete=CustomResponse(enabled=False), + custom_response_on_create_abort_url=CustomResponse(enabled=False), + custom_response_on_abort=CustomResponse(enabled=False), + # exception which is expected to be thrown (so upload is expected to have failed) + expected_exception_type: Optional[Type[BaseException]] = None, + # if abort is expected to be called + expected_aborted: bool = False, + ): self.name = name self.stream_size = stream_size self.multipart_upload_chunk_size = multipart_upload_chunk_size @@ -682,121 +680,119 @@ def __init__( self.expected_exception_type = expected_exception_type self.expected_aborted: bool = expected_aborted - def setup_session_mock(self, session_mock: requests_mock.Mocker, - server_state: MultipartUploadServerState): + def setup_session_mock(self, session_mock: requests_mock.Mocker, server_state: MultipartUploadServerState): def custom_matcher(request): request_url = urlparse(request.url) request_query = parse_qs(request_url.query) # initial request - if (request_url.hostname == 'localhost' - and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' - and request_query.get('action') == ['initiate-upload'] and request.method == 'POST'): + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): assert MultipartUploadTestCase.is_auth_header_present(request) assert request.text is None def processor(): - response_json = {'multipart_upload': {'session_token': server_state.session_token}} + response_json = {"multipart_upload": {"session_token": server_state.session_token}} return [200, json.dumps(response_json), {}] return self.custom_response_on_initiate.generate_response(request, processor) # multipart upload, create upload part URLs - elif (request_url.hostname == 'localhost' - and request_url.path == '/api/2.0/fs/create-upload-part-urls' and request.method == 'POST'): + elif ( + request_url.hostname == "localhost" + and request_url.path == "/api/2.0/fs/create-upload-part-urls" + and request.method == "POST" + ): assert MultipartUploadTestCase.is_auth_header_present(request) request_json = request.json() - assert request_json.keys() == { - 'count', 'expire_time', 'path', 'session_token', 'start_part_number' - } - assert request_json['path'] == self.path - assert request_json['session_token'] == server_state.session_token - - start_part_number = int(request_json['start_part_number']) - count = int(request_json['count']) + assert request_json.keys() == {"count", "expire_time", "path", "session_token", "start_part_number"} + assert request_json["path"] == self.path + assert request_json["session_token"] == server_state.session_token + + start_part_number = int(request_json["start_part_number"]) + count = int(request_json["count"]) assert count >= 1 - expire_time = MultipartUploadTestCase.parse_and_validate_expire_time( - request_json['expire_time']) + expire_time = MultipartUploadTestCase.parse_and_validate_expire_time(request_json["expire_time"]) def processor(): response_nodes = [] for part_number in range(start_part_number, start_part_number + count): - upload_part_url = server_state.create_upload_chunk_url( - self.path, part_number, expire_time) - response_nodes.append({ - 'part_number': part_number, - 'url': upload_part_url, - 'headers': [{ - 'name': 'name1', - 'value': 'value1' - }] - }) - - response_json = {'upload_part_urls': response_nodes} + upload_part_url = server_state.create_upload_chunk_url(self.path, part_number, expire_time) + response_nodes.append( + { + "part_number": part_number, + "url": upload_part_url, + "headers": [{"name": "name1", "value": "value1"}], + } + ) + + response_json = {"upload_part_urls": response_nodes} return [200, json.dumps(response_json), {}] return self.custom_response_on_create_multipart_url.generate_response(request, processor) # multipart upload, uploading part - elif request.url.startswith( - MultipartUploadServerState.upload_chunk_url_prefix) and request.method == 'PUT': + elif request.url.startswith(MultipartUploadServerState.upload_chunk_url_prefix) and request.method == "PUT": assert not MultipartUploadTestCase.is_auth_header_present(request) - url_path = request.url[len(MultipartUploadServerState.abort_upload_url_prefix):] - part_num = url_path.split('/')[-1] - assert url_path[:-len(part_num) - 1] == self.path + url_path = request.url[len(MultipartUploadServerState.abort_upload_url_prefix) :] + part_num = url_path.split("/")[-1] + assert url_path[: -len(part_num) - 1] == self.path def processor(): body = request.body.read() - etag = 'etag-' + MultipartUploadServerState.randomstr() + etag = "etag-" + MultipartUploadServerState.randomstr() server_state.save_part(int(part_num), body, etag) - return [200, '', {'ETag': etag}] + return [200, "", {"ETag": etag}] return self.custom_response_on_upload.generate_response(request, processor) # multipart upload, completion - elif (request_url.hostname == 'localhost' - and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' - and request_query.get('action') == ['complete-upload'] - and request_query.get('upload_type') == ['multipart'] and request.method == 'POST'): + elif ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["complete-upload"] + and request_query.get("upload_type") == ["multipart"] + and request.method == "POST" + ): assert MultipartUploadTestCase.is_auth_header_present(request) - assert [server_state.session_token] == request_query.get('session_token') + assert [server_state.session_token] == request_query.get("session_token") def processor(): request_json = request.json() etags = {} - for part in request_json['parts']: - etags[part['part_number']] = part['etag'] + for part in request_json["parts"]: + etags[part["part_number"]] = part["etag"] server_state.upload_complete(etags) - return [200, '', {}] + return [200, "", {}] return self.custom_response_on_complete.generate_response(request, processor) # create abort URL - elif request.url == 'http://localhost/api/2.0/fs/create-abort-upload-url' and request.method == 'POST': + elif request.url == "http://localhost/api/2.0/fs/create-abort-upload-url" and request.method == "POST": assert MultipartUploadTestCase.is_auth_header_present(request) request_json = request.json() - assert request_json['path'] == self.path - expire_time = MultipartUploadTestCase.parse_and_validate_expire_time( - request_json['expire_time']) + assert request_json["path"] == self.path + expire_time = MultipartUploadTestCase.parse_and_validate_expire_time(request_json["expire_time"]) def processor(): response_json = { - 'abort_upload_url': { - 'url': server_state.create_abort_url(self.path, expire_time), - 'headers': [{ - 'name': 'header1', - 'value': 'headervalue1' - }] + "abort_upload_url": { + "url": server_state.create_abort_url(self.path, expire_time), + "headers": [{"name": "header1", "value": "headervalue1"}], } } return [200, json.dumps(response_json), {}] @@ -804,14 +800,16 @@ def processor(): return self.custom_response_on_create_abort_url.generate_response(request, processor) # abort upload - elif request.url.startswith( - MultipartUploadServerState.abort_upload_url_prefix) and request.method == 'DELETE': + elif ( + request.url.startswith(MultipartUploadServerState.abort_upload_url_prefix) + and request.method == "DELETE" + ): assert not MultipartUploadTestCase.is_auth_header_present(request) - assert request.url[len(MultipartUploadServerState.abort_upload_url_prefix):] == self.path + assert request.url[len(MultipartUploadServerState.abort_upload_url_prefix) :] == self.path def processor(): server_state.abort_upload() - return [200, '', {}] + return [200, "", {}] return self.custom_response_on_abort.generate_response(request, processor) @@ -822,16 +820,16 @@ def processor(): @staticmethod def setup_token_auth(config: Config): pat_token = "some_pat_token" - config._header_factory = lambda: {'Authorization': f'Bearer {pat_token}'} + config._header_factory = lambda: {"Authorization": f"Bearer {pat_token}"} @staticmethod def is_auth_header_present(r: requests.Request): - return r.headers.get('Authorization') is not None + return r.headers.get("Authorization") is not None @staticmethod def parse_and_validate_expire_time(s: str) -> datetime: - expire_time = datetime.strptime(s, '%Y-%m-%dT%H:%M:%SZ') - expire_time = expire_time.replace(tzinfo=timezone.utc) # Explicitly add timezone + expire_time = datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ") + expire_time = expire_time.replace(tzinfo=timezone.utc) # Explicitly add timezone now = datetime.now(timezone.utc) max_expiration = now + timedelta(hours=2) assert now < expire_time < max_expiration @@ -851,7 +849,7 @@ def run(self, config: Config): if self.multipart_upload_batch_url_count: config.multipart_upload_batch_url_count = self.multipart_upload_batch_url_count config.enable_experimental_files_api_client = True - config.multipart_upload_min_stream_size = 0 # disable single-shot uploads + config.multipart_upload_min_stream_size = 0 # disable single-shot uploads file_content = os.urandom(self.stream_size) @@ -863,7 +861,7 @@ def run(self, config: Config): self.setup_session_mock(session_mock, upload_state) def upload(): - w.files.upload('/test.txt', io.BytesIO(file_content), overwrite=True) + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True) if self.expected_exception_type is not None: with pytest.raises(self.expected_exception_type): @@ -891,283 +889,317 @@ def to_string(test_case): [ # -------------------------- failures on "initiate upload" -------------------------- MultipartUploadTestCase( - 'Initiate: 400 response is not retried', + "Initiate: 400 response is not retried", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse(code=400, only_invocation=1), expected_exception_type=BadRequest, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( - 'Initiate: 403 response is not retried', + "Initiate: 403 response is not retried", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse(code=403, only_invocation=1), expected_exception_type=PermissionDenied, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( - 'Initiate: 500 response is not retried', + "Initiate: 500 response is not retried", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse(code=500, only_invocation=1), expected_exception_type=InternalError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( - 'Initiate: non-JSON response is not retried', + "Initiate: non-JSON response is not retried", stream_size=1024 * 1024, - custom_response_on_initiate=CustomResponse(body='this is not a JSON', only_invocation=1), + custom_response_on_initiate=CustomResponse(body="this is not a JSON", only_invocation=1), expected_exception_type=requests.exceptions.JSONDecodeError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( - 'Initiate: meaningless JSON response is not retried', + "Initiate: meaningless JSON response is not retried", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse(body='{"foo": 123}', only_invocation=1), expected_exception_type=ValueError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( - 'Initiate: no session token in response is not retried', + "Initiate: no session token in response is not retried", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse( - body='{"multipart_upload":{"session_token1": "token123"}}', only_invocation=1), + body='{"multipart_upload":{"session_token1": "token123"}}', only_invocation=1 + ), expected_exception_type=ValueError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( - 'Initiate: permanent retryable exception', + "Initiate: permanent retryable exception", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError), - sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout) + sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout) expected_exception_type=TimeoutError, - expected_aborted=False # upload didn't start - ), - MultipartUploadTestCase('Initiate: intermittent retryable exception', - stream_size=1024 * 1024, - custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError, - first_invocation=1, - last_invocation=3), - expected_aborted=False), - MultipartUploadTestCase('Initiate: intermittent retryable error', - stream_size=1024 * 1024, - custom_response_on_initiate=CustomResponse( - code=429, first_invocation=1, last_invocation=3), - expected_aborted=False), - + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: intermittent retryable exception", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse( + exception=requests.ConnectionError, first_invocation=1, last_invocation=3 + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Initiate: intermittent retryable error", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=429, first_invocation=1, last_invocation=3), + expected_aborted=False, + ), # -------------------------- failures on "create upload URL" -------------------------- - MultipartUploadTestCase('Create upload URL: client error is not retied', - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(code=400, - only_invocation=1), - expected_exception_type=BadRequest, - expected_aborted=True), - MultipartUploadTestCase('Create upload URL: internal error is not retied', - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(code=500, - only_invocation=1), - expected_exception_type=InternalError, - expected_aborted=True), - MultipartUploadTestCase('Create upload URL: non-JSON response is not retried', - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse( - body='this is not a JSON', only_invocation=1), - expected_exception_type=requests.exceptions.JSONDecodeError, - expected_aborted=True), - MultipartUploadTestCase('Create upload URL: meaningless JSON response is not retried', - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(body='{"foo":123}', - only_invocation=1), - expected_exception_type=ValueError, - expected_aborted=True), - MultipartUploadTestCase('Create upload URL: meaningless JSON response is not retried 2', - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse( - body='{"upload_part_urls":[]}', only_invocation=1), - expected_exception_type=ValueError, - expected_aborted=True), MultipartUploadTestCase( - 'Create upload URL: meaningless JSON response is not retried 3', + "Create upload URL: client error is not retied", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=400, only_invocation=1), + expected_exception_type=BadRequest, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: internal error is not retied", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: non-JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body="this is not a JSON", only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: meaningless JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body='{"foo":123}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: meaningless JSON response is not retried 2", stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(body='{"upload_part_urls":[{"url":""}]}', - only_invocation=1), - expected_exception_type=KeyError, # TODO we might want to make JSON parsing more reliable - expected_aborted=True), + custom_response_on_create_multipart_url=CustomResponse(body='{"upload_part_urls":[]}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=True, + ), MultipartUploadTestCase( - 'Create upload URL: permanent retryable exception', + "Create upload URL: meaningless JSON response is not retried 3", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + body='{"upload_part_urls":[{"url":""}]}', only_invocation=1 + ), + expected_exception_type=KeyError, # TODO we might want to make JSON parsing more reliable + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: permanent retryable exception", stream_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(exception=requests.ConnectionError), - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) expected_exception_type=TimeoutError, - expected_aborted=True), - MultipartUploadTestCase('Create upload URL: intermittent retryable exception', - stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse( - exception=requests.Timeout, only_invocation=1), - expected_aborted=False), + expected_aborted=True, + ), MultipartUploadTestCase( - 'Create upload URL: intermittent retryable exception 2', - stream_size=100 * 1024 * 1024, # 10 chunks + "Create upload URL: intermittent retryable exception", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(exception=requests.Timeout, only_invocation=1), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Create upload URL: intermittent retryable exception 2", + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( exception=requests.Timeout, # 4th request for multipart URLs will fail 3 times and will be retried first_invocation=4, - last_invocation=6), - expected_aborted=False), - + last_invocation=6, + ), + expected_aborted=False, + ), # -------------------------- failures on chunk upload -------------------------- MultipartUploadTestCase( - 'Upload chunk: 403 response is not retried', - stream_size=100 * 1024 * 1024, # 10 chunks + "Upload chunk: 403 response is not retried", + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=403, only_invocation=2), expected_exception_type=PermissionDenied, - expected_aborted=True), + expected_aborted=True, + ), MultipartUploadTestCase( - 'Upload chunk: 400 response is not retried', - stream_size=100 * 1024 * 1024, # 10 chunks + "Upload chunk: 400 response is not retried", + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=400, only_invocation=9), expected_exception_type=BadRequest, - expected_aborted=True), + expected_aborted=True, + ), MultipartUploadTestCase( - 'Upload chunk: 500 response is not retried', # TODO should we retry chunk upload on internal error from Cloud provider? - stream_size=100 * 1024 * 1024, # 10 chunks + "Upload chunk: 500 response is not retried", # TODO should we retry chunk upload on internal error from Cloud provider? + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=500, only_invocation=5), expected_exception_type=InternalError, - expected_aborted=True), + expected_aborted=True, + ), MultipartUploadTestCase( - 'Upload chunk: expired URL is retried on AWS', - stream_size=100 * 1024 * 1024, # 10 chunks + "Upload chunk: expired URL is retried on AWS", + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( - code=403, body=MultipartUploadTestCase.expired_url_aws_response, only_invocation=2), - expected_aborted=False), + code=403, body=MultipartUploadTestCase.expired_url_aws_response, only_invocation=2 + ), + expected_aborted=False, + ), MultipartUploadTestCase( - 'Upload chunk: expired URL is retried on Azure', + "Upload chunk: expired URL is retried on Azure", multipart_upload_max_retries=3, - stream_size=100 * 1024 * 1024, # 10 chunks + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_azure_response, # 3 failures don't exceed multipart_upload_max_retries first_invocation=2, - last_invocation=4), - expected_aborted=False), + last_invocation=4, + ), + expected_aborted=False, + ), MultipartUploadTestCase( - 'Upload chunk: expired URL is retried on Azure, requesting urls by 6', + "Upload chunk: expired URL is retried on Azure, requesting urls by 6", multipart_upload_max_retries=3, multipart_upload_batch_url_count=6, - stream_size=100 * 1024 * 1024, # 100 chunks + stream_size=100 * 1024 * 1024, # 100 chunks multipart_upload_chunk_size=1 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_azure_response, # 3 failures don't exceed multipart_upload_max_retries first_invocation=2, - last_invocation=4), - expected_aborted=False), + last_invocation=4, + ), + expected_aborted=False, + ), MultipartUploadTestCase( - 'Upload chunk: expired URL retry is limited', + "Upload chunk: expired URL retry is limited", multipart_upload_max_retries=3, - stream_size=100 * 1024 * 1024, # 10 chunks + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( code=403, body=MultipartUploadTestCase.expired_url_azure_response, # 4 failures exceed multipart_upload_max_retries first_invocation=2, - last_invocation=5), + last_invocation=5, + ), expected_exception_type=ValueError, - expected_aborted=True), + expected_aborted=True, + ), MultipartUploadTestCase( - 'Upload chunk: permanent retryable error', - stream_size=100 * 1024 * 1024, # 10 chunks + "Upload chunk: permanent retryable error", + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, first_invocation=8), expected_exception_type=TimeoutError, - expected_aborted=True), + expected_aborted=True, + ), MultipartUploadTestCase( - 'Upload chunk: intermittent retryable error', - stream_size=100 * 1024 * 1024, # 10 chunks + "Upload chunk: intermittent retryable error", + stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse( - exception=requests.ConnectionError, first_invocation=2, last_invocation=5), - expected_aborted=False), - + exception=requests.ConnectionError, first_invocation=2, last_invocation=5 + ), + expected_aborted=False, + ), # -------------------------- failures on abort -------------------------- MultipartUploadTestCase( - 'Abort URL: client error', + "Abort URL: client error", stream_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), custom_response_on_create_abort_url=CustomResponse(code=400), - expected_exception_type=InternalError, # original error - expected_aborted=False), + expected_exception_type=InternalError, # original error + expected_aborted=False, + ), MultipartUploadTestCase( - 'Abort URL: forbidden', + "Abort URL: forbidden", stream_size=1024 * 1024, custom_response_on_upload=CustomResponse(code=500, only_invocation=1), custom_response_on_create_abort_url=CustomResponse(code=403), - expected_exception_type=InternalError, # original error - expected_aborted=False), + expected_exception_type=InternalError, # original error + expected_aborted=False, + ), MultipartUploadTestCase( - 'Abort URL: intermittent retryable error', + "Abort URL: intermittent retryable error", stream_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), - custom_response_on_create_abort_url=CustomResponse( - code=429, first_invocation=1, last_invocation=3), - expected_exception_type=InternalError, # original error - expected_aborted=True # abort successfully called after abort URL creation is retried + custom_response_on_create_abort_url=CustomResponse(code=429, first_invocation=1, last_invocation=3), + expected_exception_type=InternalError, # original error + expected_aborted=True, # abort successfully called after abort URL creation is retried ), MultipartUploadTestCase( - 'Abort URL: intermittent retryable error 2', + "Abort URL: intermittent retryable error 2", stream_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), custom_response_on_create_abort_url=CustomResponse( - exception=requests.Timeout, first_invocation=1, last_invocation=3), - expected_exception_type=InternalError, # original error - expected_aborted=True # abort successfully called after abort URL creation is retried + exception=requests.Timeout, first_invocation=1, last_invocation=3 + ), + expected_exception_type=InternalError, # original error + expected_aborted=True, # abort successfully called after abort URL creation is retried ), MultipartUploadTestCase( - 'Abort: exception', + "Abort: exception", stream_size=1024 * 1024, # don't wait for 5 min (SDK default timeout) sdk_retry_timeout_seconds=30, # simulate PermissionDenied custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), - custom_response_on_abort=CustomResponse(exception=requests.Timeout, - exception_happened_before_processing=False), - expected_exception_type=PermissionDenied, # original error is reported - expected_aborted=True # abort called but failed + custom_response_on_abort=CustomResponse( + exception=requests.Timeout, exception_happened_before_processing=False + ), + expected_exception_type=PermissionDenied, # original error is reported + expected_aborted=True, # abort called but failed ), - # -------------------------- happy cases -------------------------- MultipartUploadTestCase( - 'Multipart upload successful: single chunk', - stream_size=1024 * 1024, # less than chunk size - multipart_upload_chunk_size=10 * 1024 * 1024), + "Multipart upload successful: single chunk", + stream_size=1024 * 1024, # less than chunk size + multipart_upload_chunk_size=10 * 1024 * 1024, + ), MultipartUploadTestCase( - 'Multipart upload successful: multiple chunks (aligned)', - stream_size=100 * 1024 * 1024, # 10 chunks - multipart_upload_chunk_size=10 * 1024 * 1024), + "Multipart upload successful: multiple chunks (aligned)", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + ), MultipartUploadTestCase( - 'Multipart upload successful: multiple chunks (aligned), upload urls by 3', + "Multipart upload successful: multiple chunks (aligned), upload urls by 3", multipart_upload_batch_url_count=3, - stream_size=100 * 1024 * 1024, # 10 chunks - multipart_upload_chunk_size=10 * 1024 * 1024), + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + ), MultipartUploadTestCase( - 'Multipart upload successful: multiple chunks (not aligned)', - stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder - multipart_upload_chunk_size=7 * 1024 * 1024 - 17), + "Multipart upload successful: multiple chunks (not aligned)", + stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder + multipart_upload_chunk_size=7 * 1024 * 1024 - 17, + ), MultipartUploadTestCase( - 'Multipart upload successful: multiple chunks (not aligned), upload urls by 5', + "Multipart upload successful: multiple chunks (not aligned), upload urls by 5", multipart_upload_batch_url_count=5, - stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder - multipart_upload_chunk_size=7 * 1024 * 1024 - 17), + stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder + multipart_upload_chunk_size=7 * 1024 * 1024 - 17, + ), ], - ids=MultipartUploadTestCase.to_string) + ids=MultipartUploadTestCase.to_string, +) def test_multipart_upload(config: Config, test_case: MultipartUploadTestCase): test_case.run(config) @@ -1180,8 +1212,7 @@ def __init__(self): class SingleShotUploadTestCase: - def __init__(self, name: str, stream_size: int, multipart_upload_min_stream_size: int, - expected_single_shot: bool): + def __init__(self, name: str, stream_size: int, multipart_upload_min_stream_size: int, expected_single_shot: bool): self.name = name self.stream_size = stream_size self.multipart_upload_min_stream_size = multipart_upload_min_stream_size @@ -1203,8 +1234,7 @@ def run(self, config: Config): session = requests.Session() with requests_mock.Mocker(session=session) as session_mock: - session_mock.get(f'http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}', - status_code=200) + session_mock.get(f"http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}", status_code=200) upload_state = SingleShotUploadState() @@ -1213,27 +1243,31 @@ def custom_matcher(request): request_query = parse_qs(request_url.query) if self.expected_single_shot: - if (request_url.hostname == 'localhost' - and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' - and request.method == 'PUT'): + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request.method == "PUT" + ): body = request.body.read() upload_state.single_shot_file_content = FileContent.from_bytes(body) resp = requests.Response() resp.status_code = 204 resp.request = request - resp._content = b'' + resp._content = b"" return resp else: - if (request_url.hostname == 'localhost' - and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' - and request_query.get('action') == ['initiate-upload'] - and request.method == 'POST'): + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): resp = requests.Response() - resp.status_code = 403 # this will throw, that's fine + resp.status_code = 403 # this will throw, that's fine resp.request = request - resp._content = b'' + resp._content = b"" return resp return None @@ -1244,7 +1278,7 @@ def custom_matcher(request): w.files._api._api_client._session = session def upload(): - w.files.upload('/test.txt', io.BytesIO(file_content), overwrite=True) + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True) if self.expected_single_shot: upload() @@ -1255,34 +1289,43 @@ def upload(): upload() -@pytest.mark.parametrize("test_case", [ - SingleShotUploadTestCase('Single-shot upload', - stream_size=1024 * 1024, - multipart_upload_min_stream_size=1024 * 1024 + 1, - expected_single_shot=True), - SingleShotUploadTestCase('Multipart upload 1', - stream_size=1024 * 1024, - multipart_upload_min_stream_size=1024 * 1024, - expected_single_shot=False), - SingleShotUploadTestCase('Multipart upload 2', - stream_size=1024 * 1024, - multipart_upload_min_stream_size=0, - expected_single_shot=False), -], - ids=SingleShotUploadTestCase.to_string) +@pytest.mark.parametrize( + "test_case", + [ + SingleShotUploadTestCase( + "Single-shot upload", + stream_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024 + 1, + expected_single_shot=True, + ), + SingleShotUploadTestCase( + "Multipart upload 1", + stream_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024, + expected_single_shot=False, + ), + SingleShotUploadTestCase( + "Multipart upload 2", + stream_size=1024 * 1024, + multipart_upload_min_stream_size=0, + expected_single_shot=False, + ), + ], + ids=SingleShotUploadTestCase.to_string, +) def test_single_shot_upload(config: Config, test_case: SingleShotUploadTestCase): test_case.run(config) class ResumableUploadServerState: - resumable_upload_url_prefix = 'https://cloud_provider.com/resumable-upload/' - abort_upload_url_prefix = 'https://cloud_provider.com/abort-upload/' + resumable_upload_url_prefix = "https://cloud_provider.com/resumable-upload/" + abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" def __init__(self, unconfirmed_delta: Union[int, list]): self.unconfirmed_delta = unconfirmed_delta - self.confirmed_last_byte: Optional[int] = None # inclusive + self.confirmed_last_byte: Optional[int] = None # inclusive self.uploaded_parts = [] - self.session_token = 'token-' + MultipartUploadServerState.randomstr() + self.session_token = "token-" + MultipartUploadServerState.randomstr() self.file_content = None self.aborted = False @@ -1297,21 +1340,21 @@ def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes assert end_offset_incl == start_offset + len(part_content) - 1 - is_last_part = file_size_s != '*' + is_last_part = file_size_s != "*" if is_last_part: assert int(file_size_s) == end_offset_incl + 1 else: - assert not self.file_content # last chunk should not have been uploaded yet + assert not self.file_content # last chunk should not have been uploaded yet if isinstance(self.unconfirmed_delta, int): unconfirmed_delta = self.unconfirmed_delta elif len(self.uploaded_parts) < len(self.unconfirmed_delta): unconfirmed_delta = self.unconfirmed_delta[len(self.uploaded_parts)] else: - unconfirmed_delta = self.unconfirmed_delta[-1] # take the last delta + unconfirmed_delta = self.unconfirmed_delta[-1] # take the last delta if unconfirmed_delta >= len(part_content): - unconfirmed_delta = 0 # otherwise we never finish + unconfirmed_delta = 0 # otherwise we never finish logger.info( f"Saving part {len(self.uploaded_parts) + 1} of original size {len(part_content)} with unconfirmed delta {unconfirmed_delta}. is_last_part = {is_last_part}" @@ -1343,7 +1386,7 @@ def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes def create_abort_url(self, path: str, expire_time: datetime) -> str: assert not self.aborted self.issued_abort_url_expire_time = expire_time - return f'{self.abort_upload_url_prefix}{path}' + return f"{self.abort_upload_url_prefix}{path}" def cleanup(self): for file in self.uploaded_parts: @@ -1371,34 +1414,33 @@ class ResumableUploadTestCase: Response of each call can be modified by parameterising a respective `CustomResponse` object. """ - path = '/test.txt' + + path = "/test.txt" def __init__( - self, - name: str, - stream_size: int, - overwrite: bool = True, - multipart_upload_chunk_size: Optional[int] = None, - sdk_retry_timeout_seconds: Optional[int] = None, - multipart_upload_max_retries: Optional[int] = None, - - # In resumable upload, when replying to chunk upload request, server returns - # (confirms) last accepted byte offset for the client to resume upload after. - # - # `unconfirmed_delta` defines offset from the end of the chunk that remains - # "unconfirmed", i.e. the last accepted offset would be (range_end - unconfirmed_delta). - # Can be int (same for all chunks) or list (individual for each chunk). - unconfirmed_delta: Union[int, list] = 0, - custom_response_on_create_resumable_url=CustomResponse(enabled=False), - custom_response_on_upload=CustomResponse(enabled=False), - custom_response_on_status_check=CustomResponse(enabled=False), - custom_response_on_abort=CustomResponse(enabled=False), - - # exception which is expected to be thrown (so upload is expected to have failed) - expected_exception_type: Optional[Type[BaseException]] = None, - - # if abort is expected to be called - expected_aborted: bool = False): + self, + name: str, + stream_size: int, + overwrite: bool = True, + multipart_upload_chunk_size: Optional[int] = None, + sdk_retry_timeout_seconds: Optional[int] = None, + multipart_upload_max_retries: Optional[int] = None, + # In resumable upload, when replying to chunk upload request, server returns + # (confirms) last accepted byte offset for the client to resume upload after. + # + # `unconfirmed_delta` defines offset from the end of the chunk that remains + # "unconfirmed", i.e. the last accepted offset would be (range_end - unconfirmed_delta). + # Can be int (same for all chunks) or list (individual for each chunk). + unconfirmed_delta: Union[int, list] = 0, + custom_response_on_create_resumable_url=CustomResponse(enabled=False), + custom_response_on_upload=CustomResponse(enabled=False), + custom_response_on_status_check=CustomResponse(enabled=False), + custom_response_on_abort=CustomResponse(enabled=False), + # exception which is expected to be thrown (so upload is expected to have failed) + expected_exception_type: Optional[Type[BaseException]] = None, + # if abort is expected to be called + expected_aborted: bool = False, + ): self.name = name self.stream_size = stream_size self.overwrite = overwrite @@ -1413,50 +1455,51 @@ def __init__( self.expected_exception_type = expected_exception_type self.expected_aborted: bool = expected_aborted - def setup_session_mock(self, session_mock: requests_mock.Mocker, - server_state: ResumableUploadServerState): + def setup_session_mock(self, session_mock: requests_mock.Mocker, server_state: ResumableUploadServerState): def custom_matcher(request): request_url = urlparse(request.url) request_query = parse_qs(request_url.query) # initial request - if (request_url.hostname == 'localhost' - and request_url.path == f'/api/2.0/fs/files{MultipartUploadTestCase.path}' - and request_query.get('action') == ['initiate-upload'] and request.method == 'POST'): + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): assert MultipartUploadTestCase.is_auth_header_present(request) assert request.text is None def processor(): - response_json = {'resumable_upload': {'session_token': server_state.session_token}} + response_json = {"resumable_upload": {"session_token": server_state.session_token}} return [200, json.dumps(response_json), {}] # Different initiate error responses have been verified by test_multipart_upload(), # so we're always generating a "success" response. return CustomResponse(enabled=False).generate_response(request, processor) - elif (request_url.hostname == 'localhost' - and request_url.path == '/api/2.0/fs/create-resumable-upload-url' - and request.method == 'POST'): + elif ( + request_url.hostname == "localhost" + and request_url.path == "/api/2.0/fs/create-resumable-upload-url" + and request.method == "POST" + ): assert MultipartUploadTestCase.is_auth_header_present(request) request_json = request.json() - assert request_json.keys() == {'path', 'session_token'} - assert request_json['path'] == self.path - assert request_json['session_token'] == server_state.session_token + assert request_json.keys() == {"path", "session_token"} + assert request_json["path"] == self.path + assert request_json["session_token"] == server_state.session_token def processor(): - resumable_upload_url = f'{ResumableUploadServerState.resumable_upload_url_prefix}{self.path}' + resumable_upload_url = f"{ResumableUploadServerState.resumable_upload_url_prefix}{self.path}" response_json = { - 'resumable_upload_url': { - 'url': resumable_upload_url, - 'headers': [{ - 'name': 'name1', - 'value': 'value1' - }] + "resumable_upload_url": { + "url": resumable_upload_url, + "headers": [{"name": "name1", "value": "value1"}], } } return [200, json.dumps(response_json), {}] @@ -1464,15 +1507,17 @@ def processor(): return self.custom_response_on_create_resumable_url.generate_response(request, processor) # resumable upload, uploading part - elif request.url.startswith( - ResumableUploadServerState.resumable_upload_url_prefix) and request.method == 'PUT': + elif ( + request.url.startswith(ResumableUploadServerState.resumable_upload_url_prefix) + and request.method == "PUT" + ): assert not MultipartUploadTestCase.is_auth_header_present(request) - url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix):] + url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix) :] assert url_path == self.path - content_range_header = request.headers['Content-range'] - is_status_check_request = re.match('bytes \\*/\\*', content_range_header) + content_range_header = request.headers["Content-range"] + is_status_check_request = re.match("bytes \\*/\\*", content_range_header) if is_status_check_request: assert not request.body response_customizer = self.custom_response_on_status_check @@ -1483,35 +1528,37 @@ def processor(): if not is_status_check_request: body = request.body.read() - match = re.match('bytes (\\d+)-(\\d+)/(.+)', content_range_header) + match = re.match("bytes (\\d+)-(\\d+)/(.+)", content_range_header) [range_start_s, range_end_s, file_size_s] = match.groups() server_state.save_part(int(range_start_s), int(range_end_s), body, file_size_s) if server_state.file_content: # upload complete - return [200, '', {}] + return [200, "", {}] else: # more data expected if server_state.confirmed_last_byte: - headers = {'Range': f'bytes=0-{server_state.confirmed_last_byte}'} + headers = {"Range": f"bytes=0-{server_state.confirmed_last_byte}"} else: headers = {} - return [308, '', headers] + return [308, "", headers] return response_customizer.generate_response(request, processor) # abort upload - elif request.url.startswith( - ResumableUploadServerState.resumable_upload_url_prefix) and request.method == 'DELETE': + elif ( + request.url.startswith(ResumableUploadServerState.resumable_upload_url_prefix) + and request.method == "DELETE" + ): assert not MultipartUploadTestCase.is_auth_header_present(request) - url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix):] + url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix) :] assert url_path == self.path def processor(): server_state.abort_upload() - return [200, '', {}] + return [200, "", {}] return self.custom_response_on_abort.generate_response(request, processor) @@ -1528,7 +1575,7 @@ def run(self, config: Config): if self.multipart_upload_max_retries: config.multipart_upload_max_retries = self.multipart_upload_max_retries config.enable_experimental_files_api_client = True - config.multipart_upload_min_stream_size = 0 # disable single-shot uploads + config.multipart_upload_min_stream_size = 0 # disable single-shot uploads MultipartUploadTestCase.setup_token_auth(config) @@ -1542,7 +1589,7 @@ def run(self, config: Config): w = WorkspaceClient(config=config) def upload(): - w.files.upload('/test.txt', io.BytesIO(file_content), overwrite=self.overwrite) + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=self.overwrite) if self.expected_exception_type is not None: with pytest.raises(self.expected_exception_type): @@ -1570,71 +1617,74 @@ def to_string(test_case): [ # ------------------ failures on creating resumable upload URL ------------------ ResumableUploadTestCase( - 'Create resumable URL: client error is not retried', + "Create resumable URL: client error is not retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=400, only_invocation=1), expected_exception_type=BadRequest, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - 'Create resumable URL: permission denied is not retried', + "Create resumable URL: permission denied is not retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=403, only_invocation=1), expected_exception_type=PermissionDenied, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - 'Create resumable URL: internal error is not retried', + "Create resumable URL: internal error is not retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=500, only_invocation=1), expected_exception_type=InternalError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - 'Create resumable URL: non-JSON response is not retried', + "Create resumable URL: non-JSON response is not retried", stream_size=1024 * 1024, - custom_response_on_create_resumable_url=CustomResponse(body='Foo bar', only_invocation=1), + custom_response_on_create_resumable_url=CustomResponse(body="Foo bar", only_invocation=1), expected_exception_type=requests.exceptions.JSONDecodeError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - 'Create resumable URL: meaningless JSON response is not retried', + "Create resumable URL: meaningless JSON response is not retried", stream_size=1024 * 1024, - custom_response_on_create_resumable_url=CustomResponse(body='{"upload_part_urls":[{"url":""}]}', - only_invocation=1), + custom_response_on_create_resumable_url=CustomResponse( + body='{"upload_part_urls":[{"url":""}]}', only_invocation=1 + ), expected_exception_type=ValueError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - 'Create resumable URL: permanent retryable exception', + "Create resumable URL: permanent retryable exception", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=429), - sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) expected_exception_type=TimeoutError, - expected_aborted=False # upload didn't start + expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - 'Create resumable URL: intermittent retryable exception is retried', + "Create resumable URL: intermittent retryable exception is retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse( exception=requests.Timeout, # 3 failures total first_invocation=1, - last_invocation=3), - expected_aborted=False # upload succeeds + last_invocation=3, + ), + expected_aborted=False, # upload succeeds ), - # ------------------ failures during upload ------------------ ResumableUploadTestCase( - 'Upload: retryable exception after file is uploaded', + "Upload: retryable exception after file is uploaded", stream_size=1024 * 1024, - custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, - exception_happened_before_processing=False), + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, exception_happened_before_processing=False + ), # Despite the returned error, file has been uploaded. We'll discover that # on the next status check and consider upload completed. - expected_aborted=False), + expected_aborted=False, + ), ResumableUploadTestCase( - 'Upload: retryable exception before file is uploaded, not enough retries', + "Upload: retryable exception before file is uploaded, not enough retries", stream_size=1024 * 1024, multipart_upload_max_retries=3, custom_response_on_upload=CustomResponse( @@ -1643,12 +1693,14 @@ def to_string(test_case): exception_happened_before_processing=True, # fail 4 times, exceeding max_retries first_invocation=1, - last_invocation=4), + last_invocation=4, + ), # File was never uploaded and we gave up retrying expected_exception_type=requests.ConnectionError, - expected_aborted=True), + expected_aborted=True, + ), ResumableUploadTestCase( - 'Upload: retryable exception before file is uploaded, enough retries', + "Upload: retryable exception before file is uploaded, enough retries", stream_size=1024 * 1024, multipart_upload_max_retries=4, custom_response_on_upload=CustomResponse( @@ -1657,13 +1709,14 @@ def to_string(test_case): exception_happened_before_processing=True, # fail 4 times, not exceeding max_retries first_invocation=1, - last_invocation=4), + last_invocation=4, + ), # File was uploaded after retries - expected_aborted=False), - + expected_aborted=False, + ), # -------------- abort failures -------------- ResumableUploadTestCase( - 'Abort: client error', + "Abort: client error", stream_size=1024 * 1024, # prevent chunk from being uploaded custom_response_on_upload=CustomResponse(code=403), @@ -1671,33 +1724,41 @@ def to_string(test_case): custom_response_on_abort=CustomResponse(code=500), expected_exception_type=PermissionDenied, # abort returned error but was invoked - expected_aborted=True), - + expected_aborted=True, + ), # -------------- file already exists -------------- - ResumableUploadTestCase('File already exists', - stream_size=1024 * 1024, - overwrite=False, - custom_response_on_upload=CustomResponse(code=412), - expected_exception_type=AlreadyExists, - expected_aborted=True), - + ResumableUploadTestCase( + "File already exists", + stream_size=1024 * 1024, + overwrite=False, + custom_response_on_upload=CustomResponse(code=412), + expected_exception_type=AlreadyExists, + expected_aborted=True, + ), # -------------- success cases -------------- - ResumableUploadTestCase('Multiple chunks, zero unconfirmed delta', - stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024 + 566, - unconfirmed_delta=0, - expected_aborted=False), - ResumableUploadTestCase('Multiple chunks, non-zero unconfirmed delta', - stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024 + 566, - unconfirmed_delta=239, - expected_aborted=False), - ResumableUploadTestCase('Multiple chunks, variable unconfirmed delta', - stream_size=100 * 1024 * 1024, - multipart_upload_chunk_size=7 * 1024 * 1024 + 566, - unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5], - expected_aborted=False), + ResumableUploadTestCase( + "Multiple chunks, zero unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + unconfirmed_delta=0, + expected_aborted=False, + ), + ResumableUploadTestCase( + "Multiple chunks, non-zero unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + unconfirmed_delta=239, + expected_aborted=False, + ), + ResumableUploadTestCase( + "Multiple chunks, variable unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5], + expected_aborted=False, + ), ], - ids=ResumableUploadTestCase.to_string) + ids=ResumableUploadTestCase.to_string, +) def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase): test_case.run(config) From a146db27cf881216145b883fd1fd164496ff855e Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Thu, 27 Feb 2025 14:26:24 +0100 Subject: [PATCH 05/11] Cleanup after merge, add some docs --- databricks/sdk/mixins/files.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 797f37d9c..3c69d4bb0 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -723,15 +723,6 @@ def download(self, file_path: str) -> DownloadResponse: initial_response.contents._response = wrapped_response return initial_response - def _download_raw_stream( - self, - file_path: str, - start_byte_offset: int, - if_unmodified_since_timestamp: Optional[str] = None, - ) -> DownloadResponse: - headers = { - "Accept": "application/octet-stream", - } def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): # Upload empty and small files with one-shot upload. @@ -885,6 +876,8 @@ def perform(): # following _BaseClient timeout retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + + # Uploading same data to the same URL is an idempotent operation, safe to retry. upload_response = retried( timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, @@ -1219,6 +1212,8 @@ def perform(): # following _BaseClient timeout retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + + # Aborting upload is an idempotent operation, safe to retry. abort_response = retried( timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, clock=self._clock )(perform)() @@ -1245,6 +1240,7 @@ def perform(): # following _BaseClient timeout retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + # Aborting upload is an idempotent operation, safe to retry. abort_response = retried( timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, clock=self._clock )(perform)() From 8c97a49073f7ced30ff6bd2b264477d81d103bab Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Thu, 27 Feb 2025 14:54:37 +0100 Subject: [PATCH 06/11] Followups --- databricks/sdk/mixins/files.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 3c69d4bb0..839d308c4 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -21,11 +21,13 @@ from urllib import parse import requests +import requests.adapters from requests import RequestException from .._base_client import _BaseClient, _RawResponse, _StreamingResponse from .._property import _cached_property from ..clock import Clock, RealClock +from ..config import Config from ..errors import AlreadyExists, NotFound from ..errors.mapper import _error_mapper from ..retries import retried @@ -723,7 +725,6 @@ def download(self, file_path: str) -> DownloadResponse: initial_response.contents._response = wrapped_response return initial_response - def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): # Upload empty and small files with one-shot upload. pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size) @@ -731,7 +732,7 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool _LOG.debug( f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes" ) - return super().upload(file_path=file_path, contents=pre_read_buffer, overwrite=overwrite) + return super().upload(file_path=file_path, contents=BytesIO(pre_read_buffer), overwrite=overwrite) query = {"action": "initiate-upload"} if overwrite is not None: @@ -829,6 +830,7 @@ def fill_buffer(): headers = {"Content-Type": "application/json"} + # Requesting URLs for the same set of parts is an idempotent operation, safe to retry. # _api.do() does retry upload_part_urls_response = self._api.do( "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body @@ -931,6 +933,7 @@ def perform(): body["parts"] = parts + # Completing upload is an idempotent operation, safe to retry. # _api.do() does retry self._api.do( "POST", @@ -1109,7 +1112,7 @@ def _resumable_upload( break elif upload_response.status_code == 308: - # chunk accepted, let's determine received offset to resume from there + # chunk accepted (or check-status succeeded), let's determine received offset to resume from there range_string = upload_response.headers.get("Range") confirmed_offset = self._extract_range_offset(range_string) _LOG.debug(f"Received confirmed offset: {confirmed_offset}") From f1f3e260b66f6d8378625ff92171d4c37e7b85d7 Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Fri, 28 Feb 2025 13:26:35 +0100 Subject: [PATCH 07/11] Improve retry logic, add more test cases and comments --- databricks/sdk/config.py | 8 +- databricks/sdk/mixins/files.py | 149 +++++++++++++++++-------------- tests/test_files.py | 157 ++++++++++++++++++++++++++------- 3 files changed, 212 insertions(+), 102 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index c81367ff2..591aafc44 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -134,8 +134,12 @@ class Config: # but a maximum time between consecutive data reception events (even 1 byte) from the server multipart_upload_single_chunk_upload_timeout_seconds: float = 60 - # Limit of retries during multipart upload. - # Retry counter is reset when progressing along the stream. + # Cap on the number of custom retries during incremental uploads: + # 1) multipart: upload part URL is expired, so new upload URLs must be requested to continue upload + # 2) resumable: chunk upload produced a retryable response (or exception), so upload status must be + # retrieved to continue the upload. + # In these two cases standard SDK retries (which are capped by the `retry_timeout_seconds` option) are not used. + # Note that retry counter is reset when upload is successfully resumed. multipart_upload_max_retries = 3 def __init__( diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 839d308c4..84a71c520 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -16,8 +16,8 @@ from datetime import timedelta from io import BytesIO from types import TracebackType -from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable, - Optional, Type, Union) +from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator, + Iterable, Optional, Type, Union) from urllib import parse import requests @@ -26,9 +26,9 @@ from .._base_client import _BaseClient, _RawResponse, _StreamingResponse from .._property import _cached_property -from ..clock import Clock, RealClock from ..config import Config from ..errors import AlreadyExists, NotFound +from ..errors.customizer import _RetryAfterCustomizer from ..errors.mapper import _error_mapper from ..retries import retried from ..service import files @@ -693,10 +693,12 @@ def delete(self, path: str, *, recursive=False): class FilesExt(files.FilesAPI): __doc__ = files.FilesAPI.__doc__ - def __init__(self, api_client, config: Config, clock: Clock = None): + # note that these error codes are retryable only for idempotent operations + _RETRYABLE_STATUS_CODES = [408, 429, 500, 502, 503, 504] + + def __init__(self, api_client, config: Config): super().__init__(api_client) self._config = config.copy() - self._clock = clock or RealClock() self._multipart_upload_read_ahead_bytes = 1 def download(self, file_path: str) -> DownloadResponse: @@ -867,25 +869,15 @@ def rewind(): chunk.seek(0, os.SEEK_SET) def perform(): - result = cloud_provider_session.request( + return cloud_provider_session.request( "PUT", url, headers=headers, data=chunk, timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, ) - return result - - # following _BaseClient timeout - retry_timeout_seconds = self._config.retry_timeout_seconds or 300 - # Uploading same data to the same URL is an idempotent operation, safe to retry. - upload_response = retried( - timeout=timedelta(seconds=retry_timeout_seconds), - is_retryable=_BaseClient._is_retryable, - clock=self._clock, - before_retry=rewind, - )(perform)() + upload_response = self._retry_idempotent_operation(perform, rewind) if upload_response.status_code in (200, 201): # Chunk upload successful @@ -1068,38 +1060,55 @@ def _resumable_upload( _LOG.debug(f"Uploading chunk: {content_range_header}") headers["Content-Range"] = content_range_header + def retrieve_upload_status() -> Optional[requests.Response]: + def perform(): + return cloud_provider_session.request( + "PUT", + resumable_upload_url, + headers={"Content-Range": "bytes */*"}, + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + + try: + return self._retry_idempotent_operation(perform) + except RequestException: + _LOG.warning("Failed to retrieve upload status") + return None + try: - # We're not retrying this single request as we don't know where to rewind to. - # Instead, in case of retryable failure we'll re-request the current offset - # from the server and resume upload from there in the main upload loop. - upload_response: requests.Response = cloud_provider_session.request( + upload_response = cloud_provider_session.request( "PUT", resumable_upload_url, headers=headers, data=BytesIO(buffer[:actual_chunk_length]), timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, ) - retry_count = 0 # reset retry count when progressing along the stream + + # https://cloud.google.com/storage/docs/performing-resumable-uploads#resume-upload + # If an upload request is terminated before receiving a response, or if you receive + # a 503 or 500 response, then you need to resume the interrupted upload from where it left off. + + # Let's follow that for all potentially retryable status codes. + if upload_response.status_code in self._RETRYABLE_STATUS_CODES: + if retry_count < self._config.multipart_upload_max_retries: + retry_count += 1 + # let original upload_response be handled as an error + upload_response = retrieve_upload_status() or upload_response + else: + # we received non-retryable response, reset retry count + retry_count = 0 + except RequestException as e: - _LOG.warning(f"Failure during upload request: {sys.exc_info()}") + # Let's do the same for retryable network errors if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries: retry_count += 1 - # Chunk upload threw an error, try to retrieve the current received offset - try: - # https://cloud.google.com/storage/docs/performing-resumable-uploads#status-check - headers["Content-Range"] = "bytes */*" - upload_response = cloud_provider_session.request( - "PUT", - resumable_upload_url, - headers=headers, - data=b"", - timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, - ) - except RequestException: - # status check failed, abort the upload + upload_response = retrieve_upload_status() + if not upload_response: + # rethrow original exception raise e from None else: - # error is not retryable, abort the upload + # rethrow original exception raise e from None if upload_response.status_code in (200, 201): @@ -1138,23 +1147,16 @@ def _resumable_upload( uploaded_bytes_count = next_chunk_offset - chunk_offset chunk_offset = next_chunk_offset - elif upload_response.status_code == 400: - # Expecting response body to be small to be safely logged - mapped_error = _error_mapper(upload_response, {}) - raise mapped_error or ValueError( - f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}" - ) - elif upload_response.status_code == 412 and not overwrite: # Assuming this is only possible reason # Full message in this case: "At least one of the pre-conditions you specified did not hold." raise AlreadyExists("The file being created already exists.") - else: - if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug(f"Failed to upload (status: {upload_response.status_code}): {upload_response.text}") + else: + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + _LOG.warning(message) mapped_error = _error_mapper(upload_response, {}) - raise mapped_error or ValueError(f"Failed to upload: {upload_response}") + raise mapped_error or ValueError(message) except Exception as e: _LOG.info(f"Aborting resumable upload on error: {e}") @@ -1204,22 +1206,15 @@ def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_pr headers[h["name"]] = h["value"] def perform(): - result = cloud_provider_session.request( + return cloud_provider_session.request( "DELETE", abort_url, headers=headers, data=b"", timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, ) - return result - # following _BaseClient timeout - retry_timeout_seconds = self._config.retry_timeout_seconds or 300 - - # Aborting upload is an idempotent operation, safe to retry. - abort_response = retried( - timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, clock=self._clock - )(perform)() + abort_response = self._retry_idempotent_operation(perform) if abort_response.status_code not in (200, 201): raise ValueError(abort_response) @@ -1232,21 +1227,15 @@ def _abort_resumable_upload( headers[h["name"]] = h["value"] def perform(): - result = cloud_provider_session.request( + return cloud_provider_session.request( "DELETE", resumable_upload_url, headers=headers, data=b"", timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, ) - return result - # following _BaseClient timeout - retry_timeout_seconds = self._config.retry_timeout_seconds or 300 - # Aborting upload is an idempotent operation, safe to retry. - abort_response = retried( - timeout=timedelta(seconds=retry_timeout_seconds), is_retryable=_BaseClient._is_retryable, clock=self._clock - )(perform)() + abort_response = self._retry_idempotent_operation(perform) if abort_response.status_code not in (200, 201): raise ValueError(abort_response) @@ -1265,6 +1254,30 @@ def _create_cloud_provider_session(self): session.mount("http://", http_adapter) return session + def _retry_idempotent_operation( + self, operation: Callable[[], requests.Response], before_retry: Callable = None + ) -> requests.Response: + def delegate(): + response = operation() + if response.status_code in self._RETRYABLE_STATUS_CODES: + attrs = {} + # this will assign "retry_after_secs" to the attrs, essentially making exception look retryable + _RetryAfterCustomizer().customize_error(response, attrs) + raise _error_mapper(response, attrs) + else: + return response + + # following _BaseClient timeout + retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + + return retried( + timeout=timedelta(seconds=retry_timeout_seconds), + # also retry on network errors (connection error, connection timeout) + # where we believe request didn't reach the server + is_retryable=_BaseClient._is_retryable, + before_retry=before_retry, + )(delegate)() + def _download_raw_stream( self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None ) -> DownloadResponse: @@ -1302,12 +1315,12 @@ def _download_raw_stream( return result - def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse): - underlying_response = _ResilientIterator._extract_raw_response(downloadResponse) + def _wrap_stream(self, file_path: str, download_response: DownloadResponse): + underlying_response = _ResilientIterator._extract_raw_response(download_response) return _ResilientResponse( self, file_path, - downloadResponse.last_modified, + download_response.last_modified, offset=0, underlying_response=underlying_response, ) diff --git a/tests/test_files.py b/tests/test_files.py index e33941f2f..e25035523 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -21,7 +21,8 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config from databricks.sdk.errors.platform import (AlreadyExists, BadRequest, - InternalError, PermissionDenied) + InternalError, PermissionDenied, + TooManyRequests) logger = logging.getLogger(__name__) @@ -549,6 +550,9 @@ def __init__( if self.only_invocation and (self.first_invocation or self.last_invocation): raise ValueError("Cannot set both only invocation and first/last invocation") + if self.exception_happened_before_processing and not self.exception: + raise ValueError("Exception is not defined") + self.invocation_count = 0 def invocation_matches(self): @@ -937,33 +941,45 @@ def to_string(test_case): stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError), sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout) - expected_exception_type=TimeoutError, + expected_exception_type=TimeoutError, # SDK throws this if retries are taking too long expected_aborted=False, # upload didn't start ), MultipartUploadTestCase( "Initiate: intermittent retryable exception", stream_size=1024 * 1024, custom_response_on_initiate=CustomResponse( - exception=requests.ConnectionError, first_invocation=1, last_invocation=3 + exception=requests.ConnectionError, + # 3 calls fail, but request is successfully retried + first_invocation=1, + last_invocation=3, ), expected_aborted=False, ), MultipartUploadTestCase( - "Initiate: intermittent retryable error", + "Initiate: intermittent retryable status code", stream_size=1024 * 1024, - custom_response_on_initiate=CustomResponse(code=429, first_invocation=1, last_invocation=3), + custom_response_on_initiate=CustomResponse( + code=429, + # 3 calls fail, then retry succeeds + first_invocation=1, + last_invocation=3, + ), expected_aborted=False, ), # -------------------------- failures on "create upload URL" -------------------------- MultipartUploadTestCase( - "Create upload URL: client error is not retied", + "Create upload URL: 400 response is not retied", stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(code=400, only_invocation=1), + custom_response_on_create_multipart_url=CustomResponse( + code=400, + # 1 failure is enough + only_invocation=1, + ), expected_exception_type=BadRequest, expected_aborted=True, ), MultipartUploadTestCase( - "Create upload URL: internal error is not retied", + "Create upload URL: 500 error is not retied", stream_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), expected_exception_type=InternalError, @@ -1010,7 +1026,11 @@ def to_string(test_case): MultipartUploadTestCase( "Create upload URL: intermittent retryable exception", stream_size=1024 * 1024, - custom_response_on_create_multipart_url=CustomResponse(exception=requests.Timeout, only_invocation=1), + custom_response_on_create_multipart_url=CustomResponse( + exception=requests.Timeout, + # happens only once, retry succeeds + only_invocation=1, + ), expected_aborted=False, ), MultipartUploadTestCase( @@ -1019,7 +1039,7 @@ def to_string(test_case): multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_create_multipart_url=CustomResponse( exception=requests.Timeout, - # 4th request for multipart URLs will fail 3 times and will be retried + # 4th request for multipart URLs fails 3 times, then retry succeeds first_invocation=4, last_invocation=6, ), @@ -1030,7 +1050,11 @@ def to_string(test_case): "Upload chunk: 403 response is not retried", stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, - custom_response_on_upload=CustomResponse(code=403, only_invocation=2), + custom_response_on_upload=CustomResponse( + code=403, + # fail only once + only_invocation=1, + ), expected_exception_type=PermissionDenied, expected_aborted=True, ), @@ -1038,12 +1062,16 @@ def to_string(test_case): "Upload chunk: 400 response is not retried", stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, - custom_response_on_upload=CustomResponse(code=400, only_invocation=9), + custom_response_on_upload=CustomResponse( + code=400, + # fail once, but not on the first chunk + only_invocation=3, + ), expected_exception_type=BadRequest, expected_aborted=True, ), MultipartUploadTestCase( - "Upload chunk: 500 response is not retried", # TODO should we retry chunk upload on internal error from Cloud provider? + "Upload chunk: 500 response is not retried", stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, custom_response_on_upload=CustomResponse(code=500, only_invocation=5), @@ -1089,7 +1117,7 @@ def to_string(test_case): expected_aborted=False, ), MultipartUploadTestCase( - "Upload chunk: expired URL retry is limited", + "Upload chunk: expired URL retry is exhausted", multipart_upload_max_retries=3, stream_size=100 * 1024 * 1024, # 10 chunks multipart_upload_chunk_size=10 * 1024 * 1024, @@ -1112,6 +1140,15 @@ def to_string(test_case): expected_exception_type=TimeoutError, expected_aborted=True, ), + MultipartUploadTestCase( + "Upload chunk: permanent retryable status code", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + custom_response_on_upload=CustomResponse(code=429, first_invocation=8), + expected_exception_type=TimeoutError, + expected_aborted=True, + ), MultipartUploadTestCase( "Upload chunk: intermittent retryable error", stream_size=100 * 1024 * 1024, # 10 chunks @@ -1121,22 +1158,29 @@ def to_string(test_case): ), expected_aborted=False, ), + MultipartUploadTestCase( + "Upload chunk: intermittent retryable status code", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse(code=429, first_invocation=2, last_invocation=4), + expected_aborted=False, + ), # -------------------------- failures on abort -------------------------- MultipartUploadTestCase( - "Abort URL: client error", + "Abort URL: 500 response", stream_size=1024 * 1024, custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), custom_response_on_create_abort_url=CustomResponse(code=400), expected_exception_type=InternalError, # original error - expected_aborted=False, + expected_aborted=False, # server state didn't change to record abort ), MultipartUploadTestCase( - "Abort URL: forbidden", + "Abort URL: 403 response", stream_size=1024 * 1024, custom_response_on_upload=CustomResponse(code=500, only_invocation=1), custom_response_on_create_abort_url=CustomResponse(code=403), expected_exception_type=InternalError, # original error - expected_aborted=False, + expected_aborted=False, # server state didn't change to record abort ), MultipartUploadTestCase( "Abort URL: intermittent retryable error", @@ -1161,13 +1205,14 @@ def to_string(test_case): stream_size=1024 * 1024, # don't wait for 5 min (SDK default timeout) sdk_retry_timeout_seconds=30, - # simulate PermissionDenied custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), custom_response_on_abort=CustomResponse( - exception=requests.Timeout, exception_happened_before_processing=False + exception=requests.Timeout, + # this allows to change the server state to "aborted" + exception_happened_before_processing=False, ), expected_exception_type=PermissionDenied, # original error is reported - expected_aborted=True, # abort called but failed + expected_aborted=True, ), # -------------------------- happy cases -------------------------- MultipartUploadTestCase( @@ -1187,14 +1232,14 @@ def to_string(test_case): multipart_upload_chunk_size=10 * 1024 * 1024, ), MultipartUploadTestCase( - "Multipart upload successful: multiple chunks (not aligned)", - stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder + "Multipart upload successful: multiple chunks (not aligned), upload urls by 1", + stream_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder multipart_upload_chunk_size=7 * 1024 * 1024 - 17, ), MultipartUploadTestCase( "Multipart upload successful: multiple chunks (not aligned), upload urls by 5", multipart_upload_batch_url_count=5, - stream_size=100 * 1024 * 1024 + 1000 + 566, # 14 full chunks + remainder + stream_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder multipart_upload_chunk_size=7 * 1024 * 1024 - 17, ), ], @@ -1617,21 +1662,25 @@ def to_string(test_case): [ # ------------------ failures on creating resumable upload URL ------------------ ResumableUploadTestCase( - "Create resumable URL: client error is not retried", + "Create resumable URL: 400 response is not retried", stream_size=1024 * 1024, - custom_response_on_create_resumable_url=CustomResponse(code=400, only_invocation=1), + custom_response_on_create_resumable_url=CustomResponse( + code=400, + # 1 failure is enough + only_invocation=1, + ), expected_exception_type=BadRequest, expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - "Create resumable URL: permission denied is not retried", + "Create resumable URL: 403 response is not retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=403, only_invocation=1), expected_exception_type=PermissionDenied, expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - "Create resumable URL: internal error is not retried", + "Create resumable URL: 500 response is not retried", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=500, only_invocation=1), expected_exception_type=InternalError, @@ -1654,7 +1703,7 @@ def to_string(test_case): expected_aborted=False, # upload didn't start ), ResumableUploadTestCase( - "Create resumable URL: permanent retryable exception", + "Create resumable URL: permanent retryable status code", stream_size=1024 * 1024, custom_response_on_create_resumable_url=CustomResponse(code=429), sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) @@ -1677,7 +1726,9 @@ def to_string(test_case): "Upload: retryable exception after file is uploaded", stream_size=1024 * 1024, custom_response_on_upload=CustomResponse( - exception=requests.ConnectionError, exception_happened_before_processing=False + exception=requests.ConnectionError, + # this makes server state change before exception is thrown + exception_happened_before_processing=False, ), # Despite the returned error, file has been uploaded. We'll discover that # on the next status check and consider upload completed. @@ -1714,6 +1765,33 @@ def to_string(test_case): # File was uploaded after retries expected_aborted=False, ), + ResumableUploadTestCase( + "Upload: intermittent 429 response: retried", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024, + multipart_upload_max_retries=3, + custom_response_on_upload=CustomResponse( + code=429, + # 3 failures not exceeding max_retries + first_invocation=2, + last_invocation=4, + ), + expected_aborted=False, # upload succeeded + ), + ResumableUploadTestCase( + "Upload: intermittent 429 response: retry exhausted", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=1 * 1024 * 1024, + multipart_upload_max_retries=3, + custom_response_on_upload=CustomResponse( + code=429, + # 4 failures exceeding max_retries + first_invocation=2, + last_invocation=5, + ), + expected_exception_type=TooManyRequests, + expected_aborted=True, + ), # -------------- abort failures -------------- ResumableUploadTestCase( "Abort: client error", @@ -1723,7 +1801,7 @@ def to_string(test_case): # internal server error does not prevent server state change custom_response_on_abort=CustomResponse(code=500), expected_exception_type=PermissionDenied, - # abort returned error but was invoked + # abort returned error but was actually processed expected_aborted=True, ), # -------------- file already exists -------------- @@ -1731,7 +1809,7 @@ def to_string(test_case): "File already exists", stream_size=1024 * 1024, overwrite=False, - custom_response_on_upload=CustomResponse(code=412), + custom_response_on_upload=CustomResponse(code=412, only_invocation=1), expected_exception_type=AlreadyExists, expected_aborted=True, ), @@ -1740,6 +1818,15 @@ def to_string(test_case): "Multiple chunks, zero unconfirmed delta", stream_size=100 * 1024 * 1024, multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + # server accepts all the chunks in full + unconfirmed_delta=0, + expected_aborted=False, + ), + ResumableUploadTestCase( + "Multiple small chunks, zero unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=100 * 1024, + # server accepts all the chunks in full unconfirmed_delta=0, expected_aborted=False, ), @@ -1747,6 +1834,7 @@ def to_string(test_case): "Multiple chunks, non-zero unconfirmed delta", stream_size=100 * 1024 * 1024, multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + # for every chunk, server accepts all except last 239 bytes unconfirmed_delta=239, expected_aborted=False, ), @@ -1754,6 +1842,11 @@ def to_string(test_case): "Multiple chunks, variable unconfirmed delta", stream_size=100 * 1024 * 1024, multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + # for the first chunk, server accepts all except last 15Kib + # for the second chunk, server accepts it all + # for the 3rd chunk, server accepts all except last 25000 bytes + # for the 4th chunk, server accepts all except last 7 Mb + # for the 5th chunk onwards server accepts all except last 5 bytes unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5], expected_aborted=False, ), From a320c69abbd6e47b5e797a407debdb2fbc965391 Mon Sep 17 00:00:00 2001 From: Kirill Safonov Date: Mon, 3 Mar 2025 16:15:58 +0100 Subject: [PATCH 08/11] Add integration test for the new Files API client --- databricks/sdk/mixins/files.py | 3 +- tests/integration/test_files.py | 136 ++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 84a71c520..02eb7765b 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -1090,6 +1090,7 @@ def perform(): # a 503 or 500 response, then you need to resume the interrupted upload from where it left off. # Let's follow that for all potentially retryable status codes. + # Together with the catch block below we replicate the logic in _retry_idempotent_operation(). if upload_response.status_code in self._RETRYABLE_STATUS_CODES: if retry_count < self._config.multipart_upload_max_retries: retry_count += 1 @@ -1100,7 +1101,7 @@ def perform(): retry_count = 0 except RequestException as e: - # Let's do the same for retryable network errors + # Let's do the same for retryable network errors. if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries: retry_count += 1 upload_response = retrieve_upload_status() diff --git a/tests/integration/test_files.py b/tests/integration/test_files.py index 348f88b05..e2a0fbd0c 100644 --- a/tests/integration/test_files.py +++ b/tests/integration/test_files.py @@ -1,13 +1,17 @@ +import datetime import io import logging import pathlib import platform +import re import time +from textwrap import dedent from typing import Callable, List, Tuple, Union import pytest from databricks.sdk.core import DatabricksError +from databricks.sdk.errors.sdk import OperationFailed from databricks.sdk.service.catalog import VolumeType @@ -382,3 +386,135 @@ def test_files_api_download_benchmark(ucws, files_api, random): ) min_str = str(best[0]) + "kb" if best[0] else "None" logging.info("Fastest chunk size: %s in %f seconds", min_str, best[1]) + + +@pytest.mark.parametrize("is_serverless", [True, False], ids=["Classic", "Serverless"]) +@pytest.mark.parametrize("use_new_files_api_client", [True, False], ids=["Default client", "Experimental client"]) +def test_files_api_in_cluster(ucws, random, env_or_skip, is_serverless, use_new_files_api_client): + from databricks.sdk.service import compute, jobs + + databricks_sdk_pypi_package = "databricks-sdk" + option_env_name = "DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT" + + launcher_file_path = f"/home/{ucws.current_user.me().user_name}/test_launcher.py" + + schema = "filesit-" + random() + volume = "filesit-" + random() + with ResourceWithCleanup.create_schema(ucws, "main", schema): + with ResourceWithCleanup.create_volume(ucws, "main", schema, volume): + + cloud_file_path = f"/Volumes/main/{schema}/{volume}/test-{random()}.txt" + file_size = 100 * 1024 * 1024 + + if use_new_files_api_client: + enable_new_files_api_env = f"os.environ['{option_env_name}'] = 'True'" + expected_files_api_client_class = "FilesExt" + else: + enable_new_files_api_env = "" + expected_files_api_client_class = "FilesAPI" + + using_files_api_client_msg = "Using files API client: " + + command = f""" + from databricks.sdk import WorkspaceClient + import io + import os + import hashlib + import logging + + logging.basicConfig(level=logging.DEBUG) + + {enable_new_files_api_env} + + file_size = {file_size} + original_content = os.urandom(file_size) + cloud_file_path = '{cloud_file_path}' + + w = WorkspaceClient() + print(f"Using SDK: {{w.config._product_info}}") + + print(f"{using_files_api_client_msg}{{type(w.files).__name__}}") + + w.files.upload(cloud_file_path, io.BytesIO(original_content), overwrite=True) + print("Upload succeeded") + + response = w.files.download(cloud_file_path) + resulting_content = response.contents.read() + print("Download succeeded") + + def hash(data: bytes): + sha256 = hashlib.sha256() + sha256.update(data) + return sha256.hexdigest() + + if len(resulting_content) != len(original_content): + raise ValueError(f"Content length does not match: expected {{len(original_content)}}, actual {{len(resulting_content)}}") + + expected_hash = hash(original_content) + actual_hash = hash(resulting_content) + if actual_hash != expected_hash: + raise ValueError(f"Content hash does not match: expected {{expected_hash}}, actual {{actual_hash}}") + + print(f"Contents of size {{len(resulting_content)}} match") + """ + + with ucws.dbfs.open(launcher_file_path, write=True, overwrite=True) as f: + f.write(dedent(command).encode()) + + if is_serverless: + # If no job_cluster_key, existing_cluster_id, or new_cluster were specified in task definition, + # then task will be executed using serverless compute. + new_cluster_spec = None + + # Library is specified in the environment + env_key = "test_env" + envs = [jobs.JobEnvironment(env_key, compute.Environment("test", [databricks_sdk_pypi_package]))] + libs = [] + else: + new_cluster_spec = compute.ClusterSpec( + spark_version=ucws.clusters.select_spark_version(long_term_support=True), + instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"), + num_workers=1, + ) + + # Library is specified in the task definition + env_key = None + envs = [] + libs = [compute.Library(pypi=compute.PythonPyPiLibrary(package=databricks_sdk_pypi_package))] + + waiter = ucws.jobs.submit( + run_name=f"py-sdk-{random(8)}", + tasks=[ + jobs.SubmitTask( + task_key="task1", + new_cluster=new_cluster_spec, + spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{launcher_file_path}"), + libraries=libs, + environment_key=env_key, + ) + ], + environments=envs, + ) + + def print_status(r: jobs.Run): + statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in r.tasks] + logging.info(f'Run status: {", ".join(statuses)}') + + logging.info(f"Waiting for the job run: {waiter.run_id}") + try: + job_run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) + task_run_id = job_run.tasks[0].run_id + task_run_logs = ucws.jobs.get_run_output(task_run_id).logs + logging.info(f"Run finished, output: {task_run_logs}") + match = re.search(f"{using_files_api_client_msg}(.*)$", task_run_logs, re.MULTILINE) + assert match is not None + files_api_client_class = match.group(1) + assert files_api_client_class == expected_files_api_client_class + + except OperationFailed: + job_run = ucws.jobs.get_run(waiter.run_id) + task_run_id = job_run.tasks[0].run_id + task_run_logs = ucws.jobs.get_run_output(task_run_id) + raise ValueError( + f"Run failed, error: {task_run_logs.error}, error trace: {task_run_logs.error_trace}, output: {task_run_logs.logs}" + ) From 6679798ff36733a4f8e70e884bfcb14269e956c7 Mon Sep 17 00:00:00 2001 From: ksafonov-db Date: Wed, 5 Mar 2025 12:09:45 +0100 Subject: [PATCH 09/11] Cleanup / address review comments --- databricks/sdk/mixins/files.py | 218 +++++++++++++++++++-------------- 1 file changed, 126 insertions(+), 92 deletions(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 02eb7765b..6a7e9d6a2 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -50,13 +50,13 @@ class _DbfsIO(BinaryIO): _closed = False def __init__( - self, - api: files.DbfsAPI, - path: str, - *, - read: bool = False, - write: bool = False, - overwrite: bool = False, + self, + api: files.DbfsAPI, + path: str, + *, + read: bool = False, + write: bool = False, + overwrite: bool = False, ): self._api = api self._path = path @@ -115,10 +115,10 @@ def closed(self) -> bool: return self._closed def __exit__( - self, - __t: Type[BaseException] | None, - __value: BaseException | None, - __traceback: TracebackType | None, + self, + __t: Type[BaseException] | None, + __value: BaseException | None, + __traceback: TracebackType | None, ): self.close() @@ -200,13 +200,13 @@ def __repr__(self) -> str: class _VolumesIO(BinaryIO): def __init__( - self, - api: files.FilesAPI, - path: str, - *, - read: bool, - write: bool, - overwrite: bool, + self, + api: files.FilesAPI, + path: str, + *, + read: bool, + write: bool, + overwrite: bool, ): self._buffer = [] self._api = api @@ -586,12 +586,12 @@ def __init__(self, api_client): self._dbfs_api = files.DbfsAPI(api_client) def open( - self, - path: str, - *, - read: bool = False, - write: bool = False, - overwrite: bool = False, + self, + path: str, + *, + read: bool = False, + write: bool = False, + overwrite: bool = False, ) -> BinaryIO: return self._path(path).open(read=read, write=write, overwrite=overwrite) @@ -717,7 +717,7 @@ def download(self, file_path: str) -> DownloadResponse: :returns: :class:`DownloadResponse` """ - initial_response: DownloadResponse = self._download_raw_stream( + initial_response: DownloadResponse = self._open_download_stream( file_path=file_path, start_byte_offset=0, if_unmodified_since_timestamp=None, @@ -728,6 +728,20 @@ def download(self, file_path: str) -> DownloadResponse: return initial_response def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): + """Upload a file. + + Uploads a file. The file contents should be sent as the request body as raw bytes (an + octet stream); do not encode or otherwise modify the bytes before sending. The contents of the + resulting file will be exactly the bytes sent in the request body. If the request is successful, there + is no response body. + + :param file_path: str + The absolute remote path of the target file. + :param contents: BinaryIO + :param overwrite: bool (optional) + If true, an existing file will be overwritten. When not specified, assumed True. + """ + # Upload empty and small files with one-shot upload. pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size) if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size: @@ -740,11 +754,10 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool if overwrite is not None: query["overwrite"] = overwrite - # _api.do() does retry + # Method _api.do() takes care of retrying and will raise an exception in case of failure. initiate_upload_response = self._api.do( "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", query=query ) - # no need to check response status, _api.do() will throw exception on failure if initiate_upload_response.get("multipart_upload"): cloud_provider_session = self._create_cloud_provider_session() @@ -753,7 +766,9 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool raise ValueError(f"Unexpected server response: {initiate_upload_response}") try: - self._multipart_upload(file_path, contents, session_token, pre_read_buffer, cloud_provider_session) + self._perform_multipart_upload( + file_path, contents, session_token, pre_read_buffer, cloud_provider_session + ) except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") try: @@ -768,20 +783,24 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool elif initiate_upload_response.get("resumable_upload"): cloud_provider_session = self._create_cloud_provider_session() session_token = initiate_upload_response["resumable_upload"]["session_token"] - self._resumable_upload( + self._perform_resumable_upload( file_path, contents, session_token, overwrite, pre_read_buffer, cloud_provider_session ) else: raise ValueError(f"Unexpected server response: {initiate_upload_response}") - def _multipart_upload( - self, - target_path: str, - input_stream: BinaryIO, - session_token: str, - pre_read_buffer: bytes, - cloud_provider_session: requests.Session, + def _perform_multipart_upload( + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, ): + """ + Performs multipart upload using presigned URLs on AWS and Azure: + https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html + """ current_part_number = 1 etags: dict = {} @@ -799,22 +818,12 @@ def _multipart_upload( # Note that initially buffer can be bigger (from pre_read_buffer). buffer = pre_read_buffer - def fill_buffer(): - bytes_to_read = max(0, self._config.multipart_upload_chunk_size - len(buffer)) - if bytes_to_read > 0: - next_buf = input_stream.read(bytes_to_read) - new_buffer = buffer + next_buf - return new_buffer - else: - # we have already buffered enough data - return buffer - retry_count = 0 eof = False while not eof: # If needed, buffer the next chunk. - buffer = fill_buffer() - if not len(buffer): + buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream) + if len(buffer) == 0: # End of stream, no need to request the next block of upload URLs. break @@ -833,20 +842,19 @@ def fill_buffer(): headers = {"Content-Type": "application/json"} # Requesting URLs for the same set of parts is an idempotent operation, safe to retry. - # _api.do() does retry + # Method _api.do() takes care of retrying and will raise an exception in case of failure. upload_part_urls_response = self._api.do( "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body ) - # no need to check response status, _api.do() will throw exception on failure upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) - if not len(upload_part_urls): + if len(upload_part_urls) == 0: raise ValueError(f"Unexpected server response: {upload_part_urls_response}") for upload_part_url in upload_part_urls: - buffer = fill_buffer() + buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream) actual_buffer_length = len(buffer) - if not actual_buffer_length: + if actual_buffer_length == 0: eof = True break @@ -926,7 +934,7 @@ def perform(): body["parts"] = parts # Completing upload is an idempotent operation, safe to retry. - # _api.do() does retry + # Method _api.do() takes care of retrying and will raise an exception in case of failure. self._api.do( "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}", @@ -934,10 +942,26 @@ def perform(): headers=headers, body=body, ) - # no need to check response status, _api.do() will throw exception on failure + + @staticmethod + def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO): + """ + Tries to fill given buffer to contain at least `desired_min_size` bytes by reading from input stream. + """ + bytes_to_read = max(0, desired_min_size - len(buffer)) + if bytes_to_read > 0: + next_buf = input_stream.read(bytes_to_read) + new_buffer = buffer + next_buf + return new_buffer + else: + # we have already buffered enough data + return buffer @staticmethod def _is_url_expired_response(response: requests.Response): + """ + Checks if response matches one of the known "URL expired" responses from the cloud storage providers. + """ if response.status_code != 403: return False @@ -967,16 +991,19 @@ def _is_url_expired_response(response: requests.Response): return False - def _resumable_upload( - self, - target_path: str, - input_stream: BinaryIO, - session_token: str, - overwrite: bool, - pre_read_buffer: bytes, - cloud_provider_session: requests.Session, + def _perform_resumable_upload( + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + overwrite: bool, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, ): - # https://cloud.google.com/storage/docs/performing-resumable-uploads + """ + Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads + """ + # Session URI we're using expires after a week # Why are we buffering the current chunk? @@ -1004,11 +1031,10 @@ def _resumable_upload( headers = {"Content-Type": "application/json"} - # _api.do() does retry + # Method _api.do() takes care of retrying and will raise an exception in case of failure. resumable_upload_url_response = self._api.do( "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body ) - # no need to check response status, _api.do() will throw exception on failure resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url") if not resumable_upload_url_node: @@ -1172,6 +1198,7 @@ def perform(): @staticmethod def _extract_range_offset(range_string: Optional[str]) -> Optional[int]: + """Parses the response range header to extract the last byte.""" if not range_string: return None # server did not yet confirm any bytes @@ -1181,6 +1208,7 @@ def _extract_range_offset(range_string: Optional[str]) -> Optional[int]: raise ValueError(f"Cannot parse response header: Range: {range_string}") def _get_url_expire_time(self): + """Generates expiration time and save it in the required format.""" current_time = datetime.datetime.now(datetime.timezone.utc) expire_time = current_time + self._config.multipart_upload_url_expiration_duration # From Google Protobuf doc: @@ -1190,13 +1218,13 @@ def _get_url_expire_time(self): return expire_time.strftime("%Y-%m-%dT%H:%M:%SZ") def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_provider_session: requests.Session): + """Aborts ongoing multipart upload session to clean up incomplete file.""" body: dict = {"path": target_path, "session_token": session_token, "expire_time": self._get_url_expire_time()} headers = {"Content-Type": "application/json"} - # _api.do() does retry + # Method _api.do() takes care of retrying and will raise an exception in case of failure. abort_url_response = self._api.do("POST", "/api/2.0/fs/create-abort-upload-url", headers=headers, body=body) - # no need to check response status, _api.do() will throw exception on failure abort_upload_url_node = abort_url_response["abort_upload_url"] abort_url = abort_upload_url_node["url"] @@ -1221,8 +1249,9 @@ def perform(): raise ValueError(abort_response) def _abort_resumable_upload( - self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session + self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session ): + """Aborts ongoing resumable upload session to clean up incomplete file.""" headers: dict = {} for h in required_headers: headers[h["name"]] = h["value"] @@ -1242,8 +1271,7 @@ def perform(): raise ValueError(abort_response) def _create_cloud_provider_session(self): - # Create a separate session which does not inherit - # auth headers from BaseClient session. + """Creates a separate session which does not inherit auth headers from BaseClient session.""" session = requests.Session() # following session config in _BaseClient @@ -1256,8 +1284,12 @@ def _create_cloud_provider_session(self): return session def _retry_idempotent_operation( - self, operation: Callable[[], requests.Response], before_retry: Callable = None + self, operation: Callable[[], requests.Response], before_retry: Callable = None ) -> requests.Response: + """Perform given idempotent operation with necessary retries. Since operation is idempotent it's + safe to retry it for response codes where server state might have changed. + """ + def delegate(): response = operation() if response.status_code in self._RETRYABLE_STATUS_CODES: @@ -1279,9 +1311,10 @@ def delegate(): before_retry=before_retry, )(delegate)() - def _download_raw_stream( - self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None + def _open_download_stream( + self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None ) -> DownloadResponse: + """Opens a download stream from given offset, performing necessary retries.""" headers = { "Accept": "application/octet-stream", } @@ -1300,6 +1333,7 @@ def _download_raw_stream( "content-type", "last-modified", ] + # Method _api.do() takes care of retrying and will raise an exception in case of failure. res = self._api.do( "GET", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", @@ -1330,12 +1364,12 @@ def _wrap_stream(self, file_path: str, download_response: DownloadResponse): class _ResilientResponse(_RawResponse): def __init__( - self, - api: FilesExt, - file_path: str, - file_last_modified: str, - offset: int, - underlying_response: _RawResponse, + self, + api: FilesExt, + file_path: str, + file_last_modified: str, + offset: int, + underlying_response: _RawResponse, ): self.api = api self.file_path = file_path @@ -1368,19 +1402,19 @@ class _ResilientIterator(Iterator): @staticmethod def _extract_raw_response( - download_response: DownloadResponse, + download_response: DownloadResponse, ) -> _RawResponse: streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse return streaming_response._response def __init__( - self, - underlying_iterator, - file_path: str, - file_last_modified: str, - offset: int, - api: FilesExt, - chunk_size: int, + self, + underlying_iterator, + file_path: str, + file_last_modified: str, + offset: int, + api: FilesExt, + chunk_size: int, ): self._underlying_iterator = underlying_iterator self._api = api @@ -1401,9 +1435,9 @@ def _should_recover(self) -> bool: _LOG.debug("Total recovers limit exceeded") return False if ( - self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None - and self._recovers_without_progressing_count - >= self._api._config.files_api_client_download_max_total_recovers_without_progressing + self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None + and self._recovers_without_progressing_count + >= self._api._config.files_api_client_download_max_total_recovers_without_progressing ): _LOG.debug("No progression recovers limit exceeded") return False @@ -1422,7 +1456,7 @@ def _recover(self) -> bool: _LOG.debug("Trying to recover from offset " + str(self._offset)) # following call includes all the required network retries - downloadResponse = self._api._download_raw_stream(self._file_path, self._offset, self._file_last_modified) + downloadResponse = self._api._open_download_stream(self._file_path, self._offset, self._file_last_modified) underlying_response = _ResilientIterator._extract_raw_response(downloadResponse) self._underlying_iterator = underlying_response.iter_content( chunk_size=self._chunk_size, decode_unicode=False From 731a3b75f12719e4350edf7ca53b1307dfa89aae Mon Sep 17 00:00:00 2001 From: ksafonov-db Date: Wed, 5 Mar 2025 12:48:23 +0100 Subject: [PATCH 10/11] Reformat --- databricks/sdk/mixins/files.py | 114 ++++++++++++++++----------------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 6a7e9d6a2..8d9923b4f 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -50,13 +50,13 @@ class _DbfsIO(BinaryIO): _closed = False def __init__( - self, - api: files.DbfsAPI, - path: str, - *, - read: bool = False, - write: bool = False, - overwrite: bool = False, + self, + api: files.DbfsAPI, + path: str, + *, + read: bool = False, + write: bool = False, + overwrite: bool = False, ): self._api = api self._path = path @@ -115,10 +115,10 @@ def closed(self) -> bool: return self._closed def __exit__( - self, - __t: Type[BaseException] | None, - __value: BaseException | None, - __traceback: TracebackType | None, + self, + __t: Type[BaseException] | None, + __value: BaseException | None, + __traceback: TracebackType | None, ): self.close() @@ -200,13 +200,13 @@ def __repr__(self) -> str: class _VolumesIO(BinaryIO): def __init__( - self, - api: files.FilesAPI, - path: str, - *, - read: bool, - write: bool, - overwrite: bool, + self, + api: files.FilesAPI, + path: str, + *, + read: bool, + write: bool, + overwrite: bool, ): self._buffer = [] self._api = api @@ -586,12 +586,12 @@ def __init__(self, api_client): self._dbfs_api = files.DbfsAPI(api_client) def open( - self, - path: str, - *, - read: bool = False, - write: bool = False, - overwrite: bool = False, + self, + path: str, + *, + read: bool = False, + write: bool = False, + overwrite: bool = False, ) -> BinaryIO: return self._path(path).open(read=read, write=write, overwrite=overwrite) @@ -790,12 +790,12 @@ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool raise ValueError(f"Unexpected server response: {initiate_upload_response}") def _perform_multipart_upload( - self, - target_path: str, - input_stream: BinaryIO, - session_token: str, - pre_read_buffer: bytes, - cloud_provider_session: requests.Session, + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, ): """ Performs multipart upload using presigned URLs on AWS and Azure: @@ -992,13 +992,13 @@ def _is_url_expired_response(response: requests.Response): return False def _perform_resumable_upload( - self, - target_path: str, - input_stream: BinaryIO, - session_token: str, - overwrite: bool, - pre_read_buffer: bytes, - cloud_provider_session: requests.Session, + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + overwrite: bool, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, ): """ Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads @@ -1249,7 +1249,7 @@ def perform(): raise ValueError(abort_response) def _abort_resumable_upload( - self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session + self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session ): """Aborts ongoing resumable upload session to clean up incomplete file.""" headers: dict = {} @@ -1284,7 +1284,7 @@ def _create_cloud_provider_session(self): return session def _retry_idempotent_operation( - self, operation: Callable[[], requests.Response], before_retry: Callable = None + self, operation: Callable[[], requests.Response], before_retry: Callable = None ) -> requests.Response: """Perform given idempotent operation with necessary retries. Since operation is idempotent it's safe to retry it for response codes where server state might have changed. @@ -1312,7 +1312,7 @@ def delegate(): )(delegate)() def _open_download_stream( - self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None + self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None ) -> DownloadResponse: """Opens a download stream from given offset, performing necessary retries.""" headers = { @@ -1364,12 +1364,12 @@ def _wrap_stream(self, file_path: str, download_response: DownloadResponse): class _ResilientResponse(_RawResponse): def __init__( - self, - api: FilesExt, - file_path: str, - file_last_modified: str, - offset: int, - underlying_response: _RawResponse, + self, + api: FilesExt, + file_path: str, + file_last_modified: str, + offset: int, + underlying_response: _RawResponse, ): self.api = api self.file_path = file_path @@ -1402,19 +1402,19 @@ class _ResilientIterator(Iterator): @staticmethod def _extract_raw_response( - download_response: DownloadResponse, + download_response: DownloadResponse, ) -> _RawResponse: streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse return streaming_response._response def __init__( - self, - underlying_iterator, - file_path: str, - file_last_modified: str, - offset: int, - api: FilesExt, - chunk_size: int, + self, + underlying_iterator, + file_path: str, + file_last_modified: str, + offset: int, + api: FilesExt, + chunk_size: int, ): self._underlying_iterator = underlying_iterator self._api = api @@ -1435,9 +1435,9 @@ def _should_recover(self) -> bool: _LOG.debug("Total recovers limit exceeded") return False if ( - self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None - and self._recovers_without_progressing_count - >= self._api._config.files_api_client_download_max_total_recovers_without_progressing + self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None + and self._recovers_without_progressing_count + >= self._api._config.files_api_client_download_max_total_recovers_without_progressing ): _LOG.debug("No progression recovers limit exceeded") return False From b4d39cc0f13c89a3ea2e51c4b3b60bec90c1466a Mon Sep 17 00:00:00 2001 From: ksafonov-db Date: Wed, 5 Mar 2025 14:09:39 +0100 Subject: [PATCH 11/11] Remove new test --- tests/integration/test_files.py | 136 -------------------------------- 1 file changed, 136 deletions(-) diff --git a/tests/integration/test_files.py b/tests/integration/test_files.py index e2a0fbd0c..348f88b05 100644 --- a/tests/integration/test_files.py +++ b/tests/integration/test_files.py @@ -1,17 +1,13 @@ -import datetime import io import logging import pathlib import platform -import re import time -from textwrap import dedent from typing import Callable, List, Tuple, Union import pytest from databricks.sdk.core import DatabricksError -from databricks.sdk.errors.sdk import OperationFailed from databricks.sdk.service.catalog import VolumeType @@ -386,135 +382,3 @@ def test_files_api_download_benchmark(ucws, files_api, random): ) min_str = str(best[0]) + "kb" if best[0] else "None" logging.info("Fastest chunk size: %s in %f seconds", min_str, best[1]) - - -@pytest.mark.parametrize("is_serverless", [True, False], ids=["Classic", "Serverless"]) -@pytest.mark.parametrize("use_new_files_api_client", [True, False], ids=["Default client", "Experimental client"]) -def test_files_api_in_cluster(ucws, random, env_or_skip, is_serverless, use_new_files_api_client): - from databricks.sdk.service import compute, jobs - - databricks_sdk_pypi_package = "databricks-sdk" - option_env_name = "DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT" - - launcher_file_path = f"/home/{ucws.current_user.me().user_name}/test_launcher.py" - - schema = "filesit-" + random() - volume = "filesit-" + random() - with ResourceWithCleanup.create_schema(ucws, "main", schema): - with ResourceWithCleanup.create_volume(ucws, "main", schema, volume): - - cloud_file_path = f"/Volumes/main/{schema}/{volume}/test-{random()}.txt" - file_size = 100 * 1024 * 1024 - - if use_new_files_api_client: - enable_new_files_api_env = f"os.environ['{option_env_name}'] = 'True'" - expected_files_api_client_class = "FilesExt" - else: - enable_new_files_api_env = "" - expected_files_api_client_class = "FilesAPI" - - using_files_api_client_msg = "Using files API client: " - - command = f""" - from databricks.sdk import WorkspaceClient - import io - import os - import hashlib - import logging - - logging.basicConfig(level=logging.DEBUG) - - {enable_new_files_api_env} - - file_size = {file_size} - original_content = os.urandom(file_size) - cloud_file_path = '{cloud_file_path}' - - w = WorkspaceClient() - print(f"Using SDK: {{w.config._product_info}}") - - print(f"{using_files_api_client_msg}{{type(w.files).__name__}}") - - w.files.upload(cloud_file_path, io.BytesIO(original_content), overwrite=True) - print("Upload succeeded") - - response = w.files.download(cloud_file_path) - resulting_content = response.contents.read() - print("Download succeeded") - - def hash(data: bytes): - sha256 = hashlib.sha256() - sha256.update(data) - return sha256.hexdigest() - - if len(resulting_content) != len(original_content): - raise ValueError(f"Content length does not match: expected {{len(original_content)}}, actual {{len(resulting_content)}}") - - expected_hash = hash(original_content) - actual_hash = hash(resulting_content) - if actual_hash != expected_hash: - raise ValueError(f"Content hash does not match: expected {{expected_hash}}, actual {{actual_hash}}") - - print(f"Contents of size {{len(resulting_content)}} match") - """ - - with ucws.dbfs.open(launcher_file_path, write=True, overwrite=True) as f: - f.write(dedent(command).encode()) - - if is_serverless: - # If no job_cluster_key, existing_cluster_id, or new_cluster were specified in task definition, - # then task will be executed using serverless compute. - new_cluster_spec = None - - # Library is specified in the environment - env_key = "test_env" - envs = [jobs.JobEnvironment(env_key, compute.Environment("test", [databricks_sdk_pypi_package]))] - libs = [] - else: - new_cluster_spec = compute.ClusterSpec( - spark_version=ucws.clusters.select_spark_version(long_term_support=True), - instance_pool_id=env_or_skip("TEST_INSTANCE_POOL_ID"), - num_workers=1, - ) - - # Library is specified in the task definition - env_key = None - envs = [] - libs = [compute.Library(pypi=compute.PythonPyPiLibrary(package=databricks_sdk_pypi_package))] - - waiter = ucws.jobs.submit( - run_name=f"py-sdk-{random(8)}", - tasks=[ - jobs.SubmitTask( - task_key="task1", - new_cluster=new_cluster_spec, - spark_python_task=jobs.SparkPythonTask(python_file=f"dbfs:{launcher_file_path}"), - libraries=libs, - environment_key=env_key, - ) - ], - environments=envs, - ) - - def print_status(r: jobs.Run): - statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in r.tasks] - logging.info(f'Run status: {", ".join(statuses)}') - - logging.info(f"Waiting for the job run: {waiter.run_id}") - try: - job_run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) - task_run_id = job_run.tasks[0].run_id - task_run_logs = ucws.jobs.get_run_output(task_run_id).logs - logging.info(f"Run finished, output: {task_run_logs}") - match = re.search(f"{using_files_api_client_msg}(.*)$", task_run_logs, re.MULTILINE) - assert match is not None - files_api_client_class = match.group(1) - assert files_api_client_class == expected_files_api_client_class - - except OperationFailed: - job_run = ucws.jobs.get_run(waiter.run_id) - task_run_id = job_run.tasks[0].run_id - task_run_logs = ucws.jobs.get_run_output(task_run_id) - raise ValueError( - f"Run failed, error: {task_run_logs.error}, error trace: {task_run_logs.error_trace}, output: {task_run_logs.logs}" - )