diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py
index 3334fa94f..591aafc44 100644
--- a/databricks/sdk/config.py
+++ b/databricks/sdk/config.py
@@ -1,5 +1,6 @@
import configparser
import copy
+import datetime
import logging
import os
import pathlib
@@ -98,6 +99,49 @@ class Config:
files_api_client_download_max_total_recovers = None
files_api_client_download_max_total_recovers_without_progressing = 1
+ # File multipart upload parameters
+ # ----------------------
+
+ # Minimal input stream size (bytes) to use multipart / resumable uploads.
+ # For small files it's more efficient to make one single-shot upload request.
+ # When uploading a file, SDK will initially buffer this many bytes from input stream.
+ # This parameter can be less or bigger than multipart_upload_chunk_size.
+ multipart_upload_min_stream_size: int = 5 * 1024 * 1024
+
+ # Maximum number of presigned URLs that can be requested at a time.
+ #
+ # The more URLs we request at once, the higher chance is that some of the URLs will expire
+ # before we get to use it. We discover the presigned URL is expired *after* sending the
+ # input stream partition to the server. So to retry the upload of this partition we must rewind
+ # the stream back. In case of a non-seekable stream we cannot rewind, so we'll abort
+ # the upload. To reduce the chance of this, we're requesting presigned URLs one by one
+ # and using them immediately.
+ multipart_upload_batch_url_count: int = 1
+
+ # Size of the chunk to use for multipart uploads.
+ #
+ # The smaller chunk is, the less chance for network errors (or URL get expired),
+ # but the more requests we'll make.
+ # For AWS, minimum is 5Mb: https://docs.aws.amazon.com/AmazonS3/latest/userguide/qfacts.html
+ # For GCP, minimum is 256 KiB (and also recommended multiple is 256 KiB)
+ # boto uses 8Mb: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig
+ multipart_upload_chunk_size: int = 10 * 1024 * 1024
+
+ # use maximum duration of 1 hour
+ multipart_upload_url_expiration_duration: datetime.timedelta = datetime.timedelta(hours=1)
+
+ # This is not a "wall time" cutoff for the whole upload request,
+ # but a maximum time between consecutive data reception events (even 1 byte) from the server
+ multipart_upload_single_chunk_upload_timeout_seconds: float = 60
+
+ # Cap on the number of custom retries during incremental uploads:
+ # 1) multipart: upload part URL is expired, so new upload URLs must be requested to continue upload
+ # 2) resumable: chunk upload produced a retryable response (or exception), so upload status must be
+ # retrieved to continue the upload.
+ # In these two cases standard SDK retries (which are capped by the `retry_timeout_seconds` option) are not used.
+ # Note that retry counter is reset when upload is successfully resumed.
+ multipart_upload_max_retries = 3
+
def __init__(
self,
*,
diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py
index 6a9b263bd..8d9923b4f 100644
--- a/databricks/sdk/mixins/files.py
+++ b/databricks/sdk/mixins/files.py
@@ -1,26 +1,36 @@
from __future__ import annotations
import base64
+import datetime
import logging
import os
import pathlib
import platform
+import re
import shutil
import sys
+import xml.etree.ElementTree as ET
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
+from datetime import timedelta
from io import BytesIO
from types import TracebackType
-from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable,
- Optional, Type, Union)
+from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Callable, Generator,
+ Iterable, Optional, Type, Union)
from urllib import parse
+import requests
+import requests.adapters
from requests import RequestException
-from .._base_client import _RawResponse, _StreamingResponse
+from .._base_client import _BaseClient, _RawResponse, _StreamingResponse
from .._property import _cached_property
-from ..errors import NotFound
+from ..config import Config
+from ..errors import AlreadyExists, NotFound
+from ..errors.customizer import _RetryAfterCustomizer
+from ..errors.mapper import _error_mapper
+from ..retries import retried
from ..service import files
from ..service._internal import _escape_multi_segment_path_parameter
from ..service.files import DownloadResponse
@@ -683,9 +693,13 @@ def delete(self, path: str, *, recursive=False):
class FilesExt(files.FilesAPI):
__doc__ = files.FilesAPI.__doc__
+ # note that these error codes are retryable only for idempotent operations
+ _RETRYABLE_STATUS_CODES = [408, 429, 500, 502, 503, 504]
+
def __init__(self, api_client, config: Config):
super().__init__(api_client)
self._config = config.copy()
+ self._multipart_upload_read_ahead_bytes = 1
def download(self, file_path: str) -> DownloadResponse:
"""Download a file.
@@ -703,7 +717,7 @@ def download(self, file_path: str) -> DownloadResponse:
:returns: :class:`DownloadResponse`
"""
- initial_response: DownloadResponse = self._download_raw_stream(
+ initial_response: DownloadResponse = self._open_download_stream(
file_path=file_path,
start_byte_offset=0,
if_unmodified_since_timestamp=None,
@@ -713,12 +727,594 @@ def download(self, file_path: str) -> DownloadResponse:
initial_response.contents._response = wrapped_response
return initial_response
- def _download_raw_stream(
+ def upload(self, file_path: str, contents: BinaryIO, *, overwrite: Optional[bool] = None):
+ """Upload a file.
+
+ Uploads a file. The file contents should be sent as the request body as raw bytes (an
+ octet stream); do not encode or otherwise modify the bytes before sending. The contents of the
+ resulting file will be exactly the bytes sent in the request body. If the request is successful, there
+ is no response body.
+
+ :param file_path: str
+ The absolute remote path of the target file.
+ :param contents: BinaryIO
+ :param overwrite: bool (optional)
+ If true, an existing file will be overwritten. When not specified, assumed True.
+ """
+
+ # Upload empty and small files with one-shot upload.
+ pre_read_buffer = contents.read(self._config.multipart_upload_min_stream_size)
+ if len(pre_read_buffer) < self._config.multipart_upload_min_stream_size:
+ _LOG.debug(
+ f"Using one-shot upload for input stream of size {len(pre_read_buffer)} below {self._config.multipart_upload_min_stream_size} bytes"
+ )
+ return super().upload(file_path=file_path, contents=BytesIO(pre_read_buffer), overwrite=overwrite)
+
+ query = {"action": "initiate-upload"}
+ if overwrite is not None:
+ query["overwrite"] = overwrite
+
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
+ initiate_upload_response = self._api.do(
+ "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}", query=query
+ )
+
+ if initiate_upload_response.get("multipart_upload"):
+ cloud_provider_session = self._create_cloud_provider_session()
+ session_token = initiate_upload_response["multipart_upload"].get("session_token")
+ if not session_token:
+ raise ValueError(f"Unexpected server response: {initiate_upload_response}")
+
+ try:
+ self._perform_multipart_upload(
+ file_path, contents, session_token, pre_read_buffer, cloud_provider_session
+ )
+ except Exception as e:
+ _LOG.info(f"Aborting multipart upload on error: {e}")
+ try:
+ self._abort_multipart_upload(file_path, session_token, cloud_provider_session)
+ except BaseException as ex:
+ _LOG.warning(f"Failed to abort upload: {ex}")
+ # ignore, abort is a best-effort
+ finally:
+ # rethrow original exception
+ raise e from None
+
+ elif initiate_upload_response.get("resumable_upload"):
+ cloud_provider_session = self._create_cloud_provider_session()
+ session_token = initiate_upload_response["resumable_upload"]["session_token"]
+ self._perform_resumable_upload(
+ file_path, contents, session_token, overwrite, pre_read_buffer, cloud_provider_session
+ )
+ else:
+ raise ValueError(f"Unexpected server response: {initiate_upload_response}")
+
+ def _perform_multipart_upload(
self,
- file_path: str,
- start_byte_offset: int,
- if_unmodified_since_timestamp: Optional[str] = None,
+ target_path: str,
+ input_stream: BinaryIO,
+ session_token: str,
+ pre_read_buffer: bytes,
+ cloud_provider_session: requests.Session,
+ ):
+ """
+ Performs multipart upload using presigned URLs on AWS and Azure:
+ https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html
+ """
+ current_part_number = 1
+ etags: dict = {}
+
+ # Why are we buffering the current chunk?
+ # AWS and Azure don't support traditional "Transfer-encoding: chunked", so we must
+ # provide each chunk size up front. In case of a non-seekable input stream we need
+ # to buffer a chunk before uploading to know its size. This also allows us to rewind
+ # the stream before retrying on request failure.
+ # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html
+ # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads
+
+ chunk_offset = 0 # used only for logging
+
+ # This buffer is expected to contain at least multipart_upload_chunk_size bytes.
+ # Note that initially buffer can be bigger (from pre_read_buffer).
+ buffer = pre_read_buffer
+
+ retry_count = 0
+ eof = False
+ while not eof:
+ # If needed, buffer the next chunk.
+ buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream)
+ if len(buffer) == 0:
+ # End of stream, no need to request the next block of upload URLs.
+ break
+
+ _LOG.debug(
+ f"Multipart upload: requesting next {self._config.multipart_upload_batch_url_count} upload URLs starting from part {current_part_number}"
+ )
+
+ body: dict = {
+ "path": target_path,
+ "session_token": session_token,
+ "start_part_number": current_part_number,
+ "count": self._config.multipart_upload_batch_url_count,
+ "expire_time": self._get_url_expire_time(),
+ }
+
+ headers = {"Content-Type": "application/json"}
+
+ # Requesting URLs for the same set of parts is an idempotent operation, safe to retry.
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
+ upload_part_urls_response = self._api.do(
+ "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body
+ )
+
+ upload_part_urls = upload_part_urls_response.get("upload_part_urls", [])
+ if len(upload_part_urls) == 0:
+ raise ValueError(f"Unexpected server response: {upload_part_urls_response}")
+
+ for upload_part_url in upload_part_urls:
+ buffer = FilesExt._fill_buffer(buffer, self._config.multipart_upload_chunk_size, input_stream)
+ actual_buffer_length = len(buffer)
+ if actual_buffer_length == 0:
+ eof = True
+ break
+
+ url = upload_part_url["url"]
+ required_headers = upload_part_url.get("headers", [])
+ assert current_part_number == upload_part_url["part_number"]
+
+ headers: dict = {"Content-Type": "application/octet-stream"}
+ for h in required_headers:
+ headers[h["name"]] = h["value"]
+
+ actual_chunk_length = min(actual_buffer_length, self._config.multipart_upload_chunk_size)
+ _LOG.debug(
+ f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]"
+ )
+
+ chunk = BytesIO(buffer[:actual_chunk_length])
+
+ def rewind():
+ chunk.seek(0, os.SEEK_SET)
+
+ def perform():
+ return cloud_provider_session.request(
+ "PUT",
+ url,
+ headers=headers,
+ data=chunk,
+ timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
+ )
+
+ upload_response = self._retry_idempotent_operation(perform, rewind)
+
+ if upload_response.status_code in (200, 201):
+ # Chunk upload successful
+
+ chunk_offset += actual_chunk_length
+
+ etag = upload_response.headers.get("ETag", "")
+ etags[current_part_number] = etag
+
+ # Discard uploaded bytes
+ buffer = buffer[actual_chunk_length:]
+
+ # Reset retry count when progressing along the stream
+ retry_count = 0
+
+ elif FilesExt._is_url_expired_response(upload_response):
+ if retry_count < self._config.multipart_upload_max_retries:
+ retry_count += 1
+ _LOG.debug("Upload URL expired")
+ # Preserve the buffer so we'll upload the current part again using next upload URL
+ else:
+ # don't confuse user with unrelated "Permission denied" error.
+ raise ValueError(f"Unsuccessful chunk upload: upload URL expired")
+
+ else:
+ message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
+ _LOG.warning(message)
+ mapped_error = _error_mapper(upload_response, {})
+ raise mapped_error or ValueError(message)
+
+ current_part_number += 1
+
+ _LOG.debug(
+ f"Completing multipart upload after uploading {len(etags)} parts of up to {self._config.multipart_upload_chunk_size} bytes"
+ )
+
+ query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token}
+ headers = {"Content-Type": "application/json"}
+ body: dict = {}
+
+ parts = []
+ for etag in sorted(etags.items()):
+ part = {"part_number": etag[0], "etag": etag[1]}
+ parts.append(part)
+
+ body["parts"] = parts
+
+ # Completing upload is an idempotent operation, safe to retry.
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
+ self._api.do(
+ "POST",
+ f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(target_path)}",
+ query=query,
+ headers=headers,
+ body=body,
+ )
+
+ @staticmethod
+ def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO):
+ """
+ Tries to fill given buffer to contain at least `desired_min_size` bytes by reading from input stream.
+ """
+ bytes_to_read = max(0, desired_min_size - len(buffer))
+ if bytes_to_read > 0:
+ next_buf = input_stream.read(bytes_to_read)
+ new_buffer = buffer + next_buf
+ return new_buffer
+ else:
+ # we have already buffered enough data
+ return buffer
+
+ @staticmethod
+ def _is_url_expired_response(response: requests.Response):
+ """
+ Checks if response matches one of the known "URL expired" responses from the cloud storage providers.
+ """
+ if response.status_code != 403:
+ return False
+
+ try:
+ xml_root = ET.fromstring(response.content)
+ if xml_root.tag != "Error":
+ return False
+
+ code = xml_root.find("Code")
+ if code is None:
+ return False
+
+ if code.text == "AuthenticationFailed":
+ # Azure
+ details = xml_root.find("AuthenticationErrorDetail")
+ if details is not None and "Signature not valid in the specified time frame" in details.text:
+ return True
+
+ if code.text == "AccessDenied":
+ # AWS
+ message = xml_root.find("Message")
+ if message is not None and message.text == "Request has expired":
+ return True
+
+ except ET.ParseError:
+ pass
+
+ return False
+
+ def _perform_resumable_upload(
+ self,
+ target_path: str,
+ input_stream: BinaryIO,
+ session_token: str,
+ overwrite: bool,
+ pre_read_buffer: bytes,
+ cloud_provider_session: requests.Session,
+ ):
+ """
+ Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads
+ """
+
+ # Session URI we're using expires after a week
+
+ # Why are we buffering the current chunk?
+ # When using resumable upload API we're uploading data in chunks. During chunk upload
+ # server responds with the "received offset" confirming how much data it stored so far,
+ # so we should continue uploading from that offset. (Note this is not a failure but an
+ # expected behaviour as per the docs.) But, input stream might be consumed beyond that
+ # offset, since server might have read more data than it confirmed received, or some data
+ # might have been pre-cached by e.g. OS or a proxy. So, to continue upload, we must rewind
+ # the input stream back to the byte next to "received offset". This is not possible
+ # for non-seekable input stream, so we must buffer the whole last chunk and seek inside
+ # the buffer. By always uploading from the buffer we fully support non-seekable streams.
+
+ # Why are we doing read-ahead?
+ # It's not possible to upload an empty chunk as "Content-Range" header format does not
+ # support this. So if current chunk happens to finish exactly at the end of the stream,
+ # we need to know that and mark the chunk as last (by passing real file size in the
+ # "Content-Range" header) when uploading it. To detect if we're at the end of the stream
+ # we're reading "ahead" an extra bytes but not uploading them immediately. If
+ # nothing has been read ahead, it means we're at the end of the stream.
+ # On the contrary, in multipart upload we can decide to complete upload *after*
+ # last chunk has been sent.
+
+ body: dict = {"path": target_path, "session_token": session_token}
+
+ headers = {"Content-Type": "application/json"}
+
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
+ resumable_upload_url_response = self._api.do(
+ "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body
+ )
+
+ resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url")
+ if not resumable_upload_url_node:
+ raise ValueError(f"Unexpected server response: {resumable_upload_url_response}")
+
+ resumable_upload_url = resumable_upload_url_node.get("url")
+ if not resumable_upload_url:
+ raise ValueError(f"Unexpected server response: {resumable_upload_url_response}")
+
+ required_headers = resumable_upload_url_node.get("headers", [])
+
+ try:
+ # We will buffer this many bytes: one chunk + read-ahead block.
+ # Note buffer may contain more data initially (from pre_read_buffer).
+ min_buffer_size = self._config.multipart_upload_chunk_size + self._multipart_upload_read_ahead_bytes
+
+ buffer = pre_read_buffer
+
+ # How many bytes in the buffer were confirmed to be received by the server.
+ # All the remaining bytes in the buffer must be uploaded.
+ uploaded_bytes_count = 0
+
+ chunk_offset = 0
+
+ retry_count = 0
+ while True:
+ # If needed, fill the buffer to contain at least min_buffer_size bytes
+ # (unless end of stream), discarding already uploaded bytes.
+ bytes_to_read = max(0, min_buffer_size - (len(buffer) - uploaded_bytes_count))
+ next_buf = input_stream.read(bytes_to_read)
+ buffer = buffer[uploaded_bytes_count:] + next_buf
+
+ if len(next_buf) < bytes_to_read:
+ # This is the last chunk in the stream.
+ # Let's upload all the remaining bytes in one go.
+ actual_chunk_length = len(buffer)
+ file_size = chunk_offset + actual_chunk_length
+ else:
+ # More chunks expected, let's upload current chunk (excluding read-ahead block).
+ actual_chunk_length = self._config.multipart_upload_chunk_size
+ file_size = "*"
+
+ headers: dict = {"Content-Type": "application/octet-stream"}
+ for h in required_headers:
+ headers[h["name"]] = h["value"]
+
+ chunk_last_byte_offset = chunk_offset + actual_chunk_length - 1
+ content_range_header = f"bytes {chunk_offset}-{chunk_last_byte_offset}/{file_size}"
+ _LOG.debug(f"Uploading chunk: {content_range_header}")
+ headers["Content-Range"] = content_range_header
+
+ def retrieve_upload_status() -> Optional[requests.Response]:
+ def perform():
+ return cloud_provider_session.request(
+ "PUT",
+ resumable_upload_url,
+ headers={"Content-Range": "bytes */*"},
+ data=b"",
+ timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
+ )
+
+ try:
+ return self._retry_idempotent_operation(perform)
+ except RequestException:
+ _LOG.warning("Failed to retrieve upload status")
+ return None
+
+ try:
+ upload_response = cloud_provider_session.request(
+ "PUT",
+ resumable_upload_url,
+ headers=headers,
+ data=BytesIO(buffer[:actual_chunk_length]),
+ timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
+ )
+
+ # https://cloud.google.com/storage/docs/performing-resumable-uploads#resume-upload
+ # If an upload request is terminated before receiving a response, or if you receive
+ # a 503 or 500 response, then you need to resume the interrupted upload from where it left off.
+
+ # Let's follow that for all potentially retryable status codes.
+ # Together with the catch block below we replicate the logic in _retry_idempotent_operation().
+ if upload_response.status_code in self._RETRYABLE_STATUS_CODES:
+ if retry_count < self._config.multipart_upload_max_retries:
+ retry_count += 1
+ # let original upload_response be handled as an error
+ upload_response = retrieve_upload_status() or upload_response
+ else:
+ # we received non-retryable response, reset retry count
+ retry_count = 0
+
+ except RequestException as e:
+ # Let's do the same for retryable network errors.
+ if _BaseClient._is_retryable(e) and retry_count < self._config.multipart_upload_max_retries:
+ retry_count += 1
+ upload_response = retrieve_upload_status()
+ if not upload_response:
+ # rethrow original exception
+ raise e from None
+ else:
+ # rethrow original exception
+ raise e from None
+
+ if upload_response.status_code in (200, 201):
+ if file_size == "*":
+ raise ValueError(
+ f"Received unexpected status {upload_response.status_code} before reaching end of stream"
+ )
+
+ # upload complete
+ break
+
+ elif upload_response.status_code == 308:
+ # chunk accepted (or check-status succeeded), let's determine received offset to resume from there
+ range_string = upload_response.headers.get("Range")
+ confirmed_offset = self._extract_range_offset(range_string)
+ _LOG.debug(f"Received confirmed offset: {confirmed_offset}")
+
+ if confirmed_offset:
+ if confirmed_offset < chunk_offset - 1 or confirmed_offset > chunk_last_byte_offset:
+ raise ValueError(
+ f"Unexpected received offset: {confirmed_offset} is outside of expected range, chunk offset: {chunk_offset}, chunk last byte offset: {chunk_last_byte_offset}"
+ )
+ else:
+ if chunk_offset > 0:
+ raise ValueError(
+ f"Unexpected received offset: {confirmed_offset} is outside of expected range, chunk offset: {chunk_offset}, chunk last byte offset: {chunk_last_byte_offset}"
+ )
+
+ # We have just uploaded a part of chunk starting from offset "chunk_offset" and ending
+ # at offset "confirmed_offset" (inclusive), so the next chunk will start at
+ # offset "confirmed_offset + 1"
+ if confirmed_offset:
+ next_chunk_offset = confirmed_offset + 1
+ else:
+ next_chunk_offset = chunk_offset
+ uploaded_bytes_count = next_chunk_offset - chunk_offset
+ chunk_offset = next_chunk_offset
+
+ elif upload_response.status_code == 412 and not overwrite:
+ # Assuming this is only possible reason
+ # Full message in this case: "At least one of the pre-conditions you specified did not hold."
+ raise AlreadyExists("The file being created already exists.")
+
+ else:
+ message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}"
+ _LOG.warning(message)
+ mapped_error = _error_mapper(upload_response, {})
+ raise mapped_error or ValueError(message)
+
+ except Exception as e:
+ _LOG.info(f"Aborting resumable upload on error: {e}")
+ try:
+ self._abort_resumable_upload(resumable_upload_url, required_headers, cloud_provider_session)
+ except BaseException as ex:
+ _LOG.warning(f"Failed to abort upload: {ex}")
+ # ignore, abort is a best-effort
+ finally:
+ # rethrow original exception
+ raise e from None
+
+ @staticmethod
+ def _extract_range_offset(range_string: Optional[str]) -> Optional[int]:
+ """Parses the response range header to extract the last byte."""
+ if not range_string:
+ return None # server did not yet confirm any bytes
+
+ if match := re.match("bytes=0-(\\d+)", range_string):
+ return int(match.group(1))
+ else:
+ raise ValueError(f"Cannot parse response header: Range: {range_string}")
+
+ def _get_url_expire_time(self):
+ """Generates expiration time and save it in the required format."""
+ current_time = datetime.datetime.now(datetime.timezone.utc)
+ expire_time = current_time + self._config.multipart_upload_url_expiration_duration
+ # From Google Protobuf doc:
+ # In JSON format, the Timestamp type is encoded as a string in the
+ # * [RFC 3339](https://www.ietf.org/rfc/rfc3339.txt) format. That is, the
+ # * format is "{year}-{month}-{day}T{hour}:{min}:{sec}[.{frac_sec}]Z"
+ return expire_time.strftime("%Y-%m-%dT%H:%M:%SZ")
+
+ def _abort_multipart_upload(self, target_path: str, session_token: str, cloud_provider_session: requests.Session):
+ """Aborts ongoing multipart upload session to clean up incomplete file."""
+ body: dict = {"path": target_path, "session_token": session_token, "expire_time": self._get_url_expire_time()}
+
+ headers = {"Content-Type": "application/json"}
+
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
+ abort_url_response = self._api.do("POST", "/api/2.0/fs/create-abort-upload-url", headers=headers, body=body)
+
+ abort_upload_url_node = abort_url_response["abort_upload_url"]
+ abort_url = abort_upload_url_node["url"]
+ required_headers = abort_upload_url_node.get("headers", [])
+
+ headers: dict = {"Content-Type": "application/octet-stream"}
+ for h in required_headers:
+ headers[h["name"]] = h["value"]
+
+ def perform():
+ return cloud_provider_session.request(
+ "DELETE",
+ abort_url,
+ headers=headers,
+ data=b"",
+ timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
+ )
+
+ abort_response = self._retry_idempotent_operation(perform)
+
+ if abort_response.status_code not in (200, 201):
+ raise ValueError(abort_response)
+
+ def _abort_resumable_upload(
+ self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session
+ ):
+ """Aborts ongoing resumable upload session to clean up incomplete file."""
+ headers: dict = {}
+ for h in required_headers:
+ headers[h["name"]] = h["value"]
+
+ def perform():
+ return cloud_provider_session.request(
+ "DELETE",
+ resumable_upload_url,
+ headers=headers,
+ data=b"",
+ timeout=self._config.multipart_upload_single_chunk_upload_timeout_seconds,
+ )
+
+ abort_response = self._retry_idempotent_operation(perform)
+
+ if abort_response.status_code not in (200, 201):
+ raise ValueError(abort_response)
+
+ def _create_cloud_provider_session(self):
+ """Creates a separate session which does not inherit auth headers from BaseClient session."""
+ session = requests.Session()
+
+ # following session config in _BaseClient
+ http_adapter = requests.adapters.HTTPAdapter(
+ self._config.max_connection_pools or 20, self._config.max_connections_per_pool or 20, pool_block=True
+ )
+ session.mount("https://", http_adapter)
+ # presigned URL for storage proxy can use plain HTTP
+ session.mount("http://", http_adapter)
+ return session
+
+ def _retry_idempotent_operation(
+ self, operation: Callable[[], requests.Response], before_retry: Callable = None
+ ) -> requests.Response:
+ """Perform given idempotent operation with necessary retries. Since operation is idempotent it's
+ safe to retry it for response codes where server state might have changed.
+ """
+
+ def delegate():
+ response = operation()
+ if response.status_code in self._RETRYABLE_STATUS_CODES:
+ attrs = {}
+ # this will assign "retry_after_secs" to the attrs, essentially making exception look retryable
+ _RetryAfterCustomizer().customize_error(response, attrs)
+ raise _error_mapper(response, attrs)
+ else:
+ return response
+
+ # following _BaseClient timeout
+ retry_timeout_seconds = self._config.retry_timeout_seconds or 300
+
+ return retried(
+ timeout=timedelta(seconds=retry_timeout_seconds),
+ # also retry on network errors (connection error, connection timeout)
+ # where we believe request didn't reach the server
+ is_retryable=_BaseClient._is_retryable,
+ before_retry=before_retry,
+ )(delegate)()
+
+ def _open_download_stream(
+ self, file_path: str, start_byte_offset: int, if_unmodified_since_timestamp: Optional[str] = None
) -> DownloadResponse:
+ """Opens a download stream from given offset, performing necessary retries."""
headers = {
"Accept": "application/octet-stream",
}
@@ -737,6 +1333,7 @@ def _download_raw_stream(
"content-type",
"last-modified",
]
+ # Method _api.do() takes care of retrying and will raise an exception in case of failure.
res = self._api.do(
"GET",
f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}",
@@ -753,12 +1350,12 @@ def _download_raw_stream(
return result
- def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse):
- underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
+ def _wrap_stream(self, file_path: str, download_response: DownloadResponse):
+ underlying_response = _ResilientIterator._extract_raw_response(download_response)
return _ResilientResponse(
self,
file_path,
- downloadResponse.last_modified,
+ download_response.last_modified,
offset=0,
underlying_response=underlying_response,
)
@@ -859,7 +1456,7 @@ def _recover(self) -> bool:
_LOG.debug("Trying to recover from offset " + str(self._offset))
# following call includes all the required network retries
- downloadResponse = self._api._download_raw_stream(self._file_path, self._offset, self._file_last_modified)
+ downloadResponse = self._api._open_download_stream(self._file_path, self._offset, self._file_last_modified)
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
self._underlying_iterator = underlying_response.iter_content(
chunk_size=self._chunk_size, decode_unicode=False
diff --git a/tests/test_files.py b/tests/test_files.py
index 50e6cb470..e25035523 100644
--- a/tests/test_files.py
+++ b/tests/test_files.py
@@ -1,14 +1,28 @@
+import copy
+import hashlib
+import io
+import json
import logging
import os
+import random
import re
+import time
from dataclasses import dataclass
-from typing import List, Union
+from datetime import datetime, timedelta, timezone
+from tempfile import mkstemp
+from typing import Callable, List, Optional, Type, Union
+from urllib.parse import parse_qs, urlparse
import pytest
+import requests
+import requests_mock
from requests import RequestException
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import Config
+from databricks.sdk.errors.platform import (AlreadyExists, BadRequest,
+ InternalError, PermissionDenied,
+ TooManyRequests)
logger = logging.getLogger(__name__)
@@ -392,3 +406,1452 @@ class _Constants:
)
def test_download_recover(config: Config, test_case: DownloadTestCase):
test_case.run(config)
+
+
+class FileContent:
+
+ def __init__(self, length: int, checksum: str):
+ self._length = length
+ self.checksum = checksum
+
+ @classmethod
+ def from_bytes(cls, data: bytes):
+ sha256 = hashlib.sha256()
+ sha256.update(data)
+ return FileContent(len(data), sha256.hexdigest())
+
+ def __repr__(self):
+ return f"Length: {self._length}, checksum: {self.checksum}"
+
+ def __eq__(self, other):
+ if not isinstance(other, FileContent):
+ return NotImplemented
+ return self._length == other._length and self.checksum == other.checksum
+
+
+class MultipartUploadServerState:
+ upload_chunk_url_prefix = "https://cloud_provider.com/upload-chunk/"
+ abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/"
+
+ def __init__(self):
+ self.issued_multipart_urls = {} # part_number -> expiration_time
+ self.uploaded_chunks = {} # part_number -> [chunk file path, etag]
+ self.session_token = "token-" + MultipartUploadServerState.randomstr()
+ self.file_content = None
+ self.issued_abort_url_expire_time = None
+ self.aborted = False
+
+ def create_upload_chunk_url(self, path: str, part_number: int, expire_time: datetime) -> str:
+ assert not self.aborted
+ # client may have requested a URL for the same part if retrying on network error
+ self.issued_multipart_urls[part_number] = expire_time
+ return f"{self.upload_chunk_url_prefix}{path}/{part_number}"
+
+ def create_abort_url(self, path: str, expire_time: datetime) -> str:
+ assert not self.aborted
+ self.issued_abort_url_expire_time = expire_time
+ return f"{self.abort_upload_url_prefix}{path}"
+
+ def save_part(self, part_number: int, part_content: bytes, etag: str):
+ assert not self.aborted
+ assert len(part_content) > 0
+
+ logger.info(f"Saving part {part_number} of size {len(part_content)}")
+
+ # chunk might already have been uploaded
+ existing_chunk = self.uploaded_chunks.get(part_number)
+ if existing_chunk:
+ chunk_file = existing_chunk[0]
+ with open(chunk_file, "wb") as f:
+ f.write(part_content)
+ else:
+ fd, chunk_file = mkstemp()
+ with open(fd, "wb") as f:
+ f.write(part_content)
+
+ self.uploaded_chunks[part_number] = [chunk_file, etag]
+
+ def cleanup(self):
+ for [file, _] in self.uploaded_chunks.values():
+ os.remove(file)
+
+ def get_file_content(self) -> FileContent:
+ assert not self.aborted
+ return self.file_content
+
+ def upload_complete(self, etags: dict):
+ assert not self.aborted
+ # validate etags
+ expected_etags = {}
+ for part_number in self.uploaded_chunks.keys():
+ expected_etags[part_number] = self.uploaded_chunks[part_number][1]
+ assert etags == expected_etags
+
+ size = 0
+ sha256 = hashlib.sha256()
+
+ sorted_chunks = sorted(self.uploaded_chunks.keys())
+ for part_number in sorted_chunks:
+ [chunk_path, _] = self.uploaded_chunks[part_number]
+ size += os.path.getsize(chunk_path)
+ with open(chunk_path, "rb") as f:
+ chunk_content = f.read()
+ sha256.update(chunk_content)
+
+ self.file_content = FileContent(size, sha256.hexdigest())
+
+ def abort_upload(self):
+ self.aborted = True
+
+ @staticmethod
+ def randomstr():
+ return f"{random.randrange(10000)}-{int(time.time())}"
+
+
+class CustomResponse:
+ """Custom response allows to override the "default" response generated by the server
+ with the "custom" response to simulate failure error code, unexpected response body or
+ network error.
+
+ The server is represented by the `processor` parameter in `generate_response()` call.
+ """
+
+ def __init__(
+ self,
+ # If False, default response is always returned.
+ # If True, response is defined by the current invocation count
+ # with respect to first_invocation / last_invocation / only_invocation
+ enabled=True,
+ # Custom code to return
+ code: Optional[int] = 200,
+ # Custom body to return
+ body: Optional[str] = None,
+ # Custom exception to raise
+ exception: Optional[Type[BaseException]] = None,
+ # Whether exception should be raised before calling processor()
+ # (so changing server state)
+ exception_happened_before_processing: bool = False,
+ # First invocation (1-based) at which return custom response
+ first_invocation: Optional[int] = None,
+ # Last invocation (1-based) at which return custom response
+ last_invocation: Optional[int] = None,
+ # Only invocation (1-based) at which return custom response
+ only_invocation: Optional[int] = None,
+ ):
+ self.enabled = enabled
+ self.code = code
+ self.body = body
+ self.exception = exception
+ self.exception_happened_before_processing = exception_happened_before_processing
+ self.first_invocation = first_invocation
+ self.last_invocation = last_invocation
+ self.only_invocation = only_invocation
+
+ if self.only_invocation and (self.first_invocation or self.last_invocation):
+ raise ValueError("Cannot set both only invocation and first/last invocation")
+
+ if self.exception_happened_before_processing and not self.exception:
+ raise ValueError("Exception is not defined")
+
+ self.invocation_count = 0
+
+ def invocation_matches(self):
+ if not self.enabled:
+ return False
+
+ self.invocation_count += 1
+
+ if self.only_invocation:
+ return self.invocation_count == self.only_invocation
+
+ if self.first_invocation and self.invocation_count < self.first_invocation:
+ return False
+ if self.last_invocation and self.invocation_count > self.last_invocation:
+ return False
+ return True
+
+ def generate_response(self, request: requests.Request, processor: Callable[[], list]):
+ activate_for_current_invocation = self.invocation_matches()
+
+ if activate_for_current_invocation and self.exception and self.exception_happened_before_processing:
+ # if network exception is thrown while processing a request, it's not defined
+ # if server actually processed the request (and so changed its state)
+ raise self.exception
+
+ custom_response = [self.code, self.body or "", {}]
+
+ if activate_for_current_invocation:
+ if self.code and 400 <= self.code < 500:
+ # if server returns client error, it's not supposed to change its state,
+ # so we're not calling processor()
+ [code, body, headers] = custom_response
+ else:
+ # we're calling processor() but override its response with the custom one
+ processor()
+ [code, body, headers] = custom_response
+ else:
+ [code, body, headers] = processor()
+
+ if activate_for_current_invocation and self.exception:
+ # self.exception_happened_before_processing is False
+ raise self.exception
+
+ resp = requests.Response()
+
+ resp.request = request
+ resp.status_code = code
+ resp._content = body.encode()
+
+ for key in headers:
+ resp.headers[key] = headers[key]
+
+ return resp
+
+
+class MultipartUploadTestCase:
+ """Test case for multipart upload of a file. Multipart uploads are used on AWS and Azure.
+
+ Multipart upload via presigned URLs involves multiple HTTP requests:
+ - initiating upload (call to Databricks Files API)
+ - requesting upload part URLs (calls to Databricks Files API)
+ - uploading data in chunks (calls to cloud storage provider or Databricks storage proxy)
+ - completing the upload (call to Databricks Files API)
+ - requesting abort upload URL (call to Databricks Files API)
+ - aborting the upload (call to cloud storage provider or Databricks storage proxy)
+
+ Test case uses requests-mock library to mock all these requests. Within a test, mocks use
+ shared server state that tracks the upload. Mocks generate the "default" (successful) response.
+
+ Response of each call can be modified by parameterising a respective `CustomResponse` object.
+ """
+
+ path = "/test.txt"
+
+ expired_url_aws_response = (
+ ''
+ "AuthenticationFailed
Server failed to authenticate "
+ "the request. Make sure the value of Authorization header is formed "
+ "correctly including the signature.\nRequestId:1abde581-601e-0028-"
+ "4a6d-5c3952000000\nTime:2025-01-01T16:54:20.5343181ZSignature not valid in the specified "
+ "time frame: Start [Wed, 01 Jan 2025 16:38:41 GMT] - Expiry [Wed, "
+ "01 Jan 2025 16:53:45 GMT] - Current [Wed, 01 Jan 2025 16:54:20 "
+ "GMT]"
+ )
+
+ expired_url_azure_response = (
+ '\nAccessDenied'
+ "
Request has expired"
+ "142025-01-01T17:47:13Z"
+ "2025-01-01T17:48:01Z"
+ "JY66KDXM4CXBZ7X2n8Qayqg60rbvut9P7pk0"
+ ""
+ )
+
+ # TODO test for overwrite = false
+
+ def __init__(
+ self,
+ name: str,
+ stream_size: int, # size of uploaded file or, technically, stream
+ multipart_upload_chunk_size: Optional[int] = None,
+ sdk_retry_timeout_seconds: Optional[int] = None,
+ multipart_upload_max_retries: Optional[int] = None,
+ multipart_upload_batch_url_count: Optional[int] = None,
+ custom_response_on_initiate=CustomResponse(enabled=False),
+ custom_response_on_create_multipart_url=CustomResponse(enabled=False),
+ custom_response_on_upload=CustomResponse(enabled=False),
+ custom_response_on_complete=CustomResponse(enabled=False),
+ custom_response_on_create_abort_url=CustomResponse(enabled=False),
+ custom_response_on_abort=CustomResponse(enabled=False),
+ # exception which is expected to be thrown (so upload is expected to have failed)
+ expected_exception_type: Optional[Type[BaseException]] = None,
+ # if abort is expected to be called
+ expected_aborted: bool = False,
+ ):
+ self.name = name
+ self.stream_size = stream_size
+ self.multipart_upload_chunk_size = multipart_upload_chunk_size
+ self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds
+ self.multipart_upload_max_retries = multipart_upload_max_retries
+ self.multipart_upload_batch_url_count = multipart_upload_batch_url_count
+ self.custom_response_on_initiate = copy.deepcopy(custom_response_on_initiate)
+ self.custom_response_on_create_multipart_url = copy.deepcopy(custom_response_on_create_multipart_url)
+ self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload)
+ self.custom_response_on_complete = copy.deepcopy(custom_response_on_complete)
+ self.custom_response_on_create_abort_url = copy.deepcopy(custom_response_on_create_abort_url)
+ self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort)
+ self.expected_exception_type = expected_exception_type
+ self.expected_aborted: bool = expected_aborted
+
+ def setup_session_mock(self, session_mock: requests_mock.Mocker, server_state: MultipartUploadServerState):
+
+ def custom_matcher(request):
+ request_url = urlparse(request.url)
+ request_query = parse_qs(request_url.query)
+
+ # initial request
+ if (
+ request_url.hostname == "localhost"
+ and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}"
+ and request_query.get("action") == ["initiate-upload"]
+ and request.method == "POST"
+ ):
+
+ assert MultipartUploadTestCase.is_auth_header_present(request)
+ assert request.text is None
+
+ def processor():
+ response_json = {"multipart_upload": {"session_token": server_state.session_token}}
+ return [200, json.dumps(response_json), {}]
+
+ return self.custom_response_on_initiate.generate_response(request, processor)
+
+ # multipart upload, create upload part URLs
+ elif (
+ request_url.hostname == "localhost"
+ and request_url.path == "/api/2.0/fs/create-upload-part-urls"
+ and request.method == "POST"
+ ):
+
+ assert MultipartUploadTestCase.is_auth_header_present(request)
+
+ request_json = request.json()
+ assert request_json.keys() == {"count", "expire_time", "path", "session_token", "start_part_number"}
+ assert request_json["path"] == self.path
+ assert request_json["session_token"] == server_state.session_token
+
+ start_part_number = int(request_json["start_part_number"])
+ count = int(request_json["count"])
+ assert count >= 1
+
+ expire_time = MultipartUploadTestCase.parse_and_validate_expire_time(request_json["expire_time"])
+
+ def processor():
+ response_nodes = []
+ for part_number in range(start_part_number, start_part_number + count):
+ upload_part_url = server_state.create_upload_chunk_url(self.path, part_number, expire_time)
+ response_nodes.append(
+ {
+ "part_number": part_number,
+ "url": upload_part_url,
+ "headers": [{"name": "name1", "value": "value1"}],
+ }
+ )
+
+ response_json = {"upload_part_urls": response_nodes}
+ return [200, json.dumps(response_json), {}]
+
+ return self.custom_response_on_create_multipart_url.generate_response(request, processor)
+
+ # multipart upload, uploading part
+ elif request.url.startswith(MultipartUploadServerState.upload_chunk_url_prefix) and request.method == "PUT":
+
+ assert not MultipartUploadTestCase.is_auth_header_present(request)
+
+ url_path = request.url[len(MultipartUploadServerState.abort_upload_url_prefix) :]
+ part_num = url_path.split("/")[-1]
+ assert url_path[: -len(part_num) - 1] == self.path
+
+ def processor():
+ body = request.body.read()
+ etag = "etag-" + MultipartUploadServerState.randomstr()
+ server_state.save_part(int(part_num), body, etag)
+ return [200, "", {"ETag": etag}]
+
+ return self.custom_response_on_upload.generate_response(request, processor)
+
+ # multipart upload, completion
+ elif (
+ request_url.hostname == "localhost"
+ and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}"
+ and request_query.get("action") == ["complete-upload"]
+ and request_query.get("upload_type") == ["multipart"]
+ and request.method == "POST"
+ ):
+
+ assert MultipartUploadTestCase.is_auth_header_present(request)
+ assert [server_state.session_token] == request_query.get("session_token")
+
+ def processor():
+ request_json = request.json()
+ etags = {}
+
+ for part in request_json["parts"]:
+ etags[part["part_number"]] = part["etag"]
+
+ server_state.upload_complete(etags)
+ return [200, "", {}]
+
+ return self.custom_response_on_complete.generate_response(request, processor)
+
+ # create abort URL
+ elif request.url == "http://localhost/api/2.0/fs/create-abort-upload-url" and request.method == "POST":
+ assert MultipartUploadTestCase.is_auth_header_present(request)
+ request_json = request.json()
+ assert request_json["path"] == self.path
+ expire_time = MultipartUploadTestCase.parse_and_validate_expire_time(request_json["expire_time"])
+
+ def processor():
+ response_json = {
+ "abort_upload_url": {
+ "url": server_state.create_abort_url(self.path, expire_time),
+ "headers": [{"name": "header1", "value": "headervalue1"}],
+ }
+ }
+ return [200, json.dumps(response_json), {}]
+
+ return self.custom_response_on_create_abort_url.generate_response(request, processor)
+
+ # abort upload
+ elif (
+ request.url.startswith(MultipartUploadServerState.abort_upload_url_prefix)
+ and request.method == "DELETE"
+ ):
+ assert not MultipartUploadTestCase.is_auth_header_present(request)
+ assert request.url[len(MultipartUploadServerState.abort_upload_url_prefix) :] == self.path
+
+ def processor():
+ server_state.abort_upload()
+ return [200, "", {}]
+
+ return self.custom_response_on_abort.generate_response(request, processor)
+
+ return None
+
+ session_mock.add_matcher(matcher=custom_matcher)
+
+ @staticmethod
+ def setup_token_auth(config: Config):
+ pat_token = "some_pat_token"
+ config._header_factory = lambda: {"Authorization": f"Bearer {pat_token}"}
+
+ @staticmethod
+ def is_auth_header_present(r: requests.Request):
+ return r.headers.get("Authorization") is not None
+
+ @staticmethod
+ def parse_and_validate_expire_time(s: str) -> datetime:
+ expire_time = datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ")
+ expire_time = expire_time.replace(tzinfo=timezone.utc) # Explicitly add timezone
+ now = datetime.now(timezone.utc)
+ max_expiration = now + timedelta(hours=2)
+ assert now < expire_time < max_expiration
+ return expire_time
+
+ def run(self, config: Config):
+ config = config.copy()
+
+ MultipartUploadTestCase.setup_token_auth(config)
+
+ if self.sdk_retry_timeout_seconds:
+ config.retry_timeout_seconds = self.sdk_retry_timeout_seconds
+ if self.multipart_upload_chunk_size:
+ config.multipart_upload_chunk_size = self.multipart_upload_chunk_size
+ if self.multipart_upload_max_retries:
+ config.multipart_upload_max_retries = self.multipart_upload_max_retries
+ if self.multipart_upload_batch_url_count:
+ config.multipart_upload_batch_url_count = self.multipart_upload_batch_url_count
+ config.enable_experimental_files_api_client = True
+ config.multipart_upload_min_stream_size = 0 # disable single-shot uploads
+
+ file_content = os.urandom(self.stream_size)
+
+ upload_state = MultipartUploadServerState()
+
+ try:
+ w = WorkspaceClient(config=config)
+ with requests_mock.Mocker() as session_mock:
+ self.setup_session_mock(session_mock, upload_state)
+
+ def upload():
+ w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True)
+
+ if self.expected_exception_type is not None:
+ with pytest.raises(self.expected_exception_type):
+ upload()
+ else:
+ upload()
+ actual_content = upload_state.get_file_content()
+ assert actual_content == FileContent.from_bytes(file_content)
+
+ assert upload_state.aborted == self.expected_aborted
+
+ finally:
+ upload_state.cleanup()
+
+ def __str__(self):
+ return self.name
+
+ @staticmethod
+ def to_string(test_case):
+ return str(test_case)
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ # -------------------------- failures on "initiate upload" --------------------------
+ MultipartUploadTestCase(
+ "Initiate: 400 response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(code=400, only_invocation=1),
+ expected_exception_type=BadRequest,
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: 403 response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(code=403, only_invocation=1),
+ expected_exception_type=PermissionDenied,
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: 500 response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(code=500, only_invocation=1),
+ expected_exception_type=InternalError,
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: non-JSON response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(body="this is not a JSON", only_invocation=1),
+ expected_exception_type=requests.exceptions.JSONDecodeError,
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: meaningless JSON response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(body='{"foo": 123}', only_invocation=1),
+ expected_exception_type=ValueError,
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: no session token in response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(
+ body='{"multipart_upload":{"session_token1": "token123"}}', only_invocation=1
+ ),
+ expected_exception_type=ValueError,
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: permanent retryable exception",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(exception=requests.ConnectionError),
+ sdk_retry_timeout_seconds=30, # let's not wait 5 min (SDK default timeout)
+ expected_exception_type=TimeoutError, # SDK throws this if retries are taking too long
+ expected_aborted=False, # upload didn't start
+ ),
+ MultipartUploadTestCase(
+ "Initiate: intermittent retryable exception",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(
+ exception=requests.ConnectionError,
+ # 3 calls fail, but request is successfully retried
+ first_invocation=1,
+ last_invocation=3,
+ ),
+ expected_aborted=False,
+ ),
+ MultipartUploadTestCase(
+ "Initiate: intermittent retryable status code",
+ stream_size=1024 * 1024,
+ custom_response_on_initiate=CustomResponse(
+ code=429,
+ # 3 calls fail, then retry succeeds
+ first_invocation=1,
+ last_invocation=3,
+ ),
+ expected_aborted=False,
+ ),
+ # -------------------------- failures on "create upload URL" --------------------------
+ MultipartUploadTestCase(
+ "Create upload URL: 400 response is not retied",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(
+ code=400,
+ # 1 failure is enough
+ only_invocation=1,
+ ),
+ expected_exception_type=BadRequest,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: 500 error is not retied",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1),
+ expected_exception_type=InternalError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: non-JSON response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(body="this is not a JSON", only_invocation=1),
+ expected_exception_type=requests.exceptions.JSONDecodeError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: meaningless JSON response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(body='{"foo":123}', only_invocation=1),
+ expected_exception_type=ValueError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: meaningless JSON response is not retried 2",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(body='{"upload_part_urls":[]}', only_invocation=1),
+ expected_exception_type=ValueError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: meaningless JSON response is not retried 3",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(
+ body='{"upload_part_urls":[{"url":""}]}', only_invocation=1
+ ),
+ expected_exception_type=KeyError, # TODO we might want to make JSON parsing more reliable
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: permanent retryable exception",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(exception=requests.ConnectionError),
+ sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout)
+ expected_exception_type=TimeoutError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: intermittent retryable exception",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(
+ exception=requests.Timeout,
+ # happens only once, retry succeeds
+ only_invocation=1,
+ ),
+ expected_aborted=False,
+ ),
+ MultipartUploadTestCase(
+ "Create upload URL: intermittent retryable exception 2",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(
+ exception=requests.Timeout,
+ # 4th request for multipart URLs fails 3 times, then retry succeeds
+ first_invocation=4,
+ last_invocation=6,
+ ),
+ expected_aborted=False,
+ ),
+ # -------------------------- failures on chunk upload --------------------------
+ MultipartUploadTestCase(
+ "Upload chunk: 403 response is not retried",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ code=403,
+ # fail only once
+ only_invocation=1,
+ ),
+ expected_exception_type=PermissionDenied,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: 400 response is not retried",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ code=400,
+ # fail once, but not on the first chunk
+ only_invocation=3,
+ ),
+ expected_exception_type=BadRequest,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: 500 response is not retried",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(code=500, only_invocation=5),
+ expected_exception_type=InternalError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: expired URL is retried on AWS",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ code=403, body=MultipartUploadTestCase.expired_url_aws_response, only_invocation=2
+ ),
+ expected_aborted=False,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: expired URL is retried on Azure",
+ multipart_upload_max_retries=3,
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ code=403,
+ body=MultipartUploadTestCase.expired_url_azure_response,
+ # 3 failures don't exceed multipart_upload_max_retries
+ first_invocation=2,
+ last_invocation=4,
+ ),
+ expected_aborted=False,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: expired URL is retried on Azure, requesting urls by 6",
+ multipart_upload_max_retries=3,
+ multipart_upload_batch_url_count=6,
+ stream_size=100 * 1024 * 1024, # 100 chunks
+ multipart_upload_chunk_size=1 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ code=403,
+ body=MultipartUploadTestCase.expired_url_azure_response,
+ # 3 failures don't exceed multipart_upload_max_retries
+ first_invocation=2,
+ last_invocation=4,
+ ),
+ expected_aborted=False,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: expired URL retry is exhausted",
+ multipart_upload_max_retries=3,
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ code=403,
+ body=MultipartUploadTestCase.expired_url_azure_response,
+ # 4 failures exceed multipart_upload_max_retries
+ first_invocation=2,
+ last_invocation=5,
+ ),
+ expected_exception_type=ValueError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: permanent retryable error",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout)
+ custom_response_on_upload=CustomResponse(exception=requests.ConnectionError, first_invocation=8),
+ expected_exception_type=TimeoutError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: permanent retryable status code",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout)
+ custom_response_on_upload=CustomResponse(code=429, first_invocation=8),
+ expected_exception_type=TimeoutError,
+ expected_aborted=True,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: intermittent retryable error",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ exception=requests.ConnectionError, first_invocation=2, last_invocation=5
+ ),
+ expected_aborted=False,
+ ),
+ MultipartUploadTestCase(
+ "Upload chunk: intermittent retryable status code",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ custom_response_on_upload=CustomResponse(code=429, first_invocation=2, last_invocation=4),
+ expected_aborted=False,
+ ),
+ # -------------------------- failures on abort --------------------------
+ MultipartUploadTestCase(
+ "Abort URL: 500 response",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1),
+ custom_response_on_create_abort_url=CustomResponse(code=400),
+ expected_exception_type=InternalError, # original error
+ expected_aborted=False, # server state didn't change to record abort
+ ),
+ MultipartUploadTestCase(
+ "Abort URL: 403 response",
+ stream_size=1024 * 1024,
+ custom_response_on_upload=CustomResponse(code=500, only_invocation=1),
+ custom_response_on_create_abort_url=CustomResponse(code=403),
+ expected_exception_type=InternalError, # original error
+ expected_aborted=False, # server state didn't change to record abort
+ ),
+ MultipartUploadTestCase(
+ "Abort URL: intermittent retryable error",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1),
+ custom_response_on_create_abort_url=CustomResponse(code=429, first_invocation=1, last_invocation=3),
+ expected_exception_type=InternalError, # original error
+ expected_aborted=True, # abort successfully called after abort URL creation is retried
+ ),
+ MultipartUploadTestCase(
+ "Abort URL: intermittent retryable error 2",
+ stream_size=1024 * 1024,
+ custom_response_on_create_multipart_url=CustomResponse(code=500, only_invocation=1),
+ custom_response_on_create_abort_url=CustomResponse(
+ exception=requests.Timeout, first_invocation=1, last_invocation=3
+ ),
+ expected_exception_type=InternalError, # original error
+ expected_aborted=True, # abort successfully called after abort URL creation is retried
+ ),
+ MultipartUploadTestCase(
+ "Abort: exception",
+ stream_size=1024 * 1024,
+ # don't wait for 5 min (SDK default timeout)
+ sdk_retry_timeout_seconds=30,
+ custom_response_on_create_multipart_url=CustomResponse(code=403, only_invocation=1),
+ custom_response_on_abort=CustomResponse(
+ exception=requests.Timeout,
+ # this allows to change the server state to "aborted"
+ exception_happened_before_processing=False,
+ ),
+ expected_exception_type=PermissionDenied, # original error is reported
+ expected_aborted=True,
+ ),
+ # -------------------------- happy cases --------------------------
+ MultipartUploadTestCase(
+ "Multipart upload successful: single chunk",
+ stream_size=1024 * 1024, # less than chunk size
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ ),
+ MultipartUploadTestCase(
+ "Multipart upload successful: multiple chunks (aligned)",
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ ),
+ MultipartUploadTestCase(
+ "Multipart upload successful: multiple chunks (aligned), upload urls by 3",
+ multipart_upload_batch_url_count=3,
+ stream_size=100 * 1024 * 1024, # 10 chunks
+ multipart_upload_chunk_size=10 * 1024 * 1024,
+ ),
+ MultipartUploadTestCase(
+ "Multipart upload successful: multiple chunks (not aligned), upload urls by 1",
+ stream_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder
+ multipart_upload_chunk_size=7 * 1024 * 1024 - 17,
+ ),
+ MultipartUploadTestCase(
+ "Multipart upload successful: multiple chunks (not aligned), upload urls by 5",
+ multipart_upload_batch_url_count=5,
+ stream_size=100 * 1024 * 1024 + 1566, # 14 full chunks + remainder
+ multipart_upload_chunk_size=7 * 1024 * 1024 - 17,
+ ),
+ ],
+ ids=MultipartUploadTestCase.to_string,
+)
+def test_multipart_upload(config: Config, test_case: MultipartUploadTestCase):
+ test_case.run(config)
+
+
+class SingleShotUploadState:
+
+ def __init__(self):
+ self.single_shot_file_content = None
+
+
+class SingleShotUploadTestCase:
+
+ def __init__(self, name: str, stream_size: int, multipart_upload_min_stream_size: int, expected_single_shot: bool):
+ self.name = name
+ self.stream_size = stream_size
+ self.multipart_upload_min_stream_size = multipart_upload_min_stream_size
+ self.expected_single_shot = expected_single_shot
+
+ def __str__(self):
+ return self.name
+
+ @staticmethod
+ def to_string(test_case):
+ return str(test_case)
+
+ def run(self, config: Config):
+ config = config.copy()
+ config.enable_experimental_files_api_client = True
+ config.multipart_upload_min_stream_size = self.multipart_upload_min_stream_size
+
+ file_content = os.urandom(self.stream_size)
+
+ session = requests.Session()
+ with requests_mock.Mocker(session=session) as session_mock:
+ session_mock.get(f"http://localhost/api/2.0/fs/files{MultipartUploadTestCase.path}", status_code=200)
+
+ upload_state = SingleShotUploadState()
+
+ def custom_matcher(request):
+ request_url = urlparse(request.url)
+ request_query = parse_qs(request_url.query)
+
+ if self.expected_single_shot:
+ if (
+ request_url.hostname == "localhost"
+ and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}"
+ and request.method == "PUT"
+ ):
+ body = request.body.read()
+ upload_state.single_shot_file_content = FileContent.from_bytes(body)
+
+ resp = requests.Response()
+ resp.status_code = 204
+ resp.request = request
+ resp._content = b""
+ return resp
+ else:
+ if (
+ request_url.hostname == "localhost"
+ and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}"
+ and request_query.get("action") == ["initiate-upload"]
+ and request.method == "POST"
+ ):
+
+ resp = requests.Response()
+ resp.status_code = 403 # this will throw, that's fine
+ resp.request = request
+ resp._content = b""
+ return resp
+
+ return None
+
+ session_mock.add_matcher(matcher=custom_matcher)
+
+ w = WorkspaceClient(config=config)
+ w.files._api._api_client._session = session
+
+ def upload():
+ w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True)
+
+ if self.expected_single_shot:
+ upload()
+ actual_content = upload_state.single_shot_file_content
+ assert actual_content == FileContent.from_bytes(file_content)
+ else:
+ with pytest.raises(PermissionDenied):
+ upload()
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ SingleShotUploadTestCase(
+ "Single-shot upload",
+ stream_size=1024 * 1024,
+ multipart_upload_min_stream_size=1024 * 1024 + 1,
+ expected_single_shot=True,
+ ),
+ SingleShotUploadTestCase(
+ "Multipart upload 1",
+ stream_size=1024 * 1024,
+ multipart_upload_min_stream_size=1024 * 1024,
+ expected_single_shot=False,
+ ),
+ SingleShotUploadTestCase(
+ "Multipart upload 2",
+ stream_size=1024 * 1024,
+ multipart_upload_min_stream_size=0,
+ expected_single_shot=False,
+ ),
+ ],
+ ids=SingleShotUploadTestCase.to_string,
+)
+def test_single_shot_upload(config: Config, test_case: SingleShotUploadTestCase):
+ test_case.run(config)
+
+
+class ResumableUploadServerState:
+ resumable_upload_url_prefix = "https://cloud_provider.com/resumable-upload/"
+ abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/"
+
+ def __init__(self, unconfirmed_delta: Union[int, list]):
+ self.unconfirmed_delta = unconfirmed_delta
+ self.confirmed_last_byte: Optional[int] = None # inclusive
+ self.uploaded_parts = []
+ self.session_token = "token-" + MultipartUploadServerState.randomstr()
+ self.file_content = None
+ self.aborted = False
+
+ def save_part(self, start_offset: int, end_offset_incl: int, part_content: bytes, file_size_s: str):
+ assert not self.aborted
+
+ assert len(part_content) > 0
+ if self.confirmed_last_byte:
+ assert start_offset == self.confirmed_last_byte + 1
+ else:
+ assert start_offset == 0
+
+ assert end_offset_incl == start_offset + len(part_content) - 1
+
+ is_last_part = file_size_s != "*"
+ if is_last_part:
+ assert int(file_size_s) == end_offset_incl + 1
+ else:
+ assert not self.file_content # last chunk should not have been uploaded yet
+
+ if isinstance(self.unconfirmed_delta, int):
+ unconfirmed_delta = self.unconfirmed_delta
+ elif len(self.uploaded_parts) < len(self.unconfirmed_delta):
+ unconfirmed_delta = self.unconfirmed_delta[len(self.uploaded_parts)]
+ else:
+ unconfirmed_delta = self.unconfirmed_delta[-1] # take the last delta
+
+ if unconfirmed_delta >= len(part_content):
+ unconfirmed_delta = 0 # otherwise we never finish
+
+ logger.info(
+ f"Saving part {len(self.uploaded_parts) + 1} of original size {len(part_content)} with unconfirmed delta {unconfirmed_delta}. is_last_part = {is_last_part}"
+ )
+
+ if unconfirmed_delta > 0:
+ part_content = part_content[:-unconfirmed_delta]
+
+ fd, chunk_file = mkstemp()
+ with open(fd, "wb") as f:
+ f.write(part_content)
+
+ self.uploaded_parts.append(chunk_file)
+
+ if is_last_part and unconfirmed_delta == 0:
+ size = 0
+ sha256 = hashlib.sha256()
+ for chunk_path in self.uploaded_parts:
+ size += os.path.getsize(chunk_path)
+ with open(chunk_path, "rb") as f:
+ chunk_content = f.read()
+ sha256.update(chunk_content)
+
+ assert size == end_offset_incl + 1
+ self.file_content = FileContent(size, sha256.hexdigest())
+
+ self.confirmed_last_byte = end_offset_incl - unconfirmed_delta
+
+ def create_abort_url(self, path: str, expire_time: datetime) -> str:
+ assert not self.aborted
+ self.issued_abort_url_expire_time = expire_time
+ return f"{self.abort_upload_url_prefix}{path}"
+
+ def cleanup(self):
+ for file in self.uploaded_parts:
+ os.remove(file)
+
+ def get_file_content(self) -> FileContent:
+ assert not self.aborted
+ return self.file_content
+
+ def abort_upload(self):
+ self.aborted = True
+
+
+class ResumableUploadTestCase:
+ """Test case for resumable upload of a file. Resumable uploads are used on GCP.
+
+ Resumable upload involves multiple HTTP requests:
+ - initiating upload (call to Databricks Files API)
+ - requesting resumable upload URL (call to Databricks Files API)
+ - uploading chunks of data (calls to cloud storage provider or Databricks storage proxy)
+ - aborting the upload (call to cloud storage provider or Databricks storage proxy)
+
+ Test case uses requests-mock library to mock all these requests. Within a test, mocks use
+ shared server state that tracks the upload. Mocks generate the "default" (successful) response.
+
+ Response of each call can be modified by parameterising a respective `CustomResponse` object.
+ """
+
+ path = "/test.txt"
+
+ def __init__(
+ self,
+ name: str,
+ stream_size: int,
+ overwrite: bool = True,
+ multipart_upload_chunk_size: Optional[int] = None,
+ sdk_retry_timeout_seconds: Optional[int] = None,
+ multipart_upload_max_retries: Optional[int] = None,
+ # In resumable upload, when replying to chunk upload request, server returns
+ # (confirms) last accepted byte offset for the client to resume upload after.
+ #
+ # `unconfirmed_delta` defines offset from the end of the chunk that remains
+ # "unconfirmed", i.e. the last accepted offset would be (range_end - unconfirmed_delta).
+ # Can be int (same for all chunks) or list (individual for each chunk).
+ unconfirmed_delta: Union[int, list] = 0,
+ custom_response_on_create_resumable_url=CustomResponse(enabled=False),
+ custom_response_on_upload=CustomResponse(enabled=False),
+ custom_response_on_status_check=CustomResponse(enabled=False),
+ custom_response_on_abort=CustomResponse(enabled=False),
+ # exception which is expected to be thrown (so upload is expected to have failed)
+ expected_exception_type: Optional[Type[BaseException]] = None,
+ # if abort is expected to be called
+ expected_aborted: bool = False,
+ ):
+ self.name = name
+ self.stream_size = stream_size
+ self.overwrite = overwrite
+ self.multipart_upload_chunk_size = multipart_upload_chunk_size
+ self.sdk_retry_timeout_seconds = sdk_retry_timeout_seconds
+ self.multipart_upload_max_retries = multipart_upload_max_retries
+ self.unconfirmed_delta = unconfirmed_delta
+ self.custom_response_on_create_resumable_url = copy.deepcopy(custom_response_on_create_resumable_url)
+ self.custom_response_on_upload = copy.deepcopy(custom_response_on_upload)
+ self.custom_response_on_status_check = copy.deepcopy(custom_response_on_status_check)
+ self.custom_response_on_abort = copy.deepcopy(custom_response_on_abort)
+ self.expected_exception_type = expected_exception_type
+ self.expected_aborted: bool = expected_aborted
+
+ def setup_session_mock(self, session_mock: requests_mock.Mocker, server_state: ResumableUploadServerState):
+
+ def custom_matcher(request):
+ request_url = urlparse(request.url)
+ request_query = parse_qs(request_url.query)
+
+ # initial request
+ if (
+ request_url.hostname == "localhost"
+ and request_url.path == f"/api/2.0/fs/files{MultipartUploadTestCase.path}"
+ and request_query.get("action") == ["initiate-upload"]
+ and request.method == "POST"
+ ):
+
+ assert MultipartUploadTestCase.is_auth_header_present(request)
+ assert request.text is None
+
+ def processor():
+ response_json = {"resumable_upload": {"session_token": server_state.session_token}}
+ return [200, json.dumps(response_json), {}]
+
+ # Different initiate error responses have been verified by test_multipart_upload(),
+ # so we're always generating a "success" response.
+ return CustomResponse(enabled=False).generate_response(request, processor)
+
+ elif (
+ request_url.hostname == "localhost"
+ and request_url.path == "/api/2.0/fs/create-resumable-upload-url"
+ and request.method == "POST"
+ ):
+
+ assert MultipartUploadTestCase.is_auth_header_present(request)
+
+ request_json = request.json()
+ assert request_json.keys() == {"path", "session_token"}
+ assert request_json["path"] == self.path
+ assert request_json["session_token"] == server_state.session_token
+
+ def processor():
+ resumable_upload_url = f"{ResumableUploadServerState.resumable_upload_url_prefix}{self.path}"
+
+ response_json = {
+ "resumable_upload_url": {
+ "url": resumable_upload_url,
+ "headers": [{"name": "name1", "value": "value1"}],
+ }
+ }
+ return [200, json.dumps(response_json), {}]
+
+ return self.custom_response_on_create_resumable_url.generate_response(request, processor)
+
+ # resumable upload, uploading part
+ elif (
+ request.url.startswith(ResumableUploadServerState.resumable_upload_url_prefix)
+ and request.method == "PUT"
+ ):
+
+ assert not MultipartUploadTestCase.is_auth_header_present(request)
+ url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix) :]
+ assert url_path == self.path
+
+ content_range_header = request.headers["Content-range"]
+ is_status_check_request = re.match("bytes \\*/\\*", content_range_header)
+ if is_status_check_request:
+ assert not request.body
+ response_customizer = self.custom_response_on_status_check
+ else:
+ response_customizer = self.custom_response_on_upload
+
+ def processor():
+ if not is_status_check_request:
+ body = request.body.read()
+
+ match = re.match("bytes (\\d+)-(\\d+)/(.+)", content_range_header)
+ [range_start_s, range_end_s, file_size_s] = match.groups()
+
+ server_state.save_part(int(range_start_s), int(range_end_s), body, file_size_s)
+
+ if server_state.file_content:
+ # upload complete
+ return [200, "", {}]
+ else:
+ # more data expected
+ if server_state.confirmed_last_byte:
+ headers = {"Range": f"bytes=0-{server_state.confirmed_last_byte}"}
+ else:
+ headers = {}
+ return [308, "", headers]
+
+ return response_customizer.generate_response(request, processor)
+
+ # abort upload
+ elif (
+ request.url.startswith(ResumableUploadServerState.resumable_upload_url_prefix)
+ and request.method == "DELETE"
+ ):
+
+ assert not MultipartUploadTestCase.is_auth_header_present(request)
+ url_path = request.url[len(ResumableUploadServerState.resumable_upload_url_prefix) :]
+ assert url_path == self.path
+
+ def processor():
+ server_state.abort_upload()
+ return [200, "", {}]
+
+ return self.custom_response_on_abort.generate_response(request, processor)
+
+ return None
+
+ session_mock.add_matcher(matcher=custom_matcher)
+
+ def run(self, config: Config):
+ config = config.copy()
+ if self.sdk_retry_timeout_seconds:
+ config.retry_timeout_seconds = self.sdk_retry_timeout_seconds
+ if self.multipart_upload_chunk_size:
+ config.multipart_upload_chunk_size = self.multipart_upload_chunk_size
+ if self.multipart_upload_max_retries:
+ config.multipart_upload_max_retries = self.multipart_upload_max_retries
+ config.enable_experimental_files_api_client = True
+ config.multipart_upload_min_stream_size = 0 # disable single-shot uploads
+
+ MultipartUploadTestCase.setup_token_auth(config)
+
+ file_content = os.urandom(self.stream_size)
+
+ upload_state = ResumableUploadServerState(self.unconfirmed_delta)
+
+ try:
+ with requests_mock.Mocker() as session_mock:
+ self.setup_session_mock(session_mock, upload_state)
+ w = WorkspaceClient(config=config)
+
+ def upload():
+ w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=self.overwrite)
+
+ if self.expected_exception_type is not None:
+ with pytest.raises(self.expected_exception_type):
+ upload()
+ else:
+ upload()
+ actual_content = upload_state.get_file_content()
+ assert actual_content == FileContent.from_bytes(file_content)
+
+ assert upload_state.aborted == self.expected_aborted
+
+ finally:
+ upload_state.cleanup()
+
+ def __str__(self):
+ return self.name
+
+ @staticmethod
+ def to_string(test_case):
+ return str(test_case)
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ # ------------------ failures on creating resumable upload URL ------------------
+ ResumableUploadTestCase(
+ "Create resumable URL: 400 response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(
+ code=400,
+ # 1 failure is enough
+ only_invocation=1,
+ ),
+ expected_exception_type=BadRequest,
+ expected_aborted=False, # upload didn't start
+ ),
+ ResumableUploadTestCase(
+ "Create resumable URL: 403 response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(code=403, only_invocation=1),
+ expected_exception_type=PermissionDenied,
+ expected_aborted=False, # upload didn't start
+ ),
+ ResumableUploadTestCase(
+ "Create resumable URL: 500 response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(code=500, only_invocation=1),
+ expected_exception_type=InternalError,
+ expected_aborted=False, # upload didn't start
+ ),
+ ResumableUploadTestCase(
+ "Create resumable URL: non-JSON response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(body="Foo bar", only_invocation=1),
+ expected_exception_type=requests.exceptions.JSONDecodeError,
+ expected_aborted=False, # upload didn't start
+ ),
+ ResumableUploadTestCase(
+ "Create resumable URL: meaningless JSON response is not retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(
+ body='{"upload_part_urls":[{"url":""}]}', only_invocation=1
+ ),
+ expected_exception_type=ValueError,
+ expected_aborted=False, # upload didn't start
+ ),
+ ResumableUploadTestCase(
+ "Create resumable URL: permanent retryable status code",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(code=429),
+ sdk_retry_timeout_seconds=30, # don't wait for 5 min (SDK default timeout)
+ expected_exception_type=TimeoutError,
+ expected_aborted=False, # upload didn't start
+ ),
+ ResumableUploadTestCase(
+ "Create resumable URL: intermittent retryable exception is retried",
+ stream_size=1024 * 1024,
+ custom_response_on_create_resumable_url=CustomResponse(
+ exception=requests.Timeout,
+ # 3 failures total
+ first_invocation=1,
+ last_invocation=3,
+ ),
+ expected_aborted=False, # upload succeeds
+ ),
+ # ------------------ failures during upload ------------------
+ ResumableUploadTestCase(
+ "Upload: retryable exception after file is uploaded",
+ stream_size=1024 * 1024,
+ custom_response_on_upload=CustomResponse(
+ exception=requests.ConnectionError,
+ # this makes server state change before exception is thrown
+ exception_happened_before_processing=False,
+ ),
+ # Despite the returned error, file has been uploaded. We'll discover that
+ # on the next status check and consider upload completed.
+ expected_aborted=False,
+ ),
+ ResumableUploadTestCase(
+ "Upload: retryable exception before file is uploaded, not enough retries",
+ stream_size=1024 * 1024,
+ multipart_upload_max_retries=3,
+ custom_response_on_upload=CustomResponse(
+ exception=requests.ConnectionError,
+ # prevent server from saving this chunk
+ exception_happened_before_processing=True,
+ # fail 4 times, exceeding max_retries
+ first_invocation=1,
+ last_invocation=4,
+ ),
+ # File was never uploaded and we gave up retrying
+ expected_exception_type=requests.ConnectionError,
+ expected_aborted=True,
+ ),
+ ResumableUploadTestCase(
+ "Upload: retryable exception before file is uploaded, enough retries",
+ stream_size=1024 * 1024,
+ multipart_upload_max_retries=4,
+ custom_response_on_upload=CustomResponse(
+ exception=requests.ConnectionError,
+ # prevent server from saving this chunk
+ exception_happened_before_processing=True,
+ # fail 4 times, not exceeding max_retries
+ first_invocation=1,
+ last_invocation=4,
+ ),
+ # File was uploaded after retries
+ expected_aborted=False,
+ ),
+ ResumableUploadTestCase(
+ "Upload: intermittent 429 response: retried",
+ stream_size=100 * 1024 * 1024,
+ multipart_upload_chunk_size=7 * 1024 * 1024,
+ multipart_upload_max_retries=3,
+ custom_response_on_upload=CustomResponse(
+ code=429,
+ # 3 failures not exceeding max_retries
+ first_invocation=2,
+ last_invocation=4,
+ ),
+ expected_aborted=False, # upload succeeded
+ ),
+ ResumableUploadTestCase(
+ "Upload: intermittent 429 response: retry exhausted",
+ stream_size=100 * 1024 * 1024,
+ multipart_upload_chunk_size=1 * 1024 * 1024,
+ multipart_upload_max_retries=3,
+ custom_response_on_upload=CustomResponse(
+ code=429,
+ # 4 failures exceeding max_retries
+ first_invocation=2,
+ last_invocation=5,
+ ),
+ expected_exception_type=TooManyRequests,
+ expected_aborted=True,
+ ),
+ # -------------- abort failures --------------
+ ResumableUploadTestCase(
+ "Abort: client error",
+ stream_size=1024 * 1024,
+ # prevent chunk from being uploaded
+ custom_response_on_upload=CustomResponse(code=403),
+ # internal server error does not prevent server state change
+ custom_response_on_abort=CustomResponse(code=500),
+ expected_exception_type=PermissionDenied,
+ # abort returned error but was actually processed
+ expected_aborted=True,
+ ),
+ # -------------- file already exists --------------
+ ResumableUploadTestCase(
+ "File already exists",
+ stream_size=1024 * 1024,
+ overwrite=False,
+ custom_response_on_upload=CustomResponse(code=412, only_invocation=1),
+ expected_exception_type=AlreadyExists,
+ expected_aborted=True,
+ ),
+ # -------------- success cases --------------
+ ResumableUploadTestCase(
+ "Multiple chunks, zero unconfirmed delta",
+ stream_size=100 * 1024 * 1024,
+ multipart_upload_chunk_size=7 * 1024 * 1024 + 566,
+ # server accepts all the chunks in full
+ unconfirmed_delta=0,
+ expected_aborted=False,
+ ),
+ ResumableUploadTestCase(
+ "Multiple small chunks, zero unconfirmed delta",
+ stream_size=100 * 1024 * 1024,
+ multipart_upload_chunk_size=100 * 1024,
+ # server accepts all the chunks in full
+ unconfirmed_delta=0,
+ expected_aborted=False,
+ ),
+ ResumableUploadTestCase(
+ "Multiple chunks, non-zero unconfirmed delta",
+ stream_size=100 * 1024 * 1024,
+ multipart_upload_chunk_size=7 * 1024 * 1024 + 566,
+ # for every chunk, server accepts all except last 239 bytes
+ unconfirmed_delta=239,
+ expected_aborted=False,
+ ),
+ ResumableUploadTestCase(
+ "Multiple chunks, variable unconfirmed delta",
+ stream_size=100 * 1024 * 1024,
+ multipart_upload_chunk_size=7 * 1024 * 1024 + 566,
+ # for the first chunk, server accepts all except last 15Kib
+ # for the second chunk, server accepts it all
+ # for the 3rd chunk, server accepts all except last 25000 bytes
+ # for the 4th chunk, server accepts all except last 7 Mb
+ # for the 5th chunk onwards server accepts all except last 5 bytes
+ unconfirmed_delta=[15 * 1024, 0, 25000, 7 * 1024 * 1024, 5],
+ expected_aborted=False,
+ ),
+ ],
+ ids=ResumableUploadTestCase.to_string,
+)
+def test_resumable_upload(config: Config, test_case: ResumableUploadTestCase):
+ test_case.run(config)