diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 60fa3c75f..93b6f623c 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,62 +1,41 @@ import logging -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor, Future from typing import List, Union from databricks.sql.cloudfetch.downloader import ( ResultSetDownloadHandler, DownloadableResultSettings, + DownloadedFile, ) from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) -@dataclass -class DownloadedFile: - """ - Class for the result file and metadata. - - Attributes: - file_bytes (bytes): Downloaded file in bytes. - start_row_offset (int): The offset of the starting row in relation to the full result. - row_count (int): Number of rows the file represents in the result. - """ - - file_bytes: bytes - start_row_offset: int - row_count: int - - class ResultFileDownloadManager: - def __init__(self, max_download_threads: int, lz4_compressed: bool): - self.download_handlers: List[ResultSetDownloadHandler] = [] - self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1) - self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed) - self.fetch_need_retry = False - self.num_consecutive_result_file_download_retries = 0 - - def add_file_links( - self, t_spark_arrow_result_links: List[TSparkArrowResultLink] - ) -> None: - """ - Create download handler for each cloud fetch link. - - Args: - t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata. - """ - for link in t_spark_arrow_result_links: + def __init__( + self, + links: List[TSparkArrowResultLink], + max_download_threads: int, + lz4_compressed: bool, + ): + self._pending_links: List[TSparkArrowResultLink] = [] + for link in links: if link.rowCount <= 0: continue logger.debug( - "ResultFileDownloadManager.add_file_links: start offset {}, row count: {}".format( + "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( link.startRowOffset, link.rowCount ) ) - self.download_handlers.append( - ResultSetDownloadHandler(self.downloadable_result_settings, link) - ) + self._pending_links.append(link) + + self._download_tasks: List[Future[DownloadedFile]] = [] + self._max_download_threads: int = max_download_threads + self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) + + self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) def get_next_downloaded_file( self, next_row_offset: int @@ -73,143 +52,49 @@ def get_next_downloaded_file( Args: next_row_offset (int): The offset of the starting row of the next file we want data from. """ - # No more files to download from this batch of links - if not self.download_handlers: - self._shutdown_manager() - return None - - # Remove handlers we don't need anymore - self._remove_past_handlers(next_row_offset) - # Schedule the downloads + # Make sure the download queue is always full self._schedule_downloads() - # Find next file - idx = self._find_next_file_index(next_row_offset) - if idx is None: + # No more files to download from this batch of links + if len(self._download_tasks) == 0: self._shutdown_manager() return None - handler = self.download_handlers[idx] - # Check (and wait) for download status - if self._check_if_download_successful(handler): - link = handler.result_link - logger.debug( - "ResultFileDownloadManager: file found for row index {}: start {}, row count: {}".format( - next_row_offset, link.startRowOffset, link.rowCount - ) - ) - # Buffer should be empty so set buffer to new ArrowQueue with result_file - result = DownloadedFile( - handler.result_file, - handler.result_link.startRowOffset, - handler.result_link.rowCount, - ) - self.download_handlers.pop(idx) - # Return True upon successful download to continue loop and not force a retry - return result - else: + task = self._download_tasks.pop(0) + # Future's `result()` method will wait for the call to complete, and return + # the value returned by the call. If the call throws an exception - `result()` + # will throw the same exception + file = task.result() + if (next_row_offset < file.start_row_offset) or ( + next_row_offset > file.start_row_offset + file.row_count + ): logger.debug( - "ResultFileDownloadManager: cannot find file for row index {}".format( - next_row_offset + "ResultFileDownloadManager: file does not contain row {}, start {}, row count {}".format( + next_row_offset, file.start_row_offset, file.row_count ) ) - # Download was not successful for next download item, force a retry - self._shutdown_manager() - return None - - def _remove_past_handlers(self, next_row_offset: int): - logger.debug( - "ResultFileDownloadManager: removing past handlers, current offset: {}".format( - next_row_offset - ) - ) - # Any link in which its start to end range doesn't include the next row to be fetched does not need downloading - i = 0 - while i < len(self.download_handlers): - result_link = self.download_handlers[i].result_link - logger.debug( - "- checking result link: start {}, row count: {}, current offset: {}".format( - result_link.startRowOffset, result_link.rowCount, next_row_offset - ) - ) - if result_link.startRowOffset + result_link.rowCount > next_row_offset: - i += 1 - continue - self.download_handlers.pop(i) + return file def _schedule_downloads(self): - # Schedule downloads for all download handlers if not already scheduled. + """ + While download queue has a capacity, peek pending links and submit them to thread pool. + """ logger.debug("ResultFileDownloadManager: schedule downloads") - for handler in self.download_handlers: - if handler.is_download_scheduled: - continue - try: - logger.debug( - "- start: {}, row count: {}".format( - handler.result_link.startRowOffset, handler.result_link.rowCount - ) - ) - self.thread_pool.submit(handler.run) - except Exception as e: - logger.error(e) - break - handler.is_download_scheduled = True - - def _find_next_file_index(self, next_row_offset: int): - logger.debug( - "ResultFileDownloadManager: trying to find file for row {}".format( - next_row_offset - ) - ) - # Get the handler index of the next file in order - next_indices = [ - i - for i, handler in enumerate(self.download_handlers) - if handler.is_download_scheduled - # TODO: shouldn't `next_row_offset` be tested against the range, not just start row offset? - and handler.result_link.startRowOffset == next_row_offset - ] - - for i in next_indices: - link = self.download_handlers[i].result_link + while (len(self._download_tasks) < self._max_download_threads) and ( + len(self._pending_links) > 0 + ): + link = self._pending_links.pop(0) logger.debug( - "- found file: start {}, row count {}".format( - link.startRowOffset, link.rowCount - ) + "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) ) - - return next_indices[0] if len(next_indices) > 0 else None - - def _check_if_download_successful(self, handler: ResultSetDownloadHandler): - # Check (and wait until download finishes) if download was successful - if not handler.is_file_download_successful(): - if handler.is_link_expired: - self.fetch_need_retry = True - return False - elif handler.is_download_timedout: - # Consecutive file retries should not exceed threshold in settings - if ( - self.num_consecutive_result_file_download_retries - >= self.downloadable_result_settings.max_consecutive_file_download_retries - ): - self.fetch_need_retry = True - return False - self.num_consecutive_result_file_download_retries += 1 - - # Re-submit handler run to thread pool and recursively check download status - self.thread_pool.submit(handler.run) - return self._check_if_download_successful(handler) - else: - self.fetch_need_retry = True - return False - - self.num_consecutive_result_file_download_retries = 0 - self.fetch_need_retry = False - return True + handler = ResultSetDownloadHandler(self._downloadable_result_settings, link) + task = self._thread_pool.submit(handler.run) + self._download_tasks.append(task) def _shutdown_manager(self): # Clear download handlers and shutdown the thread pool - self.download_handlers = [] - self.thread_pool.shutdown(wait=False) + self._pending_links = [] + self._download_tasks = [] + self._thread_pool.shutdown(wait=False) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 019c4ef92..61ae26ac5 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -2,14 +2,43 @@ from dataclasses import dataclass import requests +from requests.adapters import HTTPAdapter, Retry import lz4.frame -import threading import time from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +from databricks.sql.exc import Error + logger = logging.getLogger(__name__) +# TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. +# But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests +retryPolicy = Retry( + total=5, # max retry attempts + backoff_factor=1, # min delay, 1 second + backoff_max=60, # max delay, 60 seconds + # retry all status codes below 100, 429 (Too Many Requests), and all codes above 500, + # excluding 501 Not implemented + status_forcelist=[*range(0, 101), 429, 500, *range(502, 1000)], +) + + +@dataclass +class DownloadedFile: + """ + Class for the result file and metadata. + + Attributes: + file_bytes (bytes): Downloaded file in bytes. + start_row_offset (int): The offset of the starting row in relation to the full result. + row_count (int): Number of rows the file represents in the result. + """ + + file_bytes: bytes + start_row_offset: int + row_count: int + @dataclass class DownloadableResultSettings: @@ -29,111 +58,78 @@ class DownloadableResultSettings: max_consecutive_file_download_retries: int = 0 -class ResultSetDownloadHandler(threading.Thread): +class ResultSetDownloadHandler: def __init__( self, - downloadable_result_settings: DownloadableResultSettings, - t_spark_arrow_result_link: TSparkArrowResultLink, + settings: DownloadableResultSettings, + link: TSparkArrowResultLink, ): - super().__init__() - self.settings = downloadable_result_settings - self.result_link = t_spark_arrow_result_link - self.is_download_scheduled = False - self.is_download_finished = threading.Event() - self.is_file_downloaded_successfully = False - self.is_link_expired = False - self.is_download_timedout = False - self.result_file = None - - def is_file_download_successful(self) -> bool: - """ - Check and report if cloud fetch file downloaded successfully. - - This function will block until a file download finishes or until a timeout. - """ - timeout = ( - self.settings.download_timeout - if self.settings.download_timeout > 0 - else None - ) - try: - if not self.is_download_finished.wait(timeout=timeout): - self.is_download_timedout = True - logger.debug( - "Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( - self.settings.download_timeout, - self.result_link.startRowOffset, - self.result_link.startRowOffset + self.result_link.rowCount, - ) - ) - return False - except Exception as e: - logger.error(e) - return False - return self.is_file_downloaded_successfully + self.settings = settings + self.link = link - def run(self): + def run(self) -> DownloadedFile: """ Download the file described in the cloud fetch link. This function checks if the link has or is expiring, gets the file via a requests session, decompresses the file, and signals to waiting threads that the download is finished and whether it was successful. """ - self._reset() + + logger.debug( + "ResultSetDownloadHandler: starting file download, offset {}, row count {}".format( + self.link.startRowOffset, self.link.rowCount + ) + ) # Check if link is already expired or is expiring - if ResultSetDownloadHandler.check_link_expired( - self.result_link, self.settings.link_expiry_buffer_secs - ): - self.is_link_expired = True - return + ResultSetDownloadHandler._validate_link( + self.link, self.settings.link_expiry_buffer_secs + ) session = requests.Session() - session.timeout = self.settings.download_timeout + session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) + session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) try: # Get the file via HTTP request - response = session.get(self.result_link.fileLink) - - if not response.ok: - self.is_file_downloaded_successfully = False - return + response = session.get( + self.link.fileLink, timeout=self.settings.download_timeout + ) + response.raise_for_status() # Save (and decompress if needed) the downloaded file compressed_data = response.content decompressed_data = ( - ResultSetDownloadHandler.decompress_data(compressed_data) + ResultSetDownloadHandler._decompress_data(compressed_data) if self.settings.is_lz4_compressed else compressed_data ) - self.result_file = decompressed_data # The size of the downloaded file should match the size specified from TSparkArrowResultLink - self.is_file_downloaded_successfully = ( - len(self.result_file) == self.result_link.bytesNum + if len(decompressed_data) != self.link.bytesNum: + logger.debug( + "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( + len(decompressed_data), self.link.bytesNum + ) + ) + + logger.debug( + "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( + self.link.startRowOffset, self.link.rowCount + ) ) - except Exception as e: - logger.error(e) - self.is_file_downloaded_successfully = False + return DownloadedFile( + decompressed_data, + self.link.startRowOffset, + self.link.rowCount, + ) finally: - session and session.close() - # Awaken threads waiting for this to be true which signals the run is complete - self.is_download_finished.set() - - def _reset(self): - """ - Reset download-related flags for every retry of run() - """ - self.is_file_downloaded_successfully = False - self.is_link_expired = False - self.is_download_timedout = False - self.is_download_finished = threading.Event() + if session: + session.close() @staticmethod - def check_link_expired( - link: TSparkArrowResultLink, expiry_buffer_secs: int - ) -> bool: + def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): """ Check if a link has expired or will expire. @@ -142,14 +138,13 @@ def check_link_expired( """ current_time = int(time.time()) if ( - link.expiryTime < current_time - or link.expiryTime - current_time < expiry_buffer_secs + link.expiryTime <= current_time + or link.expiryTime - current_time <= expiry_buffer_secs ): - return True - return False + raise Error("CloudFetch link has expired") @staticmethod - def decompress_data(compressed_data: bytes) -> bytes: + def _decompress_data(compressed_data: bytes) -> bytes: """ Decompress lz4 frame compressed data. diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 6063eca1a..e3e1696e5 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -169,9 +169,8 @@ def __init__( ) ) self.download_manager = ResultFileDownloadManager( - self.max_download_threads, self.lz4_compressed + result_links or [], self.max_download_threads, self.lz4_compressed ) - self.download_manager.add_file_links(result_links or []) self.table = self._create_next_table() self.table_row_index = 0 diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index e5611ce62..e9dfd712d 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -49,7 +49,8 @@ def test_initializer_adds_links(self, mock_create_next_table): result_links = self.create_result_links(10) queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) - assert len(queue.download_manager.download_handlers) == 10 + assert len(queue.download_manager._pending_links) == 10 + assert len(queue.download_manager._download_tasks) == 0 mock_create_next_table.assert_called() def test_initializer_no_links_to_add(self): @@ -57,7 +58,8 @@ def test_initializer_no_links_to_add(self): result_links = [] queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) - assert len(queue.download_manager.download_handlers) == 0 + assert len(queue.download_manager._pending_links) == 0 + assert len(queue.download_manager._download_tasks) == 0 assert queue.table is None @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) @@ -65,7 +67,7 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): queue = utils.CloudFetchQueue(MagicMock(), result_links=[], max_download_threads=10) assert queue._create_next_table() is None - assert mock_get_next_downloaded_file.called_with(0) + mock_get_next_downloaded_file.assert_called_with(0) @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", @@ -76,8 +78,8 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) expected_result = self.make_arrow_table() - assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description) - assert mock_get_next_downloaded_file.called_with(0) + mock_get_next_downloaded_file.assert_called_with(0) + mock_create_arrow_table.assert_called_with(b"1234567890", description) assert queue.table == expected_result assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -130,20 +132,6 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") - def test_next_n_rows_more_than_one_table(self, mock_create_next_table): - mock_create_next_table.return_value = self.make_arrow_table() - schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) - assert queue.table == self.make_arrow_table() - assert queue.table.num_rows == 4 - assert queue.table_row_index == 0 - - result = queue.next_n_rows(7) - assert result.num_rows == 7 - assert queue.table_row_index == 3 - assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] @@ -165,6 +153,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): assert queue.table is None result = queue.next_n_rows(100) + mock_create_next_table.assert_called() assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 97bf407aa..7a35e65aa 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -2,7 +2,6 @@ from unittest.mock import patch, MagicMock import databricks.sql.cloudfetch.download_manager as download_manager -import databricks.sql.cloudfetch.downloader as downloader from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink @@ -11,10 +10,8 @@ class DownloadManagerTests(unittest.TestCase): Unit tests for checking download manager logic. """ - def create_download_manager(self): - max_download_threads = 10 - lz4_compressed = True - return download_manager.ResultFileDownloadManager(max_download_threads, lz4_compressed) + def create_download_manager(self, links, max_download_threads=10, lz4_compressed=True): + return download_manager.ResultFileDownloadManager(links, max_download_threads, lz4_compressed) def create_result_link( self, @@ -36,172 +33,25 @@ def create_result_links(self, num_files: int, start_row_offset: int = 0): def test_add_file_links_zero_row_count(self): links = [self.create_result_link(row_count=0, bytes_num=0)] - manager = self.create_download_manager() - manager.add_file_links(links) + manager = self.create_download_manager(links) - assert not manager.download_handlers + assert len(manager._pending_links) == 0 # the only link supplied contains no data, so should be skipped + assert len(manager._download_tasks) == 0 def test_add_file_links_success(self): links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) + manager = self.create_download_manager(links) - assert len(manager.download_handlers) == 10 - - def test_remove_past_handlers_one(self): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - - manager._remove_past_handlers(8000) - assert len(manager.download_handlers) == 9 - - def test_remove_past_handlers_all(self): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - - manager._remove_past_handlers(8000*10) - assert len(manager.download_handlers) == 0 - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_schedule_downloads_partial_already_scheduled(self, mock_submit): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - - for i in range(5): - manager.download_handlers[i].is_download_scheduled = True - - manager._schedule_downloads() - assert mock_submit.call_count == 5 - assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10 - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_schedule_downloads_will_not_schedule_twice(self, mock_submit): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - - for i in range(5): - manager.download_handlers[i].is_download_scheduled = True - - manager._schedule_downloads() - assert mock_submit.call_count == 5 - assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 10 - - manager._schedule_downloads() - assert mock_submit.call_count == 5 - - @patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")]) - def test_schedule_downloads_submit_fails(self, mock_submit): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - - manager._schedule_downloads() - assert mock_submit.call_count == 2 - assert sum([1 if handler.is_download_scheduled else 0 for handler in manager.download_handlers]) == 1 - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_find_next_file_index_all_scheduled_next_row_0(self, mock_submit): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - manager._schedule_downloads() - - assert manager._find_next_file_index(0) == 0 + assert len(manager._pending_links) == len(links) + assert len(manager._download_tasks) == 0 @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_find_next_file_index_all_scheduled_next_row_7999(self, mock_submit): + def test_schedule_downloads(self, mock_submit): + max_download_threads = 4 links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - manager._schedule_downloads() - - assert manager._find_next_file_index(7999) is None - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_find_next_file_index_all_scheduled_next_row_8000(self, mock_submit): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - manager._schedule_downloads() - - assert manager._find_next_file_index(8000) == 1 - - @patch("concurrent.futures.ThreadPoolExecutor.submit", side_effect=[True, KeyError("foo")]) - def test_find_next_file_index_one_scheduled_next_row_8000(self, mock_submit): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) - manager._schedule_downloads() + manager = self.create_download_manager(links, max_download_threads=max_download_threads) - assert manager._find_next_file_index(8000) is None - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=True) - @patch("concurrent.futures.ThreadPoolExecutor.submit") - def test_check_if_download_successful_happy(self, mock_submit, mock_is_file_download_successful): - links = self.create_result_links(num_files=10) - manager = self.create_download_manager() - manager.add_file_links(links) manager._schedule_downloads() - - status = manager._check_if_download_successful(manager.download_handlers[0]) - assert status - assert manager.num_consecutive_result_file_download_retries == 0 - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_link_expired(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_link_expired = True - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_download_timed_out_no_retries(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_download_timedout = True - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry - - @patch("concurrent.futures.ThreadPoolExecutor.submit") - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_download_timed_out_1_retry(self, mock_is_file_download_successful, mock_submit): - manager = self.create_download_manager() - manager.downloadable_result_settings = download_manager.DownloadableResultSettings( - is_lz4_compressed=True, - download_timeout=0, - max_consecutive_file_download_retries=1, - ) - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - handler.is_download_timedout = True - - status = manager._check_if_download_successful(handler) - assert mock_is_file_download_successful.call_count == 2 - assert mock_submit.call_count == 1 - assert not status - assert manager.fetch_need_retry - - @patch("databricks.sql.cloudfetch.downloader.ResultSetDownloadHandler.is_file_download_successful", - return_value=False) - def test_check_if_download_successful_other_reason(self, mock_is_file_download_successful): - manager = self.create_download_manager() - handler = downloader.ResultSetDownloadHandler(manager.downloadable_result_settings, self.create_result_link()) - - status = manager._check_if_download_successful(handler) - mock_is_file_download_successful.assert_called() - assert not status - assert manager.fetch_need_retry + assert mock_submit.call_count == max_download_threads + assert len(manager._pending_links) == len(links) - max_download_threads + assert len(manager._download_tasks) == max_download_threads diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 6e13c9496..e138cdbb9 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,7 +1,17 @@ import unittest from unittest.mock import Mock, patch, MagicMock +import requests + import databricks.sql.cloudfetch.downloader as downloader +from databricks.sql.exc import Error + + +def create_response(**kwargs) -> requests.Response: + result = requests.Response() + for k, v in kwargs.items(): + setattr(result, k, v) + return result class DownloaderTests(unittest.TestCase): @@ -16,9 +26,11 @@ def test_run_link_expired(self, mock_time): # Already expired result_link.expiryTime = 999 d = downloader.ResultSetDownloadHandler(settings, result_link) - assert not d.is_link_expired - d.run() - assert d.is_link_expired + + with self.assertRaises(Error) as context: + d.run() + self.assertTrue('link has expired' in context.exception.message) + mock_time.assert_called_once() @patch('time.time', return_value=1000) @@ -28,83 +40,58 @@ def test_run_link_past_expiry_buffer(self, mock_time): # Within the expiry buffer time result_link.expiryTime = 1004 d = downloader.ResultSetDownloadHandler(settings, result_link) - assert not d.is_link_expired - d.run() - assert d.is_link_expired + + with self.assertRaises(Error) as context: + d.run() + self.assertTrue('link has expired' in context.exception.message) + mock_time.assert_called_once() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=None))) @patch('time.time', return_value=1000) def test_run_get_response_not_ok(self, mock_time, mock_session): + mock_session.return_value.get.return_value = create_response(status_code=404) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() - - assert not d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() - - @patch('requests.Session', - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 9)))) - @patch('time.time', return_value=1000) - def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False) - result_link = Mock(bytesNum=100, expiryTime=1001) - - d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() - - assert not d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() - - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) - @patch('time.time', return_value=1000) - def test_run_compressed_data_length_incorrect(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) - settings.is_lz4_compressed = True - result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00' - - d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() - - assert not d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() + with self.assertRaises(requests.exceptions.HTTPError) as context: + d.run() + self.assertTrue('404' in str(context.exception)) - @patch('requests.Session', - return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 10)))) + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=None))) @patch('time.time', return_value=1000) def test_run_uncompressed_successful(self, mock_time, mock_session): + file_bytes = b"1234567890" * 10 + mock_session.return_value.get.return_value = create_response(status_code=200, _content=file_bytes) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() + file = d.run() - assert d.result_file == b"1234567890" * 10 - assert d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() + assert file.file_bytes == b"1234567890" * 10 @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) @patch('time.time', return_value=1000) def test_run_compressed_successful(self, mock_time, mock_session): + file_bytes = b"1234567890" * 10 + compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + mock_session.return_value.get.return_value = create_response(status_code=200, _content=compressed_bytes) + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() + file = d.run() - assert d.result_file == b"1234567890" * 10 - assert d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() + assert file.file_bytes == b"1234567890" * 10 @patch('requests.Session.get', side_effect=ConnectionError('foo')) @patch('time.time', return_value=1000) @@ -115,10 +102,8 @@ def test_download_connection_error(self, mock_time, mock_session): b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() - - assert not d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() + with self.assertRaises(ConnectionError): + d.run() @patch('requests.Session.get', side_effect=TimeoutError('foo')) @patch('time.time', return_value=1000) @@ -129,27 +114,5 @@ def test_download_timeout(self, mock_time, mock_session): b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' d = downloader.ResultSetDownloadHandler(settings, result_link) - d.run() - - assert not d.is_file_downloaded_successfully - assert d.is_download_finished.is_set() - - @patch("threading.Event.wait", return_value=True) - def test_is_file_download_successful_has_finished(self, mock_wait): - for timeout in [0, 1]: - with self.subTest(timeout=timeout): - settings = Mock(download_timeout=timeout) - result_link = Mock() - handler = downloader.ResultSetDownloadHandler(settings, result_link) - - status = handler.is_file_download_successful() - assert status == handler.is_file_downloaded_successfully - - def test_is_file_download_successful_times_outs(self): - settings = Mock(download_timeout=1) - result_link = Mock() - handler = downloader.ResultSetDownloadHandler(settings, result_link) - - status = handler.is_file_download_successful() - assert not status - assert handler.is_download_timedout + with self.assertRaises(TimeoutError): + d.run()