Skip to content

Commit a0705bc

Browse files
add fetchmany_arrow and fetchall_arrow
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 715cc13 commit a0705bc

File tree

2 files changed

+42
-76
lines changed

2 files changed

+42
-76
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -302,74 +302,6 @@ def get_allowed_session_configurations() -> List[str]:
302302
"""
303303
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
304304

305-
def _get_schema_bytes(self, sea_response) -> Optional[bytes]:
306-
"""
307-
Extract schema bytes from the SEA response.
308-
309-
For ARROW format, we need to get the schema bytes from the first chunk.
310-
If the first chunk is not available, we need to get it from the server.
311-
312-
Args:
313-
sea_response: The response from the SEA API
314-
315-
Returns:
316-
bytes: The schema bytes or None if not available
317-
"""
318-
import requests
319-
import lz4.frame
320-
321-
# Check if we have the first chunk in the response
322-
result_data = sea_response.get("result", {})
323-
external_links = result_data.get("external_links", [])
324-
325-
if not external_links:
326-
return None
327-
328-
# Find the first chunk (chunk_index = 0)
329-
first_chunk = None
330-
for link in external_links:
331-
if link.get("chunk_index") == 0:
332-
first_chunk = link
333-
break
334-
335-
if not first_chunk:
336-
# Try to fetch the first chunk from the server
337-
statement_id = sea_response.get("statement_id")
338-
if not statement_id:
339-
return None
340-
341-
chunks_response = self.get_chunk_links(statement_id, 0)
342-
if not chunks_response.external_links:
343-
return None
344-
345-
first_chunk = chunks_response.external_links[0].__dict__
346-
347-
# Download the first chunk to get the schema bytes
348-
external_link = first_chunk.get("external_link")
349-
http_headers = first_chunk.get("http_headers", {})
350-
351-
if not external_link:
352-
return None
353-
354-
# Use requests to download the first chunk
355-
http_response = requests.get(
356-
external_link,
357-
headers=http_headers,
358-
verify=self.ssl_options.tls_verify,
359-
)
360-
361-
if http_response.status_code != 200:
362-
raise Error(f"Failed to download schema bytes: {http_response.text}")
363-
364-
# Extract schema bytes from the Arrow file
365-
# The schema is at the beginning of the file
366-
data = http_response.content
367-
if sea_response.get("manifest", {}).get("result_compression") == "LZ4_FRAME":
368-
data = lz4.frame.decompress(data)
369-
370-
# Return the schema bytes
371-
return data
372-
373305
def _results_message_to_execute_response(self, sea_response, command_id):
374306
"""
375307
Convert a SEA response to an ExecuteResponse and extract result data.
@@ -412,13 +344,6 @@ def _results_message_to_execute_response(self, sea_response, command_id):
412344
)
413345
description = columns if columns else None
414346

415-
# Extract schema bytes for Arrow format
416-
schema_bytes = None
417-
format = manifest_data.get("format")
418-
if format == "ARROW_STREAM":
419-
# For ARROW format, we need to get the schema bytes
420-
schema_bytes = self._get_schema_bytes(sea_response)
421-
422347
# Check for compression
423348
lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME"
424349

@@ -473,7 +398,7 @@ def _results_message_to_execute_response(self, sea_response, command_id):
473398
has_been_closed_server_side=False,
474399
lz4_compressed=lz4_compressed,
475400
is_staging_operation=False,
476-
arrow_schema_bytes=schema_bytes,
401+
arrow_schema_bytes=None,
477402
result_format=manifest_data.get("format"),
478403
)
479404

src/databricks/sql/result_set.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ def fetchall(self) -> List[Row]:
154154
"""Fetch all remaining rows of a query result."""
155155
pass
156156

157+
@abstractmethod
158+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
159+
"""Fetch the next set of rows as an Arrow table."""
160+
pass
161+
162+
@abstractmethod
163+
def fetchall_arrow(self) -> "pyarrow.Table":
164+
"""Fetch all remaining rows as an Arrow table."""
165+
pass
166+
157167
def close(self) -> None:
158168
"""
159169
Close the result set.
@@ -537,6 +547,37 @@ def fetchall_json(self):
537547

538548
return results
539549

550+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
551+
"""
552+
Fetch the next set of rows as an Arrow table.
553+
554+
Args:
555+
size: Number of rows to fetch
556+
557+
Returns:
558+
PyArrow Table containing the fetched rows
559+
560+
Raises:
561+
ImportError: If PyArrow is not installed
562+
ValueError: If size is negative
563+
"""
564+
if size < 0:
565+
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
566+
567+
results = self.results.next_n_rows(size)
568+
self._next_row_index += results.num_rows
569+
570+
return results
571+
572+
def fetchall_arrow(self) -> "pyarrow.Table":
573+
"""
574+
Fetch all remaining rows as an Arrow table.
575+
"""
576+
results = self.results.remaining_rows()
577+
self._next_row_index += results.num_rows
578+
579+
return results
580+
540581
def fetchone(self) -> Optional[Row]:
541582
"""
542583
Fetch the next row of a query result set, returning a single sequence,

0 commit comments

Comments
 (0)