diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 3334fa94f..591aafc44 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -1,5 +1,6 @@ import configparser import copy +import datetime import logging import os import pathlib @@ -98,6 +99,49 @@ class Config: files_api_client_download_max_total_recovers = None files_api_client_download_max_total_recovers_without_progressing = 1 + # File multipart upload parameters + # ---------------------- + + # Minimal input stream size (bytes) to use multipart / resumable uploads. + # For small files it's more efficient to make one single-shot upload request. + # When uploading a file, SDK will initially buffer this many bytes from input stream. + # This parameter can be less or bigger than multipart_upload_chunk_size. + multipart_upload_min_stream_size: int = 5 * 1024 * 1024 + + # Maximum number of presigned URLs that can be requested at a time. + # + # The more URLs we request at once, the higher chance is that some of the URLs will expire + # before we get to use it. We discover the presigned URL is expired *after* sending the + # input stream partition to the server. So to retry the upload of this partition we must rewind + # the stream back. In case of a non-seekable stream we cannot rewind, so we'll abort + # the upload. To reduce the chance of this, we're requesting presigned URLs one by one + # and using them immediately. + multipart_upload_batch_url_count: int = 1 + + # Size of the chunk to use for multipart uploads. + # + # The smaller chunk is, the less chance for network errors (or URL get expired), + # but the more requests we'll make. + # For AWS, minimum is 5Mb: https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html + # For GCP, minimum is 256 KiB (and also recommended multiple is 256 KiB) + # boto uses 8Mb: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig + multipart_upload_chunk_size: int = 10 * 1024 * 1024 + + # use maximum duration of 1 hour + multipart_upload_url_expiration_duration: datetime.timedelta = datetime.timedelta(hours=1) + + # This is not a "wall time" cutoff for the whole upload request, + # but a maximum time between consecutive data reception events (even 1 byte) from the server + multipart_upload_single_chunk_upload_timeout_seconds: float = 60 + + # Cap on the number of custom retries during incremental uploads: + # 1) multipart: upload part URL is expired, so new upload URLs must be requested to continue upload + # 2) resumable: chunk upload produced a retryable response (or exception), so upload status must be + # retrieved to continue the upload. + # In these two cases standard SDK retries (which are capped by the `retry_timeout_seconds` option) are not used. + # Note that retry counter is reset when upload is successfully resumed. + multipart_upload_max_retries = 3 + def __init__( self, *, diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 6a9b263bd..8d9923b4f 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -1,26 +1,36 @@ from __future__ import annotations import base64 +import datetime import logging import os import pathlib import platform +import re import shutil import sys +import xml.etree.ElementTree as ET from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterator +from datetime import timedelta from io import BytesIO from types import TracebackType -from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable, - Optional, Type, Union) +from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator, + Iterable, Optional, Type, Union) from urllib import parse +import requests +import requests.adapters from requests import RequestException -from .._base_client import _RawResponse, _StreamingResponse +from .._base_client import _BaseClient, _RawResponse, _StreamingResponse from .._property import _cached_property -from ..errors import NotFound +from ..config import Config +from ..errors import AlreadyExists, NotFound +from ..errors.customizer import _RetryAfterCustomizer +from ..errors.mapper import _error_mapper +from ..retries import retried from ..service import files from ..service._internal import _escape_multi_segment_path_parameter from ..service.files import DownloadResponse @@ -683,9 +693,13 @@ def delete(self, path: str, *, recursive=False): class FilesExt(files.FilesAPI): __doc__ = files.FilesAPI.__doc__ + # note that these error codes are retryable only for idempotent operations + _RETRYABLE_STATUS_CODES = [408, 429, 500, 502, 503, 504] + def __init__(self, api_client, config: Config): super().__init__(api_client) self._config = config.copy() + self._multipart_upload_read_ahead_bytes = 1 def download(self, file_path: str) -> DownloadResponse: """Download a file. @@ -703,7 +717,7 @@ def download(self, file_path: str) -> DownloadResponse: :returns: :class:`DownloadResponse` """ - initial_response: DownloadResponse = self._download_raw_stream( + initial_response: DownloadResponse = self._open_download_stream( file_path=file_path, start_byte_offset=0, if_unmodified_since_timestamp=None, @@ -713,12 +727,594 @@ def download(self, file_path: str) -> DownloadResponse: initial_response.contents._response = wrapped_response return initial_response - def _download_raw_stream( + def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None): + """Upload a file. + + Uploads a file. The file contents should be sent as the request body as raw bytes (an + octet stream); do not encode or otherwise modify the bytes before sending. The contents of the + resulting file will be exactly the bytes sent in the request body. If the request is successful, there + is no response body. + + :param file_path: str + The absolute remote path of the target file. + :param contents: BinaryIO + :param overwrite: bool (optional) + If true, an existing file will be overwritten. When not specified, assumed True. + """ + + # Upload empty and small files with one-shot upload. + pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size) + if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size: + _LOG.debug( + f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes" + ) + return super().upload(file_path=file_path, contents=BytesIO(pre_read_buffer), overwrite=overwrite) + + query = {"action": "initiate-upload"} + if overwrite is not None: + query["overwrite"] = overwrite + + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + initiate_upload_response = self._api.do( + "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", query=query + ) + + if initiate_upload_response.get("multipart_upload"): + cloud_provider_session = self._create_cloud_provider_session() + session_token = initiate_upload_response["multipart_upload"].get("session_token") + if not session_token: + raise ValueError(f"Unexpected server response: {initiate_upload_response}") + + try: + self._perform_multipart_upload( + file_path, contents, session_token, pre_read_buffer, cloud_provider_session + ) + except Exception as e: + _LOG.info(f"Aborting multipart upload on error: {e}") + try: + self._abort_multipart_upload(file_path, session_token, cloud_provider_session) + except BaseException as ex: + _LOG.warning(f"Failed to abort upload: {ex}") + # ignore, abort is a best-effort + finally: + # rethrow original exception + raise e from None + + elif initiate_upload_response.get("resumable_upload"): + cloud_provider_session = self._create_cloud_provider_session() + session_token = initiate_upload_response["resumable_upload"]["session_token"] + self._perform_resumable_upload( + file_path, contents, session_token, overwrite, pre_read_buffer, cloud_provider_session + ) + else: + raise ValueError(f"Unexpected server response: {initiate_upload_response}") + + def _perform_multipart_upload( self, - file_path: str, - start_byte_offset: int, - if_unmodified_since_timestamp: Optional[str] = None, + target_path: str, + input_stream: BinaryIO, + session_token: str, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, + ): + """ + Performs multipart upload using presigned URLs on AWS and Azure: + https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html + """ + current_part_number = 1 + etags: dict = {} + + # Why are we buffering the current chunk? + # AWS and Azure don't support traditional "Transfer-encoding: chunked", so we must + # provide each chunk size up front. In case of a non-seekable input stream we need + # to buffer a chunk before uploading to know its size. This also allows us to rewind + # the stream before retrying on request failure. + # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html + # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads + + chunk_offset = 0 # used only for logging + + # This buffer is expected to contain at least multipart_upload_chunk_size bytes. + # Note that initially buffer can be bigger (from pre_read_buffer). + buffer = pre_read_buffer + + retry_count = 0 + eof = False + while not eof: + # If needed, buffer the next chunk. + buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream) + if len(buffer) == 0: + # End of stream, no need to request the next block of upload URLs. + break + + _LOG.debug( + f"Multipart upload: requesting next {self._config.multipart_upload_batch_url_count} upload URLs starting from part {current_part_number}" + ) + + body: dict = { + "path": target_path, + "session_token": session_token, + "start_part_number": current_part_number, + "count": self._config.multipart_upload_batch_url_count, + "expire_time": self._get_url_expire_time(), + } + + headers = {"Content-Type": "application/json"} + + # Requesting URLs for the same set of parts is an idempotent operation, safe to retry. + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + upload_part_urls_response = self._api.do( + "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body + ) + + upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) + if len(upload_part_urls) == 0: + raise ValueError(f"Unexpected server response: {upload_part_urls_response}") + + for upload_part_url in upload_part_urls: + buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream) + actual_buffer_length = len(buffer) + if actual_buffer_length == 0: + eof = True + break + + url = upload_part_url["url"] + required_headers = upload_part_url.get("headers", []) + assert current_part_number == upload_part_url["part_number"] + + headers: dict = {"Content-Type": "application/octet-stream"} + for h in required_headers: + headers[h["name"]] = h["value"] + + actual_chunk_length = min(actual_buffer_length, self._config.multipart_upload_chunk_size) + _LOG.debug( + f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]" + ) + + chunk = BytesIO(buffer[:actual_chunk_length]) + + def rewind(): + chunk.seek(0, os.SEEK_SET) + + def perform(): + return cloud_provider_session.request( + "PUT", + url, + headers=headers, + data=chunk, + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + + upload_response = self._retry_idempotent_operation(perform, rewind) + + if upload_response.status_code in (200, 201): + # Chunk upload successful + + chunk_offset += actual_chunk_length + + etag = upload_response.headers.get("ETag", "") + etags[current_part_number] = etag + + # Discard uploaded bytes + buffer = buffer[actual_chunk_length:] + + # Reset retry count when progressing along the stream + retry_count = 0 + + elif FilesExt._is_url_expired_response(upload_response): + if retry_count < self._config.multipart_upload_max_retries: + retry_count += 1 + _LOG.debug("Upload URL expired") + # Preserve the buffer so we'll upload the current part again using next upload URL + else: + # don't confuse user with unrelated "Permission denied" error. + raise ValueError(f"Unsuccessful chunk upload: upload URL expired") + + else: + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + _LOG.warning(message) + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError(message) + + current_part_number += 1 + + _LOG.debug( + f"Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes" + ) + + query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} + headers = {"Content-Type": "application/json"} + body: dict = {} + + parts = [] + for etag in sorted(etags.items()): + part = {"part_number": etag[0], "etag": etag[1]} + parts.append(part) + + body["parts"] = parts + + # Completing upload is an idempotent operation, safe to retry. + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + self._api.do( + "POST", + f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}", + query=query, + headers=headers, + body=body, + ) + + @staticmethod + def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO): + """ + Tries to fill given buffer to contain at least `desired_min_size` bytes by reading from input stream. + """ + bytes_to_read = max(0, desired_min_size - len(buffer)) + if bytes_to_read > 0: + next_buf = input_stream.read(bytes_to_read) + new_buffer = buffer + next_buf + return new_buffer + else: + # we have already buffered enough data + return buffer + + @staticmethod + def _is_url_expired_response(response: requests.Response): + """ + Checks if response matches one of the known "URL expired" responses from the cloud storage providers. + """ + if response.status_code != 403: + return False + + try: + xml_root = ET.fromstring(response.content) + if xml_root.tag != "Error": + return False + + code = xml_root.find("Code") + if code is None: + return False + + if code.text == "AuthenticationFailed": + # Azure + details = xml_root.find("AuthenticationErrorDetail") + if details is not None and "Signature not valid in the specified time frame" in details.text: + return True + + if code.text == "AccessDenied": + # AWS + message = xml_root.find("Message") + if message is not None and message.text == "Request has expired": + return True + + except ET.ParseError: + pass + + return False + + def _perform_resumable_upload( + self, + target_path: str, + input_stream: BinaryIO, + session_token: str, + overwrite: bool, + pre_read_buffer: bytes, + cloud_provider_session: requests.Session, + ): + """ + Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads + """ + + # Session URI we're using expires after a week + + # Why are we buffering the current chunk? + # When using resumable upload API we're uploading data in chunks. During chunk upload + # server responds with the "received offset" confirming how much data it stored so far, + # so we should continue uploading from that offset. (Note this is not a failure but an + # expected behaviour as per the docs.) But, input stream might be consumed beyond that + # offset, since server might have read more data than it confirmed received, or some data + # might have been pre-cached by e.g. OS or a proxy. So, to continue upload, we must rewind + # the input stream back to the byte next to "received offset". This is not possible + # for non-seekable input stream, so we must buffer the whole last chunk and seek inside + # the buffer. By always uploading from the buffer we fully support non-seekable streams. + + # Why are we doing read-ahead? + # It's not possible to upload an empty chunk as "Content-Range" header format does not + # support this. So if current chunk happens to finish exactly at the end of the stream, + # we need to know that and mark the chunk as last (by passing real file size in the + # "Content-Range" header) when uploading it. To detect if we're at the end of the stream + # we're reading "ahead" an extra bytes but not uploading them immediately. If + # nothing has been read ahead, it means we're at the end of the stream. + # On the contrary, in multipart upload we can decide to complete upload *after* + # last chunk has been sent. + + body: dict = {"path": target_path, "session_token": session_token} + + headers = {"Content-Type": "application/json"} + + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + resumable_upload_url_response = self._api.do( + "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body + ) + + resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url") + if not resumable_upload_url_node: + raise ValueError(f"Unexpected server response: {resumable_upload_url_response}") + + resumable_upload_url = resumable_upload_url_node.get("url") + if not resumable_upload_url: + raise ValueError(f"Unexpected server response: {resumable_upload_url_response}") + + required_headers = resumable_upload_url_node.get("headers", []) + + try: + # We will buffer this many bytes: one chunk + read-ahead block. + # Note buffer may contain more data initially (from pre_read_buffer). + min_buffer_size = self._config.multipart_upload_chunk_size + self._multipart_upload_read_ahead_bytes + + buffer = pre_read_buffer + + # How many bytes in the buffer were confirmed to be received by the server. + # All the remaining bytes in the buffer must be uploaded. + uploaded_bytes_count = 0 + + chunk_offset = 0 + + retry_count = 0 + while True: + # If needed, fill the buffer to contain at least min_buffer_size bytes + # (unless end of stream), discarding already uploaded bytes. + bytes_to_read = max(0, min_buffer_size - (len(buffer) - uploaded_bytes_count)) + next_buf = input_stream.read(bytes_to_read) + buffer = buffer[uploaded_bytes_count:] + next_buf + + if len(next_buf) < bytes_to_read: + # This is the last chunk in the stream. + # Let's upload all the remaining bytes in one go. + actual_chunk_length = len(buffer) + file_size = chunk_offset + actual_chunk_length + else: + # More chunks expected, let's upload current chunk (excluding read-ahead block). + actual_chunk_length = self._config.multipart_upload_chunk_size + file_size = "*" + + headers: dict = {"Content-Type": "application/octet-stream"} + for h in required_headers: + headers[h["name"]] = h["value"] + + chunk_last_byte_offset = chunk_offset + actual_chunk_length - 1 + content_range_header = f"bytes {chunk_offset}-{chunk_last_byte_offset}/{file_size}" + _LOG.debug(f"Uploading chunk: {content_range_header}") + headers["Content-Range"] = content_range_header + + def retrieve_upload_status() -> Optional[requests.Response]: + def perform(): + return cloud_provider_session.request( + "PUT", + resumable_upload_url, + headers={"Content-Range": "bytes */*"}, + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + + try: + return self._retry_idempotent_operation(perform) + except RequestException: + _LOG.warning("Failed to retrieve upload status") + return None + + try: + upload_response = cloud_provider_session.request( + "PUT", + resumable_upload_url, + headers=headers, + data=BytesIO(buffer[:actual_chunk_length]), + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + + # https://cloud.google.com/storage/docs/performing-resumable-uploads#resume-upload + # If an upload request is terminated before receiving a response, or if you receive + # a 503 or 500 response, then you need to resume the interrupted upload from where it left off. + + # Let's follow that for all potentially retryable status codes. + # Together with the catch block below we replicate the logic in _retry_idempotent_operation(). + if upload_response.status_code in self._RETRYABLE_STATUS_CODES: + if retry_count < self._config.multipart_upload_max_retries: + retry_count += 1 + # let original upload_response be handled as an error + upload_response = retrieve_upload_status() or upload_response + else: + # we received non-retryable response, reset retry count + retry_count = 0 + + except RequestException as e: + # Let's do the same for retryable network errors. + if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries: + retry_count += 1 + upload_response = retrieve_upload_status() + if not upload_response: + # rethrow original exception + raise e from None + else: + # rethrow original exception + raise e from None + + if upload_response.status_code in (200, 201): + if file_size == "*": + raise ValueError( + f"Received unexpected status {upload_response.status_code} before reaching end of stream" + ) + + # upload complete + break + + elif upload_response.status_code == 308: + # chunk accepted (or check-status succeeded), let's determine received offset to resume from there + range_string = upload_response.headers.get("Range") + confirmed_offset = self._extract_range_offset(range_string) + _LOG.debug(f"Received confirmed offset: {confirmed_offset}") + + if confirmed_offset: + if confirmed_offset < chunk_offset - 1 or confirmed_offset > chunk_last_byte_offset: + raise ValueError( + f"Unexpected received offset: {confirmed_offset} is outside of expected range, chunk offset: {chunk_offset}, chunk last byte offset: {chunk_last_byte_offset}" + ) + else: + if chunk_offset > 0: + raise ValueError( + f"Unexpected received offset: {confirmed_offset} is outside of expected range, chunk offset: {chunk_offset}, chunk last byte offset: {chunk_last_byte_offset}" + ) + + # We have just uploaded a part of chunk starting from offset "chunk_offset" and ending + # at offset "confirmed_offset" (inclusive), so the next chunk will start at + # offset "confirmed_offset + 1" + if confirmed_offset: + next_chunk_offset = confirmed_offset + 1 + else: + next_chunk_offset = chunk_offset + uploaded_bytes_count = next_chunk_offset - chunk_offset + chunk_offset = next_chunk_offset + + elif upload_response.status_code == 412 and not overwrite: + # Assuming this is only possible reason + # Full message in this case: "At least one of the pre-conditions you specified did not hold." + raise AlreadyExists("The file being created already exists.") + + else: + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + _LOG.warning(message) + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError(message) + + except Exception as e: + _LOG.info(f"Aborting resumable upload on error: {e}") + try: + self._abort_resumable_upload(resumable_upload_url, required_headers, cloud_provider_session) + except BaseException as ex: + _LOG.warning(f"Failed to abort upload: {ex}") + # ignore, abort is a best-effort + finally: + # rethrow original exception + raise e from None + + @staticmethod + def _extract_range_offset(range_string: Optional[str]) -> Optional[int]: + """Parses the response range header to extract the last byte.""" + if not range_string: + return None # server did not yet confirm any bytes + + if match := re.match("bytes=0-(\\d+)", range_string): + return int(match.group(1)) + else: + raise ValueError(f"Cannot parse response header: Range: {range_string}") + + def _get_url_expire_time(self): + """Generates expiration time and save it in the required format.""" + current_time = datetime.datetime.now(datetime.timezone.utc) + expire_time = current_time + self._config.multipart_upload_url_expiration_duration + # From Google Protobuf doc: + # In JSON format, the Timestamp type is encoded as a string in the + # * [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the + # * format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z" + return expire_time.strftime("%Y-%m-%dT%H:%M:%SZ") + + def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_provider_session: requests.Session): + """Aborts ongoing multipart upload session to clean up incomplete file.""" + body: dict = {"path": target_path, "session_token": session_token, "expire_time": self._get_url_expire_time()} + + headers = {"Content-Type": "application/json"} + + # Method _api.do() takes care of retrying and will raise an exception in case of failure. + abort_url_response = self._api.do("POST", "/api/2.0/fs/create-abort-upload-url", headers=headers, body=body) + + abort_upload_url_node = abort_url_response["abort_upload_url"] + abort_url = abort_upload_url_node["url"] + required_headers = abort_upload_url_node.get("headers", []) + + headers: dict = {"Content-Type": "application/octet-stream"} + for h in required_headers: + headers[h["name"]] = h["value"] + + def perform(): + return cloud_provider_session.request( + "DELETE", + abort_url, + headers=headers, + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + + abort_response = self._retry_idempotent_operation(perform) + + if abort_response.status_code not in (200, 201): + raise ValueError(abort_response) + + def _abort_resumable_upload( + self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session + ): + """Aborts ongoing resumable upload session to clean up incomplete file.""" + headers: dict = {} + for h in required_headers: + headers[h["name"]] = h["value"] + + def perform(): + return cloud_provider_session.request( + "DELETE", + resumable_upload_url, + headers=headers, + data=b"", + timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds, + ) + + abort_response = self._retry_idempotent_operation(perform) + + if abort_response.status_code not in (200, 201): + raise ValueError(abort_response) + + def _create_cloud_provider_session(self): + """Creates a separate session which does not inherit auth headers from BaseClient session.""" + session = requests.Session() + + # following session config in _BaseClient + http_adapter = requests.adapters.HTTPAdapter( + self._config.max_connection_pools or 20, self._config.max_connections_per_pool or 20, pool_block=True + ) + session.mount("https://", http_adapter) + # presigned URL for storage proxy can use plain HTTP + session.mount("http://", http_adapter) + return session + + def _retry_idempotent_operation( + self, operation: Callable[[], requests.Response], before_retry: Callable = None + ) -> requests.Response: + """Perform given idempotent operation with necessary retries. Since operation is idempotent it's + safe to retry it for response codes where server state might have changed. + """ + + def delegate(): + response = operation() + if response.status_code in self._RETRYABLE_STATUS_CODES: + attrs = {} + # this will assign "retry_after_secs" to the attrs, essentially making exception look retryable + _RetryAfterCustomizer().customize_error(response, attrs) + raise _error_mapper(response, attrs) + else: + return response + + # following _BaseClient timeout + retry_timeout_seconds = self._config.retry_timeout_seconds or 300 + + return retried( + timeout=timedelta(seconds=retry_timeout_seconds), + # also retry on network errors (connection error, connection timeout) + # where we believe request didn't reach the server + is_retryable=_BaseClient._is_retryable, + before_retry=before_retry, + )(delegate)() + + def _open_download_stream( + self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None ) -> DownloadResponse: + """Opens a download stream from given offset, performing necessary retries.""" headers = { "Accept": "application/octet-stream", } @@ -737,6 +1333,7 @@ def _download_raw_stream( "content-type", "last-modified", ] + # Method _api.do() takes care of retrying and will raise an exception in case of failure. res = self._api.do( "GET", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", @@ -753,12 +1350,12 @@ def _download_raw_stream( return result - def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse): - underlying_response = _ResilientIterator._extract_raw_response(downloadResponse) + def _wrap_stream(self, file_path: str, download_response: DownloadResponse): + underlying_response = _ResilientIterator._extract_raw_response(download_response) return _ResilientResponse( self, file_path, - downloadResponse.last_modified, + download_response.last_modified, offset=0, underlying_response=underlying_response, ) @@ -859,7 +1456,7 @@ def _recover(self) -> bool: _LOG.debug("Trying to recover from offset " + str(self._offset)) # following call includes all the required network retries - downloadResponse = self._api._download_raw_stream(self._file_path, self._offset, self._file_last_modified) + downloadResponse = self._api._open_download_stream(self._file_path, self._offset, self._file_last_modified) underlying_response = _ResilientIterator._extract_raw_response(downloadResponse) self._underlying_iterator = underlying_response.iter_content( chunk_size=self._chunk_size, decode_unicode=False diff --git a/tests/test_files.py b/tests/test_files.py index 50e6cb470..e25035523 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -1,14 +1,28 @@ +import copy +import hashlib +import io +import json import logging import os +import random import re +import time from dataclasses import dataclass -from typing import List, Union +from datetime import datetime, timedelta, timezone +from tempfile import mkstemp +from typing import Callable, List, Optional, Type, Union +from urllib.parse import parse_qs, urlparse import pytest +import requests +import requests_mock from requests import RequestException from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config +from databricks.sdk.errors.platform import (AlreadyExists, BadRequest, + InternalError, PermissionDenied, + TooManyRequests) logger = logging.getLogger(__name__) @@ -392,3 +406,1452 @@ class _Constants: ) def test_download_recover(config: Config, test_case: DownloadTestCase): test_case.run(config) + + +class FileContent: + + def __init__(self, length: int, checksum: str): + self._length = length + self.checksum = checksum + + @classmethod + def from_bytes(cls, data: bytes): + sha256 = hashlib.sha256() + sha256.update(data) + return FileContent(len(data), sha256.hexdigest()) + + def __repr__(self): + return f"Length: {self._length}, checksum: {self.checksum}" + + def __eq__(self, other): + if not isinstance(other, FileContent): + return NotImplemented + return self._length == other._length and self.checksum == other.checksum + + +class MultipartUploadServerState: + upload_chunk_url_prefix = "https://cloud_provider.com/upload-chunk/" + abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" + + def __init__(self): + self.issued_multipart_urls = {} # part_number -> expiration_time + self.uploaded_chunks = {} # part_number -> [chunk file path, etag] + self.session_token = "token-" + MultipartUploadServerState.randomstr() + self.file_content = None + self.issued_abort_url_expire_time = None + self.aborted = False + + def create_upload_chunk_url(self, path: str, part_number: int, expire_time: datetime) -> str: + assert not self.aborted + # client may have requested a URL for the same part if retrying on network error + self.issued_multipart_urls[part_number] = expire_time + return f"{self.upload_chunk_url_prefix}{path}/{part_number}" + + def create_abort_url(self, path: str, expire_time: datetime) -> str: + assert not self.aborted + self.issued_abort_url_expire_time = expire_time + return f"{self.abort_upload_url_prefix}{path}" + + def save_part(self, part_number: int, part_content: bytes, etag: str): + assert not self.aborted + assert len(part_content) > 0 + + logger.info(f"Saving part {part_number} of size {len(part_content)}") + + # chunk might already have been uploaded + existing_chunk = self.uploaded_chunks.get(part_number) + if existing_chunk: + chunk_file = existing_chunk[0] + with open(chunk_file, "wb") as f: + f.write(part_content) + else: + fd, chunk_file = mkstemp() + with open(fd, "wb") as f: + f.write(part_content) + + self.uploaded_chunks[part_number] = [chunk_file, etag] + + def cleanup(self): + for [file, _] in self.uploaded_chunks.values(): + os.remove(file) + + def get_file_content(self) -> FileContent: + assert not self.aborted + return self.file_content + + def upload_complete(self, etags: dict): + assert not self.aborted + # validate etags + expected_etags = {} + for part_number in self.uploaded_chunks.keys(): + expected_etags[part_number] = self.uploaded_chunks[part_number][1] + assert etags == expected_etags + + size = 0 + sha256 = hashlib.sha256() + + sorted_chunks = sorted(self.uploaded_chunks.keys()) + for part_number in sorted_chunks: + [chunk_path, _] = self.uploaded_chunks[part_number] + size += os.path.getsize(chunk_path) + with open(chunk_path, "rb") as f: + chunk_content = f.read() + sha256.update(chunk_content) + + self.file_content = FileContent(size, sha256.hexdigest()) + + def abort_upload(self): + self.aborted = True + + @staticmethod + def randomstr(): + return f"{random.randrange(10000)}-{int(time.time())}" + + +class CustomResponse: + """Custom response allows to override the "default" response generated by the server + with the "custom" response to simulate failure error code, unexpected response body or + network error. + + The server is represented by the `processor` parameter in `generate_response()` call. + """ + + def __init__( + self, + # If False, default response is always returned. + # If True, response is defined by the current invocation count + # with respect to first_invocation / last_invocation / only_invocation + enabled=True, + # Custom code to return + code: Optional[int] = 200, + # Custom body to return + body: Optional[str] = None, + # Custom exception to raise + exception: Optional[Type[BaseException]] = None, + # Whether exception should be raised before calling processor() + # (so changing server state) + exception_happened_before_processing: bool = False, + # First invocation (1-based) at which return custom response + first_invocation: Optional[int] = None, + # Last invocation (1-based) at which return custom response + last_invocation: Optional[int] = None, + # Only invocation (1-based) at which return custom response + only_invocation: Optional[int] = None, + ): + self.enabled = enabled + self.code = code + self.body = body + self.exception = exception + self.exception_happened_before_processing = exception_happened_before_processing + self.first_invocation = first_invocation + self.last_invocation = last_invocation + self.only_invocation = only_invocation + + if self.only_invocation and (self.first_invocation or self.last_invocation): + raise ValueError("Cannot set both only invocation and first/last invocation") + + if self.exception_happened_before_processing and not self.exception: + raise ValueError("Exception is not defined") + + self.invocation_count = 0 + + def invocation_matches(self): + if not self.enabled: + return False + + self.invocation_count += 1 + + if self.only_invocation: + return self.invocation_count == self.only_invocation + + if self.first_invocation and self.invocation_count < self.first_invocation: + return False + if self.last_invocation and self.invocation_count > self.last_invocation: + return False + return True + + def generate_response(self, request: requests.Request, processor: Callable[[], list]): + activate_for_current_invocation = self.invocation_matches() + + if activate_for_current_invocation and self.exception and self.exception_happened_before_processing: + # if network exception is thrown while processing a request, it's not defined + # if server actually processed the request (and so changed its state) + raise self.exception + + custom_response = [self.code, self.body or "", {}] + + if activate_for_current_invocation: + if self.code and 400 <= self.code < 500: + # if server returns client error, it's not supposed to change its state, + # so we're not calling processor() + [code, body, headers] = custom_response + else: + # we're calling processor() but override its response with the custom one + processor() + [code, body, headers] = custom_response + else: + [code, body, headers] = processor() + + if activate_for_current_invocation and self.exception: + # self.exception_happened_before_processing is False + raise self.exception + + resp = requests.Response() + + resp.request = request + resp.status_code = code + resp._content = body.encode() + + for key in headers: + resp.headers[key] = headers[key] + + return resp + + +class MultipartUploadTestCase: + """Test case for multipart upload of a file. Multipart uploads are used on AWS and Azure. + + Multipart upload via presigned URLs involves multiple HTTP requests: + - initiating upload (call to Databricks Files API) + - requesting upload part URLs (calls to Databricks Files API) + - uploading data in chunks (calls to cloud storage provider or Databricks storage proxy) + - completing the upload (call to Databricks Files API) + - requesting abort upload URL (call to Databricks Files API) + - aborting the upload (call to cloud storage provider or Databricks storage proxy) + + Test case uses requests-mock library to mock all these requests. Within a test, mocks use + shared server state that tracks the upload. Mocks generate the "default" (successful) response. + + Response of each call can be modified by parameterising a respective `CustomResponse` object. + """ + + path = "/test.txt" + + expired_url_aws_response = ( + '' + "AuthenticationFailedServer failed to authenticate " + "the request. Make sure the value of Authorization header is formed " + "correctly including the signature.\nRequestId:1abde581-601e-0028-" + "4a6d-5c3952000000\nTime:2025-01-01T16:54:20.5343181ZSignature not valid in the specified " + "time frame: Start [Wed, 01 Jan 2025 16:38:41 GMT] - Expiry [Wed, " + "01 Jan 2025 16:53:45 GMT] - Current [Wed, 01 Jan 2025 16:54:20 " + "GMT]" + ) + + expired_url_azure_response = ( + '\nAccessDenied' + "Request has expired" + "142025-01-01T17:47:13Z" + "2025-01-01T17:48:01Z" + "JY66KDXM4CXBZ7X2n8Qayqg60rbvut9P7pk0" + "" + ) + + # TODO test for overwrite = false + + def __init__( + self, + name: str, + stream_size: int, # size of uploaded file or, technically, stream + multipart_upload_chunk_size: Optional[int] = None, + sdk_retry_timeout_seconds: Optional[int] = None, + multipart_upload_max_retries: Optional[int] = None, + multipart_upload_batch_url_count: Optional[int] = None, + custom_response_on_initiate=CustomResponse(enabled=False), + custom_response_on_create_multipart_url=CustomResponse(enabled=False), + custom_response_on_upload=CustomResponse(enabled=False), + custom_response_on_complete=CustomResponse(enabled=False), + custom_response_on_create_abort_url=CustomResponse(enabled=False), + custom_response_on_abort=CustomResponse(enabled=False), + # exception which is expected to be thrown (so upload is expected to have failed) + expected_exception_type: Optional[Type[BaseException]] = None, + # if abort is expected to be called + expected_aborted: bool = False, + ): + self.name = name + self.stream_size = stream_size + self.multipart_upload_chunk_size = multipart_upload_chunk_size + self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds + self.multipart_upload_max_retries = multipart_upload_max_retries + self.multipart_upload_batch_url_count = multipart_upload_batch_url_count + self.custom_response_on_initiate = copy.deepcopy(custom_response_on_initiate) + self.custom_response_on_create_multipart_url = copy.deepcopy(custom_response_on_create_multipart_url) + self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload) + self.custom_response_on_complete = copy.deepcopy(custom_response_on_complete) + self.custom_response_on_create_abort_url = copy.deepcopy(custom_response_on_create_abort_url) + self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort) + self.expected_exception_type = expected_exception_type + self.expected_aborted: bool = expected_aborted + + def setup_session_mock(self, session_mock: requests_mock.Mocker, server_state: MultipartUploadServerState): + + def custom_matcher(request): + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + # initial request + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): + + assert MultipartUploadTestCase.is_auth_header_present(request) + assert request.text is None + + def processor(): + response_json = {"multipart_upload": {"session_token": server_state.session_token}} + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_initiate.generate_response(request, processor) + + # multipart upload, create upload part URLs + elif ( + request_url.hostname == "localhost" + and request_url.path == "/api/2.0/fs/create-upload-part-urls" + and request.method == "POST" + ): + + assert MultipartUploadTestCase.is_auth_header_present(request) + + request_json = request.json() + assert request_json.keys() == {"count", "expire_time", "path", "session_token", "start_part_number"} + assert request_json["path"] == self.path + assert request_json["session_token"] == server_state.session_token + + start_part_number = int(request_json["start_part_number"]) + count = int(request_json["count"]) + assert count >= 1 + + expire_time = MultipartUploadTestCase.parse_and_validate_expire_time(request_json["expire_time"]) + + def processor(): + response_nodes = [] + for part_number in range(start_part_number, start_part_number + count): + upload_part_url = server_state.create_upload_chunk_url(self.path, part_number, expire_time) + response_nodes.append( + { + "part_number": part_number, + "url": upload_part_url, + "headers": [{"name": "name1", "value": "value1"}], + } + ) + + response_json = {"upload_part_urls": response_nodes} + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_create_multipart_url.generate_response(request, processor) + + # multipart upload, uploading part + elif request.url.startswith(MultipartUploadServerState.upload_chunk_url_prefix) and request.method == "PUT": + + assert not MultipartUploadTestCase.is_auth_header_present(request) + + url_path = request.url[len(MultipartUploadServerState.abort_upload_url_prefix) :] + part_num = url_path.split("/")[-1] + assert url_path[: -len(part_num) - 1] == self.path + + def processor(): + body = request.body.read() + etag = "etag-" + MultipartUploadServerState.randomstr() + server_state.save_part(int(part_num), body, etag) + return [200, "", {"ETag": etag}] + + return self.custom_response_on_upload.generate_response(request, processor) + + # multipart upload, completion + elif ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["complete-upload"] + and request_query.get("upload_type") == ["multipart"] + and request.method == "POST" + ): + + assert MultipartUploadTestCase.is_auth_header_present(request) + assert [server_state.session_token] == request_query.get("session_token") + + def processor(): + request_json = request.json() + etags = {} + + for part in request_json["parts"]: + etags[part["part_number"]] = part["etag"] + + server_state.upload_complete(etags) + return [200, "", {}] + + return self.custom_response_on_complete.generate_response(request, processor) + + # create abort URL + elif request.url == "http://localhost/api/2.0/fs/create-abort-upload-url" and request.method == "POST": + assert MultipartUploadTestCase.is_auth_header_present(request) + request_json = request.json() + assert request_json["path"] == self.path + expire_time = MultipartUploadTestCase.parse_and_validate_expire_time(request_json["expire_time"]) + + def processor(): + response_json = { + "abort_upload_url": { + "url": server_state.create_abort_url(self.path, expire_time), + "headers": [{"name": "header1", "value": "headervalue1"}], + } + } + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_create_abort_url.generate_response(request, processor) + + # abort upload + elif ( + request.url.startswith(MultipartUploadServerState.abort_upload_url_prefix) + and request.method == "DELETE" + ): + assert not MultipartUploadTestCase.is_auth_header_present(request) + assert request.url[len(MultipartUploadServerState.abort_upload_url_prefix) :] == self.path + + def processor(): + server_state.abort_upload() + return [200, "", {}] + + return self.custom_response_on_abort.generate_response(request, processor) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + @staticmethod + def setup_token_auth(config: Config): + pat_token = "some_pat_token" + config._header_factory = lambda: {"Authorization": f"Bearer {pat_token}"} + + @staticmethod + def is_auth_header_present(r: requests.Request): + return r.headers.get("Authorization") is not None + + @staticmethod + def parse_and_validate_expire_time(s: str) -> datetime: + expire_time = datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ") + expire_time = expire_time.replace(tzinfo=timezone.utc) # Explicitly add timezone + now = datetime.now(timezone.utc) + max_expiration = now + timedelta(hours=2) + assert now < expire_time < max_expiration + return expire_time + + def run(self, config: Config): + config = config.copy() + + MultipartUploadTestCase.setup_token_auth(config) + + if self.sdk_retry_timeout_seconds: + config.retry_timeout_seconds = self.sdk_retry_timeout_seconds + if self.multipart_upload_chunk_size: + config.multipart_upload_chunk_size = self.multipart_upload_chunk_size + if self.multipart_upload_max_retries: + config.multipart_upload_max_retries = self.multipart_upload_max_retries + if self.multipart_upload_batch_url_count: + config.multipart_upload_batch_url_count = self.multipart_upload_batch_url_count + config.enable_experimental_files_api_client = True + config.multipart_upload_min_stream_size = 0 # disable single-shot uploads + + file_content = os.urandom(self.stream_size) + + upload_state = MultipartUploadServerState() + + try: + w = WorkspaceClient(config=config) + with requests_mock.Mocker() as session_mock: + self.setup_session_mock(session_mock, upload_state) + + def upload(): + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True) + + if self.expected_exception_type is not None: + with pytest.raises(self.expected_exception_type): + upload() + else: + upload() + actual_content = upload_state.get_file_content() + assert actual_content == FileContent.from_bytes(file_content) + + assert upload_state.aborted == self.expected_aborted + + finally: + upload_state.cleanup() + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # -------------------------- failures on "initiate upload" -------------------------- + MultipartUploadTestCase( + "Initiate: 400 response is not retried", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=400, only_invocation=1), + expected_exception_type=BadRequest, + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: 403 response is not retried", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=403, only_invocation=1), + expected_exception_type=PermissionDenied, + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: 500 response is not retried", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(code=500, only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: non-JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(body="this is not a JSON", only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: meaningless JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(body='{"foo": 123}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: no session token in response is not retried", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse( + body='{"multipart_upload":{"session_token1": "token123"}}', only_invocation=1 + ), + expected_exception_type=ValueError, + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: permanent retryable exception", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError), + sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout) + expected_exception_type=TimeoutError, # SDK throws this if retries are taking too long + expected_aborted=False, # upload didn't start + ), + MultipartUploadTestCase( + "Initiate: intermittent retryable exception", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse( + exception=requests.ConnectionError, + # 3 calls fail, but request is successfully retried + first_invocation=1, + last_invocation=3, + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Initiate: intermittent retryable status code", + stream_size=1024 * 1024, + custom_response_on_initiate=CustomResponse( + code=429, + # 3 calls fail, then retry succeeds + first_invocation=1, + last_invocation=3, + ), + expected_aborted=False, + ), + # -------------------------- failures on "create upload URL" -------------------------- + MultipartUploadTestCase( + "Create upload URL: 400 response is not retied", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + code=400, + # 1 failure is enough + only_invocation=1, + ), + expected_exception_type=BadRequest, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: 500 error is not retied", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: non-JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body="this is not a JSON", only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: meaningless JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body='{"foo":123}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: meaningless JSON response is not retried 2", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(body='{"upload_part_urls":[]}', only_invocation=1), + expected_exception_type=ValueError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: meaningless JSON response is not retried 3", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + body='{"upload_part_urls":[{"url":""}]}', only_invocation=1 + ), + expected_exception_type=KeyError, # TODO we might want to make JSON parsing more reliable + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: permanent retryable exception", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(exception=requests.ConnectionError), + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + expected_exception_type=TimeoutError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Create upload URL: intermittent retryable exception", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + exception=requests.Timeout, + # happens only once, retry succeeds + only_invocation=1, + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Create upload URL: intermittent retryable exception 2", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse( + exception=requests.Timeout, + # 4th request for multipart URLs fails 3 times, then retry succeeds + first_invocation=4, + last_invocation=6, + ), + expected_aborted=False, + ), + # -------------------------- failures on chunk upload -------------------------- + MultipartUploadTestCase( + "Upload chunk: 403 response is not retried", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + # fail only once + only_invocation=1, + ), + expected_exception_type=PermissionDenied, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Upload chunk: 400 response is not retried", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=400, + # fail once, but not on the first chunk + only_invocation=3, + ), + expected_exception_type=BadRequest, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Upload chunk: 500 response is not retried", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse(code=500, only_invocation=5), + expected_exception_type=InternalError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Upload chunk: expired URL is retried on AWS", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, body=MultipartUploadTestCase.expired_url_aws_response, only_invocation=2 + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Upload chunk: expired URL is retried on Azure", + multipart_upload_max_retries=3, + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + # 3 failures don't exceed multipart_upload_max_retries + first_invocation=2, + last_invocation=4, + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Upload chunk: expired URL is retried on Azure, requesting urls by 6", + multipart_upload_max_retries=3, + multipart_upload_batch_url_count=6, + stream_size=100 * 1024 * 1024, # 100 chunks + multipart_upload_chunk_size=1 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + # 3 failures don't exceed multipart_upload_max_retries + first_invocation=2, + last_invocation=4, + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Upload chunk: expired URL retry is exhausted", + multipart_upload_max_retries=3, + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + code=403, + body=MultipartUploadTestCase.expired_url_azure_response, + # 4 failures exceed multipart_upload_max_retries + first_invocation=2, + last_invocation=5, + ), + expected_exception_type=ValueError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Upload chunk: permanent retryable error", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, first_invocation=8), + expected_exception_type=TimeoutError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Upload chunk: permanent retryable status code", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + custom_response_on_upload=CustomResponse(code=429, first_invocation=8), + expected_exception_type=TimeoutError, + expected_aborted=True, + ), + MultipartUploadTestCase( + "Upload chunk: intermittent retryable error", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, first_invocation=2, last_invocation=5 + ), + expected_aborted=False, + ), + MultipartUploadTestCase( + "Upload chunk: intermittent retryable status code", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + custom_response_on_upload=CustomResponse(code=429, first_invocation=2, last_invocation=4), + expected_aborted=False, + ), + # -------------------------- failures on abort -------------------------- + MultipartUploadTestCase( + "Abort URL: 500 response", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse(code=400), + expected_exception_type=InternalError, # original error + expected_aborted=False, # server state didn't change to record abort + ), + MultipartUploadTestCase( + "Abort URL: 403 response", + stream_size=1024 * 1024, + custom_response_on_upload=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse(code=403), + expected_exception_type=InternalError, # original error + expected_aborted=False, # server state didn't change to record abort + ), + MultipartUploadTestCase( + "Abort URL: intermittent retryable error", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse(code=429, first_invocation=1, last_invocation=3), + expected_exception_type=InternalError, # original error + expected_aborted=True, # abort successfully called after abort URL creation is retried + ), + MultipartUploadTestCase( + "Abort URL: intermittent retryable error 2", + stream_size=1024 * 1024, + custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1), + custom_response_on_create_abort_url=CustomResponse( + exception=requests.Timeout, first_invocation=1, last_invocation=3 + ), + expected_exception_type=InternalError, # original error + expected_aborted=True, # abort successfully called after abort URL creation is retried + ), + MultipartUploadTestCase( + "Abort: exception", + stream_size=1024 * 1024, + # don't wait for 5 min (SDK default timeout) + sdk_retry_timeout_seconds=30, + custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1), + custom_response_on_abort=CustomResponse( + exception=requests.Timeout, + # this allows to change the server state to "aborted" + exception_happened_before_processing=False, + ), + expected_exception_type=PermissionDenied, # original error is reported + expected_aborted=True, + ), + # -------------------------- happy cases -------------------------- + MultipartUploadTestCase( + "Multipart upload successful: single chunk", + stream_size=1024 * 1024, # less than chunk size + multipart_upload_chunk_size=10 * 1024 * 1024, + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple chunks (aligned)", + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple chunks (aligned), upload urls by 3", + multipart_upload_batch_url_count=3, + stream_size=100 * 1024 * 1024, # 10 chunks + multipart_upload_chunk_size=10 * 1024 * 1024, + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple chunks (not aligned), upload urls by 1", + stream_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder + multipart_upload_chunk_size=7 * 1024 * 1024 - 17, + ), + MultipartUploadTestCase( + "Multipart upload successful: multiple chunks (not aligned), upload urls by 5", + multipart_upload_batch_url_count=5, + stream_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder + multipart_upload_chunk_size=7 * 1024 * 1024 - 17, + ), + ], + ids=MultipartUploadTestCase.to_string, +) +def test_multipart_upload(config: Config, test_case: MultipartUploadTestCase): + test_case.run(config) + + +class SingleShotUploadState: + + def __init__(self): + self.single_shot_file_content = None + + +class SingleShotUploadTestCase: + + def __init__(self, name: str, stream_size: int, multipart_upload_min_stream_size: int, expected_single_shot: bool): + self.name = name + self.stream_size = stream_size + self.multipart_upload_min_stream_size = multipart_upload_min_stream_size + self.expected_single_shot = expected_single_shot + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + def run(self, config: Config): + config = config.copy() + config.enable_experimental_files_api_client = True + config.multipart_upload_min_stream_size = self.multipart_upload_min_stream_size + + file_content = os.urandom(self.stream_size) + + session = requests.Session() + with requests_mock.Mocker(session=session) as session_mock: + session_mock.get(f"http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}", status_code=200) + + upload_state = SingleShotUploadState() + + def custom_matcher(request): + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + if self.expected_single_shot: + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request.method == "PUT" + ): + body = request.body.read() + upload_state.single_shot_file_content = FileContent.from_bytes(body) + + resp = requests.Response() + resp.status_code = 204 + resp.request = request + resp._content = b"" + return resp + else: + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): + + resp = requests.Response() + resp.status_code = 403 # this will throw, that's fine + resp.request = request + resp._content = b"" + return resp + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + w = WorkspaceClient(config=config) + w.files._api._api_client._session = session + + def upload(): + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True) + + if self.expected_single_shot: + upload() + actual_content = upload_state.single_shot_file_content + assert actual_content == FileContent.from_bytes(file_content) + else: + with pytest.raises(PermissionDenied): + upload() + + +@pytest.mark.parametrize( + "test_case", + [ + SingleShotUploadTestCase( + "Single-shot upload", + stream_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024 + 1, + expected_single_shot=True, + ), + SingleShotUploadTestCase( + "Multipart upload 1", + stream_size=1024 * 1024, + multipart_upload_min_stream_size=1024 * 1024, + expected_single_shot=False, + ), + SingleShotUploadTestCase( + "Multipart upload 2", + stream_size=1024 * 1024, + multipart_upload_min_stream_size=0, + expected_single_shot=False, + ), + ], + ids=SingleShotUploadTestCase.to_string, +) +def test_single_shot_upload(config: Config, test_case: SingleShotUploadTestCase): + test_case.run(config) + + +class ResumableUploadServerState: + resumable_upload_url_prefix = "https://cloud_provider.com/resumable-upload/" + abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" + + def __init__(self, unconfirmed_delta: Union[int, list]): + self.unconfirmed_delta = unconfirmed_delta + self.confirmed_last_byte: Optional[int] = None # inclusive + self.uploaded_parts = [] + self.session_token = "token-" + MultipartUploadServerState.randomstr() + self.file_content = None + self.aborted = False + + def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes, file_size_s: str): + assert not self.aborted + + assert len(part_content) > 0 + if self.confirmed_last_byte: + assert start_offset == self.confirmed_last_byte + 1 + else: + assert start_offset == 0 + + assert end_offset_incl == start_offset + len(part_content) - 1 + + is_last_part = file_size_s != "*" + if is_last_part: + assert int(file_size_s) == end_offset_incl + 1 + else: + assert not self.file_content # last chunk should not have been uploaded yet + + if isinstance(self.unconfirmed_delta, int): + unconfirmed_delta = self.unconfirmed_delta + elif len(self.uploaded_parts) < len(self.unconfirmed_delta): + unconfirmed_delta = self.unconfirmed_delta[len(self.uploaded_parts)] + else: + unconfirmed_delta = self.unconfirmed_delta[-1] # take the last delta + + if unconfirmed_delta >= len(part_content): + unconfirmed_delta = 0 # otherwise we never finish + + logger.info( + f"Saving part {len(self.uploaded_parts) + 1} of original size {len(part_content)} with unconfirmed delta {unconfirmed_delta}. is_last_part = {is_last_part}" + ) + + if unconfirmed_delta > 0: + part_content = part_content[:-unconfirmed_delta] + + fd, chunk_file = mkstemp() + with open(fd, "wb") as f: + f.write(part_content) + + self.uploaded_parts.append(chunk_file) + + if is_last_part and unconfirmed_delta == 0: + size = 0 + sha256 = hashlib.sha256() + for chunk_path in self.uploaded_parts: + size += os.path.getsize(chunk_path) + with open(chunk_path, "rb") as f: + chunk_content = f.read() + sha256.update(chunk_content) + + assert size == end_offset_incl + 1 + self.file_content = FileContent(size, sha256.hexdigest()) + + self.confirmed_last_byte = end_offset_incl - unconfirmed_delta + + def create_abort_url(self, path: str, expire_time: datetime) -> str: + assert not self.aborted + self.issued_abort_url_expire_time = expire_time + return f"{self.abort_upload_url_prefix}{path}" + + def cleanup(self): + for file in self.uploaded_parts: + os.remove(file) + + def get_file_content(self) -> FileContent: + assert not self.aborted + return self.file_content + + def abort_upload(self): + self.aborted = True + + +class ResumableUploadTestCase: + """Test case for resumable upload of a file. Resumable uploads are used on GCP. + + Resumable upload involves multiple HTTP requests: + - initiating upload (call to Databricks Files API) + - requesting resumable upload URL (call to Databricks Files API) + - uploading chunks of data (calls to cloud storage provider or Databricks storage proxy) + - aborting the upload (call to cloud storage provider or Databricks storage proxy) + + Test case uses requests-mock library to mock all these requests. Within a test, mocks use + shared server state that tracks the upload. Mocks generate the "default" (successful) response. + + Response of each call can be modified by parameterising a respective `CustomResponse` object. + """ + + path = "/test.txt" + + def __init__( + self, + name: str, + stream_size: int, + overwrite: bool = True, + multipart_upload_chunk_size: Optional[int] = None, + sdk_retry_timeout_seconds: Optional[int] = None, + multipart_upload_max_retries: Optional[int] = None, + # In resumable upload, when replying to chunk upload request, server returns + # (confirms) last accepted byte offset for the client to resume upload after. + # + # `unconfirmed_delta` defines offset from the end of the chunk that remains + # "unconfirmed", i.e. the last accepted offset would be (range_end - unconfirmed_delta). + # Can be int (same for all chunks) or list (individual for each chunk). + unconfirmed_delta: Union[int, list] = 0, + custom_response_on_create_resumable_url=CustomResponse(enabled=False), + custom_response_on_upload=CustomResponse(enabled=False), + custom_response_on_status_check=CustomResponse(enabled=False), + custom_response_on_abort=CustomResponse(enabled=False), + # exception which is expected to be thrown (so upload is expected to have failed) + expected_exception_type: Optional[Type[BaseException]] = None, + # if abort is expected to be called + expected_aborted: bool = False, + ): + self.name = name + self.stream_size = stream_size + self.overwrite = overwrite + self.multipart_upload_chunk_size = multipart_upload_chunk_size + self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds + self.multipart_upload_max_retries = multipart_upload_max_retries + self.unconfirmed_delta = unconfirmed_delta + self.custom_response_on_create_resumable_url = copy.deepcopy(custom_response_on_create_resumable_url) + self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload) + self.custom_response_on_status_check = copy.deepcopy(custom_response_on_status_check) + self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort) + self.expected_exception_type = expected_exception_type + self.expected_aborted: bool = expected_aborted + + def setup_session_mock(self, session_mock: requests_mock.Mocker, server_state: ResumableUploadServerState): + + def custom_matcher(request): + request_url = urlparse(request.url) + request_query = parse_qs(request_url.query) + + # initial request + if ( + request_url.hostname == "localhost" + and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}" + and request_query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): + + assert MultipartUploadTestCase.is_auth_header_present(request) + assert request.text is None + + def processor(): + response_json = {"resumable_upload": {"session_token": server_state.session_token}} + return [200, json.dumps(response_json), {}] + + # Different initiate error responses have been verified by test_multipart_upload(), + # so we're always generating a "success" response. + return CustomResponse(enabled=False).generate_response(request, processor) + + elif ( + request_url.hostname == "localhost" + and request_url.path == "/api/2.0/fs/create-resumable-upload-url" + and request.method == "POST" + ): + + assert MultipartUploadTestCase.is_auth_header_present(request) + + request_json = request.json() + assert request_json.keys() == {"path", "session_token"} + assert request_json["path"] == self.path + assert request_json["session_token"] == server_state.session_token + + def processor(): + resumable_upload_url = f"{ResumableUploadServerState.resumable_upload_url_prefix}{self.path}" + + response_json = { + "resumable_upload_url": { + "url": resumable_upload_url, + "headers": [{"name": "name1", "value": "value1"}], + } + } + return [200, json.dumps(response_json), {}] + + return self.custom_response_on_create_resumable_url.generate_response(request, processor) + + # resumable upload, uploading part + elif ( + request.url.startswith(ResumableUploadServerState.resumable_upload_url_prefix) + and request.method == "PUT" + ): + + assert not MultipartUploadTestCase.is_auth_header_present(request) + url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix) :] + assert url_path == self.path + + content_range_header = request.headers["Content-range"] + is_status_check_request = re.match("bytes \\*/\\*", content_range_header) + if is_status_check_request: + assert not request.body + response_customizer = self.custom_response_on_status_check + else: + response_customizer = self.custom_response_on_upload + + def processor(): + if not is_status_check_request: + body = request.body.read() + + match = re.match("bytes (\\d+)-(\\d+)/(.+)", content_range_header) + [range_start_s, range_end_s, file_size_s] = match.groups() + + server_state.save_part(int(range_start_s), int(range_end_s), body, file_size_s) + + if server_state.file_content: + # upload complete + return [200, "", {}] + else: + # more data expected + if server_state.confirmed_last_byte: + headers = {"Range": f"bytes=0-{server_state.confirmed_last_byte}"} + else: + headers = {} + return [308, "", headers] + + return response_customizer.generate_response(request, processor) + + # abort upload + elif ( + request.url.startswith(ResumableUploadServerState.resumable_upload_url_prefix) + and request.method == "DELETE" + ): + + assert not MultipartUploadTestCase.is_auth_header_present(request) + url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix) :] + assert url_path == self.path + + def processor(): + server_state.abort_upload() + return [200, "", {}] + + return self.custom_response_on_abort.generate_response(request, processor) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + def run(self, config: Config): + config = config.copy() + if self.sdk_retry_timeout_seconds: + config.retry_timeout_seconds = self.sdk_retry_timeout_seconds + if self.multipart_upload_chunk_size: + config.multipart_upload_chunk_size = self.multipart_upload_chunk_size + if self.multipart_upload_max_retries: + config.multipart_upload_max_retries = self.multipart_upload_max_retries + config.enable_experimental_files_api_client = True + config.multipart_upload_min_stream_size = 0 # disable single-shot uploads + + MultipartUploadTestCase.setup_token_auth(config) + + file_content = os.urandom(self.stream_size) + + upload_state = ResumableUploadServerState(self.unconfirmed_delta) + + try: + with requests_mock.Mocker() as session_mock: + self.setup_session_mock(session_mock, upload_state) + w = WorkspaceClient(config=config) + + def upload(): + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=self.overwrite) + + if self.expected_exception_type is not None: + with pytest.raises(self.expected_exception_type): + upload() + else: + upload() + actual_content = upload_state.get_file_content() + assert actual_content == FileContent.from_bytes(file_content) + + assert upload_state.aborted == self.expected_aborted + + finally: + upload_state.cleanup() + + def __str__(self): + return self.name + + @staticmethod + def to_string(test_case): + return str(test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # ------------------ failures on creating resumable upload URL ------------------ + ResumableUploadTestCase( + "Create resumable URL: 400 response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse( + code=400, + # 1 failure is enough + only_invocation=1, + ), + expected_exception_type=BadRequest, + expected_aborted=False, # upload didn't start + ), + ResumableUploadTestCase( + "Create resumable URL: 403 response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=403, only_invocation=1), + expected_exception_type=PermissionDenied, + expected_aborted=False, # upload didn't start + ), + ResumableUploadTestCase( + "Create resumable URL: 500 response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=500, only_invocation=1), + expected_exception_type=InternalError, + expected_aborted=False, # upload didn't start + ), + ResumableUploadTestCase( + "Create resumable URL: non-JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(body="Foo bar", only_invocation=1), + expected_exception_type=requests.exceptions.JSONDecodeError, + expected_aborted=False, # upload didn't start + ), + ResumableUploadTestCase( + "Create resumable URL: meaningless JSON response is not retried", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse( + body='{"upload_part_urls":[{"url":""}]}', only_invocation=1 + ), + expected_exception_type=ValueError, + expected_aborted=False, # upload didn't start + ), + ResumableUploadTestCase( + "Create resumable URL: permanent retryable status code", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse(code=429), + sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout) + expected_exception_type=TimeoutError, + expected_aborted=False, # upload didn't start + ), + ResumableUploadTestCase( + "Create resumable URL: intermittent retryable exception is retried", + stream_size=1024 * 1024, + custom_response_on_create_resumable_url=CustomResponse( + exception=requests.Timeout, + # 3 failures total + first_invocation=1, + last_invocation=3, + ), + expected_aborted=False, # upload succeeds + ), + # ------------------ failures during upload ------------------ + ResumableUploadTestCase( + "Upload: retryable exception after file is uploaded", + stream_size=1024 * 1024, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, + # this makes server state change before exception is thrown + exception_happened_before_processing=False, + ), + # Despite the returned error, file has been uploaded. We'll discover that + # on the next status check and consider upload completed. + expected_aborted=False, + ), + ResumableUploadTestCase( + "Upload: retryable exception before file is uploaded, not enough retries", + stream_size=1024 * 1024, + multipart_upload_max_retries=3, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, + # prevent server from saving this chunk + exception_happened_before_processing=True, + # fail 4 times, exceeding max_retries + first_invocation=1, + last_invocation=4, + ), + # File was never uploaded and we gave up retrying + expected_exception_type=requests.ConnectionError, + expected_aborted=True, + ), + ResumableUploadTestCase( + "Upload: retryable exception before file is uploaded, enough retries", + stream_size=1024 * 1024, + multipart_upload_max_retries=4, + custom_response_on_upload=CustomResponse( + exception=requests.ConnectionError, + # prevent server from saving this chunk + exception_happened_before_processing=True, + # fail 4 times, not exceeding max_retries + first_invocation=1, + last_invocation=4, + ), + # File was uploaded after retries + expected_aborted=False, + ), + ResumableUploadTestCase( + "Upload: intermittent 429 response: retried", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024, + multipart_upload_max_retries=3, + custom_response_on_upload=CustomResponse( + code=429, + # 3 failures not exceeding max_retries + first_invocation=2, + last_invocation=4, + ), + expected_aborted=False, # upload succeeded + ), + ResumableUploadTestCase( + "Upload: intermittent 429 response: retry exhausted", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=1 * 1024 * 1024, + multipart_upload_max_retries=3, + custom_response_on_upload=CustomResponse( + code=429, + # 4 failures exceeding max_retries + first_invocation=2, + last_invocation=5, + ), + expected_exception_type=TooManyRequests, + expected_aborted=True, + ), + # -------------- abort failures -------------- + ResumableUploadTestCase( + "Abort: client error", + stream_size=1024 * 1024, + # prevent chunk from being uploaded + custom_response_on_upload=CustomResponse(code=403), + # internal server error does not prevent server state change + custom_response_on_abort=CustomResponse(code=500), + expected_exception_type=PermissionDenied, + # abort returned error but was actually processed + expected_aborted=True, + ), + # -------------- file already exists -------------- + ResumableUploadTestCase( + "File already exists", + stream_size=1024 * 1024, + overwrite=False, + custom_response_on_upload=CustomResponse(code=412, only_invocation=1), + expected_exception_type=AlreadyExists, + expected_aborted=True, + ), + # -------------- success cases -------------- + ResumableUploadTestCase( + "Multiple chunks, zero unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + # server accepts all the chunks in full + unconfirmed_delta=0, + expected_aborted=False, + ), + ResumableUploadTestCase( + "Multiple small chunks, zero unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=100 * 1024, + # server accepts all the chunks in full + unconfirmed_delta=0, + expected_aborted=False, + ), + ResumableUploadTestCase( + "Multiple chunks, non-zero unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + # for every chunk, server accepts all except last 239 bytes + unconfirmed_delta=239, + expected_aborted=False, + ), + ResumableUploadTestCase( + "Multiple chunks, variable unconfirmed delta", + stream_size=100 * 1024 * 1024, + multipart_upload_chunk_size=7 * 1024 * 1024 + 566, + # for the first chunk, server accepts all except last 15Kib + # for the second chunk, server accepts it all + # for the 3rd chunk, server accepts all except last 25000 bytes + # for the 4th chunk, server accepts all except last 7 Mb + # for the 5th chunk onwards server accepts all except last 5 bytes + unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5], + expected_aborted=False, + ), + ], + ids=ResumableUploadTestCase.to_string, +) +def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase): + test_case.run(config)