Skip to content

Commit c07f709

Browse files
move queue and result set into SEA specific dir
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent fa2359d commit c07f709

File tree

6 files changed

+184
-291
lines changed

6 files changed

+184
-291
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 175 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
from __future__ import annotations
22

33
from abc import ABC
4-
from typing import List, Optional, Tuple
4+
from typing import List, Optional, Tuple, Union
5+
6+
try:
7+
import pyarrow
8+
except ImportError:
9+
pyarrow = None
10+
11+
import dateutil
512

613
from databricks.sql.backend.sea.backend import SeaDatabricksClient
7-
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
14+
from databricks.sql.backend.sea.models.base import (
15+
ExternalLink,
16+
ResultData,
17+
ResultManifest,
18+
)
819
from databricks.sql.backend.sea.utils.constants import ResultFormat
920
from databricks.sql.exc import ProgrammingError
10-
from databricks.sql.utils import ResultSetQueue
21+
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
22+
from databricks.sql.types import SSLOptions
23+
from databricks.sql.utils import CloudFetchQueue, ResultSetQueue
24+
25+
import logging
26+
27+
logger = logging.getLogger(__name__)
1128

1229

1330
class SeaResultSetQueueFactory(ABC):
@@ -42,8 +59,30 @@ def build_queue(
4259
return JsonQueue(sea_result_data.data)
4360
elif manifest.format == ResultFormat.ARROW_STREAM.value:
4461
# EXTERNAL_LINKS disposition
45-
raise NotImplementedError(
46-
"EXTERNAL_LINKS disposition is not implemented for SEA backend"
62+
if not max_download_threads:
63+
raise ValueError(
64+
"Max download threads is required for EXTERNAL_LINKS disposition"
65+
)
66+
if not ssl_options:
67+
raise ValueError(
68+
"SSL options are required for EXTERNAL_LINKS disposition"
69+
)
70+
if not sea_client:
71+
raise ValueError(
72+
"SEA client is required for EXTERNAL_LINKS disposition"
73+
)
74+
if not manifest:
75+
raise ValueError("Manifest is required for EXTERNAL_LINKS disposition")
76+
77+
return SeaCloudFetchQueue(
78+
initial_links=sea_result_data.external_links,
79+
max_download_threads=max_download_threads,
80+
ssl_options=ssl_options,
81+
sea_client=sea_client,
82+
statement_id=statement_id,
83+
total_chunk_count=manifest.total_chunk_count,
84+
lz4_compressed=lz4_compressed,
85+
description=description,
4786
)
4887
raise ProgrammingError("Invalid result format")
4988

@@ -69,3 +108,134 @@ def remaining_rows(self) -> List[List[str]]:
69108
slice = self.data_array[self.cur_row_index :]
70109
self.cur_row_index += len(slice)
71110
return slice
111+
112+
113+
class SeaCloudFetchQueue(CloudFetchQueue):
114+
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
115+
116+
def __init__(
117+
self,
118+
initial_links: List["ExternalLink"],
119+
max_download_threads: int,
120+
ssl_options: SSLOptions,
121+
sea_client: "SeaDatabricksClient",
122+
statement_id: str,
123+
total_chunk_count: int,
124+
lz4_compressed: bool = False,
125+
description: Optional[List[Tuple]] = None,
126+
):
127+
"""
128+
Initialize the SEA CloudFetchQueue.
129+
130+
Args:
131+
initial_links: Initial list of external links to download
132+
schema_bytes: Arrow schema bytes
133+
max_download_threads: Maximum number of download threads
134+
ssl_options: SSL options for downloads
135+
sea_client: SEA client for fetching additional links
136+
statement_id: Statement ID for the query
137+
total_chunk_count: Total number of chunks in the result set
138+
lz4_compressed: Whether the data is LZ4 compressed
139+
description: Column descriptions
140+
"""
141+
142+
super().__init__(
143+
max_download_threads=max_download_threads,
144+
ssl_options=ssl_options,
145+
schema_bytes=None,
146+
lz4_compressed=lz4_compressed,
147+
description=description,
148+
)
149+
150+
self._sea_client = sea_client
151+
self._statement_id = statement_id
152+
153+
logger.debug(
154+
"SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format(
155+
statement_id, total_chunk_count
156+
)
157+
)
158+
159+
initial_link = next((l for l in initial_links if l.chunk_index == 0), None)
160+
if not initial_link:
161+
raise ValueError("No initial link found for chunk index 0")
162+
163+
self.download_manager = ResultFileDownloadManager(
164+
links=[],
165+
max_download_threads=max_download_threads,
166+
lz4_compressed=lz4_compressed,
167+
ssl_options=ssl_options,
168+
)
169+
170+
# Track the current chunk we're processing
171+
self._current_chunk_link: Optional["ExternalLink"] = initial_link
172+
self._download_current_link()
173+
174+
# Initialize table and position
175+
self.table = self._create_next_table()
176+
177+
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
178+
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
179+
# Parse the ISO format expiration time
180+
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
181+
return TSparkArrowResultLink(
182+
fileLink=link.external_link,
183+
expiryTime=expiry_time,
184+
rowCount=link.row_count,
185+
bytesNum=link.byte_count,
186+
startRowOffset=link.row_offset,
187+
httpHeaders=link.http_headers or {},
188+
)
189+
190+
def _download_current_link(self):
191+
"""Download the current chunk link."""
192+
if not self._current_chunk_link:
193+
return None
194+
195+
if not self.download_manager:
196+
logger.debug("SeaCloudFetchQueue: No download manager, returning")
197+
return None
198+
199+
thrift_link = self._convert_to_thrift_link(self._current_chunk_link)
200+
self.download_manager.add_link(thrift_link)
201+
202+
def _progress_chunk_link(self):
203+
"""Progress to the next chunk link."""
204+
if not self._current_chunk_link:
205+
return None
206+
207+
next_chunk_index = self._current_chunk_link.next_chunk_index
208+
209+
if next_chunk_index is None:
210+
self._current_chunk_link = None
211+
return None
212+
213+
try:
214+
self._current_chunk_link = self._sea_client.get_chunk_link(
215+
self._statement_id, next_chunk_index
216+
)
217+
except Exception as e:
218+
logger.error(
219+
"SeaCloudFetchQueue: Error fetching link for chunk {}: {}".format(
220+
next_chunk_index, e
221+
)
222+
)
223+
return None
224+
225+
logger.debug(
226+
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
227+
)
228+
self._download_current_link()
229+
230+
def _create_next_table(self) -> Union["pyarrow.Table", None]:
231+
"""Create next table by retrieving the logical next downloaded file."""
232+
if not self._current_chunk_link:
233+
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
234+
return None
235+
236+
row_offset = self._current_chunk_link.row_offset
237+
arrow_table = self._create_table_at_offset(row_offset)
238+
239+
self._progress_chunk_link()
240+
241+
return arrow_table

src/databricks/sql/backend/sea/result_set.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
196196
if size < 0:
197197
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
198198

199-
if not isinstance(self.results, JsonQueue):
200-
raise NotImplementedError("fetchmany_arrow only supported for JSON data")
199+
results = self.results.next_n_rows(size)
200+
if isinstance(self.results, JsonQueue):
201+
results = self._convert_json_to_arrow_table(results)
201202

202-
results = self._convert_json_to_arrow_table(self.results.next_n_rows(size))
203203
self._next_row_index += results.num_rows
204204

205205
return results
@@ -209,10 +209,10 @@ def fetchall_arrow(self) -> "pyarrow.Table":
209209
Fetch all remaining rows as an Arrow table.
210210
"""
211211

212-
if not isinstance(self.results, JsonQueue):
213-
raise NotImplementedError("fetchall_arrow only supported for JSON data")
212+
results = self.results.remaining_rows()
213+
if isinstance(self.results, JsonQueue):
214+
results = self._convert_json_to_arrow_table(results)
214215

215-
results = self._convert_json_to_arrow_table(self.results.remaining_rows())
216216
self._next_row_index += results.num_rows
217217

218218
return results
@@ -229,7 +229,7 @@ def fetchone(self) -> Optional[Row]:
229229
if isinstance(self.results, JsonQueue):
230230
res = self._create_json_table(self.fetchmany_json(1))
231231
else:
232-
raise NotImplementedError("fetchone only supported for JSON data")
232+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
233233

234234
return res[0] if res else None
235235

@@ -250,7 +250,7 @@ def fetchmany(self, size: int) -> List[Row]:
250250
if isinstance(self.results, JsonQueue):
251251
return self._create_json_table(self.fetchmany_json(size))
252252
else:
253-
raise NotImplementedError("fetchmany only supported for JSON data")
253+
return self._convert_arrow_table(self.fetchmany_arrow(size))
254254

255255
def fetchall(self) -> List[Row]:
256256
"""
@@ -263,4 +263,4 @@ def fetchall(self) -> List[Row]:
263263
if isinstance(self.results, JsonQueue):
264264
return self._create_json_table(self.fetchall_json())
265265
else:
266-
raise NotImplementedError("fetchall only supported for JSON data")
266+
return self._convert_arrow_table(self.fetchall_arrow())

src/databricks/sql/result_set.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from databricks.sql.utils import (
2626
ColumnTable,
2727
ColumnQueue,
28-
JsonQueue,
29-
SeaResultSetQueueFactory,
3028
)
3129
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
3230

0 commit comments

Comments
 (0)