diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index a8311ee3..aa09a04c 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC -from typing import List, Optional, Tuple, Union +import threading +from typing import Dict, List, Optional, Tuple, Union from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager @@ -111,6 +112,87 @@ def remaining_rows(self) -> List[List[str]]: return slice +class LinkFetcher: + def __init__( + self, + download_manager: ResultFileDownloadManager, + backend: "SeaDatabricksClient", + statement_id: str, + current_chunk_link: Optional["ExternalLink"] = None, + ): + self.download_manager = download_manager + self.backend = backend + self._statement_id = statement_id + self._current_chunk_link = current_chunk_link + + self._shutdown_event = threading.Event() + + self._map_lock = threading.Lock() + self.chunk_index_to_link: Dict[int, "ExternalLink"] = {} + + def _set_current_chunk_link(self): + link = self._current_chunk_link + with self._map_lock: + self.chunk_index_to_link[link.chunk_index] = link + + def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: + with self._map_lock: + return self.chunk_index_to_link.get(chunk_index, None) + + def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink: + """Convert SEA external links to Thrift format for compatibility with existing download manager.""" + # Parse the ISO format expiration time + expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) + return TSparkArrowResultLink( + fileLink=link.external_link, + expiryTime=expiry_time, + rowCount=link.row_count, + bytesNum=link.byte_count, + startRowOffset=link.row_offset, + httpHeaders=link.http_headers or {}, + ) + + def _progress_chunk_link(self): + """Progress to the next chunk link.""" + if not self._current_chunk_link: + return None + + next_chunk_index = self._current_chunk_link.next_chunk_index + + if next_chunk_index is None: + self._current_chunk_link = None + return None + + try: + self._current_chunk_link = self.backend.get_chunk_link( + self._statement_id, next_chunk_index + ) + except Exception as e: + logger.error( + "LinkFetcher: Error fetching link for chunk {}: {}".format( + next_chunk_index, e + ) + ) + self._current_chunk_link = None + + def _worker_loop(self): + while not (self._shutdown_event.is_set() or self._current_chunk_link is None): + self._set_current_chunk_link(self._current_chunk_link) + self.download_manager.add_link( + self._convert_to_thrift_link(self._current_chunk_link) + ) + + self._progress_chunk_link() + + def start(self): + self._worker_thread = threading.Thread(target=self._worker_loop) + self._worker_thread.start() + + def stop(self): + self._shutdown_event.set() + self._worker_thread.join() + + class SeaCloudFetchQueue(CloudFetchQueue): """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend.""" @@ -160,6 +242,7 @@ def __init__( initial_link = next((l for l in initial_links if l.chunk_index == 0), None) if not initial_link: return + self.current_chunk_index = initial_link.chunk_index self.download_manager = ResultFileDownloadManager( links=[], @@ -168,75 +251,23 @@ def __init__( ssl_options=ssl_options, ) - # Track the current chunk we're processing - self._current_chunk_link: Optional["ExternalLink"] = initial_link - self._download_current_link() + self.link_fetcher = LinkFetcher( + self.download_manager, self._sea_client, statement_id, initial_link + ) + self.link_fetcher.start() # Initialize table and position self.table = self._create_next_table() - def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink: - """Convert SEA external links to Thrift format for compatibility with existing download manager.""" - # Parse the ISO format expiration time - expiry_time = int(dateutil.parser.parse(link.expiration).timestamp()) - return TSparkArrowResultLink( - fileLink=link.external_link, - expiryTime=expiry_time, - rowCount=link.row_count, - bytesNum=link.byte_count, - startRowOffset=link.row_offset, - httpHeaders=link.http_headers or {}, - ) - - def _download_current_link(self): - """Download the current chunk link.""" - if not self._current_chunk_link: - return None - - if not self.download_manager: - logger.debug("SeaCloudFetchQueue: No download manager, returning") - return None - - thrift_link = self._convert_to_thrift_link(self._current_chunk_link) - self.download_manager.add_link(thrift_link) - - def _progress_chunk_link(self): - """Progress to the next chunk link.""" - if not self._current_chunk_link: - return None - - next_chunk_index = self._current_chunk_link.next_chunk_index - - if next_chunk_index is None: - self._current_chunk_link = None - return None - - try: - self._current_chunk_link = self._sea_client.get_chunk_link( - self._statement_id, next_chunk_index - ) - except Exception as e: - logger.error( - "SeaCloudFetchQueue: Error fetching link for chunk {}: {}".format( - next_chunk_index, e - ) - ) - return None - - logger.debug( - f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}" - ) - self._download_current_link() - def _create_next_table(self) -> Union["pyarrow.Table", None]: """Create next table by retrieving the logical next downloaded file.""" - if not self._current_chunk_link: - logger.debug("SeaCloudFetchQueue: No current chunk link, returning") + current_chunk_link = self.link_fetcher.get_chunk_link(self.current_chunk_index) + if not current_chunk_link: return None - row_offset = self._current_chunk_link.row_offset + row_offset = current_chunk_link.row_offset arrow_table = self._create_table_at_offset(row_offset) - self._progress_chunk_link() + self.current_chunk_index = current_chunk_link.next_chunk_index return arrow_table