Skip to content

[PECO-1751] Refactor CloudFetch downloader: handle files sequentially #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 46 additions & 161 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,41 @@
import logging

from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union

from databricks.sql.cloudfetch.downloader import (
ResultSetDownloadHandler,
DownloadableResultSettings,
DownloadedFile,
)
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)


@dataclass
class DownloadedFile:
"""
Class for the result file and metadata.

Attributes:
file_bytes (bytes): Downloaded file in bytes.
start_row_offset (int): The offset of the starting row in relation to the full result.
row_count (int): Number of rows the file represents in the result.
"""

file_bytes: bytes
start_row_offset: int
row_count: int


class ResultFileDownloadManager:
def __init__(self, max_download_threads: int, lz4_compressed: bool):
self.download_handlers: List[ResultSetDownloadHandler] = []
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self.fetch_need_retry = False
self.num_consecutive_result_file_download_retries = 0

def add_file_links(
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
) -> None:
"""
Create download handler for each cloud fetch link.

Args:
t_spark_arrow_result_links: List of cloud fetch links consisting of file URL and metadata.
"""
for link in t_spark_arrow_result_links:
def __init__(
self,
links: List[TSparkArrowResultLink],
max_download_threads: int,
lz4_compressed: bool,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
if link.rowCount <= 0:
continue
logger.debug(
"ResultFileDownloadManager.add_file_links: start offset {}, row count: {}".format(
"ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format(
link.startRowOffset, link.rowCount
)
)
self.download_handlers.append(
ResultSetDownloadHandler(self.downloadable_result_settings, link)
)
self._pending_links.append(link)

self._download_tasks: List[Future[DownloadedFile]] = []
self._max_download_threads: int = max_download_threads
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)

def get_next_downloaded_file(
self, next_row_offset: int
Expand All @@ -73,143 +52,49 @@ def get_next_downloaded_file(
Args:
next_row_offset (int): The offset of the starting row of the next file we want data from.
"""
# No more files to download from this batch of links
if not self.download_handlers:
self._shutdown_manager()
return None

# Remove handlers we don't need anymore
self._remove_past_handlers(next_row_offset)

# Schedule the downloads
# Make sure the download queue is always full
self._schedule_downloads()

# Find next file
idx = self._find_next_file_index(next_row_offset)
if idx is None:
# No more files to download from this batch of links
if len(self._download_tasks) == 0:
self._shutdown_manager()
return None
handler = self.download_handlers[idx]

# Check (and wait) for download status
if self._check_if_download_successful(handler):
link = handler.result_link
logger.debug(
"ResultFileDownloadManager: file found for row index {}: start {}, row count: {}".format(
next_row_offset, link.startRowOffset, link.rowCount
)
)
# Buffer should be empty so set buffer to new ArrowQueue with result_file
result = DownloadedFile(
handler.result_file,
handler.result_link.startRowOffset,
handler.result_link.rowCount,
)
self.download_handlers.pop(idx)
# Return True upon successful download to continue loop and not force a retry
return result
else:
task = self._download_tasks.pop(0)
# Future's `result()` method will wait for the call to complete, and return
# the value returned by the call. If the call throws an exception - `result()`
# will throw the same exception
file = task.result()
if (next_row_offset < file.start_row_offset) or (
next_row_offset > file.start_row_offset + file.row_count
):
logger.debug(
"ResultFileDownloadManager: cannot find file for row index {}".format(
next_row_offset
"ResultFileDownloadManager: file does not contain row {}, start {}, row count {}".format(
next_row_offset, file.start_row_offset, file.row_count
)
)

# Download was not successful for next download item, force a retry
self._shutdown_manager()
return None

def _remove_past_handlers(self, next_row_offset: int):
logger.debug(
"ResultFileDownloadManager: removing past handlers, current offset: {}".format(
next_row_offset
)
)
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
i = 0
while i < len(self.download_handlers):
result_link = self.download_handlers[i].result_link
logger.debug(
"- checking result link: start {}, row count: {}, current offset: {}".format(
result_link.startRowOffset, result_link.rowCount, next_row_offset
)
)
if result_link.startRowOffset + result_link.rowCount > next_row_offset:
i += 1
continue
self.download_handlers.pop(i)
return file

def _schedule_downloads(self):
# Schedule downloads for all download handlers if not already scheduled.
"""
While download queue has a capacity, peek pending links and submit them to thread pool.
"""
logger.debug("ResultFileDownloadManager: schedule downloads")
for handler in self.download_handlers:
if handler.is_download_scheduled:
continue
try:
logger.debug(
"- start: {}, row count: {}".format(
handler.result_link.startRowOffset, handler.result_link.rowCount
)
)
self.thread_pool.submit(handler.run)
except Exception as e:
logger.error(e)
break
handler.is_download_scheduled = True

def _find_next_file_index(self, next_row_offset: int):
logger.debug(
"ResultFileDownloadManager: trying to find file for row {}".format(
next_row_offset
)
)
# Get the handler index of the next file in order
next_indices = [
i
for i, handler in enumerate(self.download_handlers)
if handler.is_download_scheduled
# TODO: shouldn't `next_row_offset` be tested against the range, not just start row offset?
and handler.result_link.startRowOffset == next_row_offset
]

for i in next_indices:
link = self.download_handlers[i].result_link
while (len(self._download_tasks) < self._max_download_threads) and (
len(self._pending_links) > 0
):
link = self._pending_links.pop(0)
logger.debug(
"- found file: start {}, row count {}".format(
link.startRowOffset, link.rowCount
)
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
)

return next_indices[0] if len(next_indices) > 0 else None

def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
# Check (and wait until download finishes) if download was successful
if not handler.is_file_download_successful():
if handler.is_link_expired:
self.fetch_need_retry = True
return False
elif handler.is_download_timedout:
# Consecutive file retries should not exceed threshold in settings
if (
self.num_consecutive_result_file_download_retries
>= self.downloadable_result_settings.max_consecutive_file_download_retries
):
self.fetch_need_retry = True
return False
self.num_consecutive_result_file_download_retries += 1

# Re-submit handler run to thread pool and recursively check download status
self.thread_pool.submit(handler.run)
return self._check_if_download_successful(handler)
else:
self.fetch_need_retry = True
return False

self.num_consecutive_result_file_download_retries = 0
self.fetch_need_retry = False
return True
handler = ResultSetDownloadHandler(self._downloadable_result_settings, link)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
self.download_handlers = []
self.thread_pool.shutdown(wait=False)
self._pending_links = []
self._download_tasks = []
self._thread_pool.shutdown(wait=False)
Loading
Loading