Skip to content

Commit 71b451a

Browse files
minimal fetch phase intro
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 33821f4 commit 71b451a

File tree

4 files changed

+186
-15
lines changed

4 files changed

+186
-15
lines changed

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def test_sea_sync_query_without_cloud_fetch():
122122
cursor.execute("SELECT 1 as test_value")
123123
logger.info("Query executed successfully with cloud fetch disabled")
124124

125+
rows = cursor.fetchall()
126+
logger.info(f"Rows: {rows}")
127+
125128
# Close resources
126129
cursor.close()
127130
connection.close()

src/databricks/sql/backend/thrift_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
)
4343

4444
from databricks.sql.utils import (
45-
ResultSetQueueFactory,
45+
ThriftResultSetQueueFactory,
4646
_bound,
4747
RequestErrorInfo,
4848
NoRetryReason,
49-
ResultSetQueueFactory,
49+
ThriftResultSetQueueFactory,
5050
convert_arrow_based_set_to_arrow_table,
5151
convert_decimals_in_arrow_table,
5252
convert_column_based_set_to_arrow_table,
@@ -784,7 +784,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
784784
assert direct_results.resultSet.results.startRowOffset == 0
785785
assert direct_results.resultSetMetadata
786786

787-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
787+
arrow_queue_opt = ThriftResultSetQueueFactory.build_queue(
788788
row_set_type=t_result_set_metadata_resp.resultFormat,
789789
t_row_set=direct_results.resultSet.results,
790790
arrow_schema_bytes=schema_bytes,
@@ -857,7 +857,7 @@ def get_execution_result(
857857
else:
858858
schema_bytes = None
859859

860-
queue = ResultSetQueueFactory.build_queue(
860+
queue = ThriftResultSetQueueFactory.build_queue(
861861
row_set_type=resp.resultSetMetadata.resultFormat,
862862
t_row_set=resp.results,
863863
arrow_schema_bytes=schema_bytes,
@@ -1225,7 +1225,7 @@ def fetch_results(
12251225
)
12261226
)
12271227

1228-
queue = ResultSetQueueFactory.build_queue(
1228+
queue = ThriftResultSetQueueFactory.build_queue(
12291229
row_set_type=resp.resultSetMetadata.resultFormat,
12301230
t_row_set=resp.results,
12311231
arrow_schema_bytes=arrow_schema_bytes,

src/databricks/sql/result_set.py

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas
77

88
from databricks.sql.backend.sea.backend import SeaDatabricksClient
9+
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
910

1011
try:
1112
import pyarrow
@@ -19,7 +20,7 @@
1920
from databricks.sql.thrift_api.TCLIService import ttypes
2021
from databricks.sql.types import Row
2122
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
22-
from databricks.sql.utils import ColumnTable, ColumnQueue
23+
from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue, SeaResultSetQueueFactory
2324
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2425

2526
logger = logging.getLogger(__name__)
@@ -441,6 +442,14 @@ def __init__(
441442
sea_response: Direct SEA response (legacy style)
442443
"""
443444

445+
queue = SeaResultSetQueueFactory.build_queue(
446+
sea_result_data=execute_response.results_data,
447+
manifest=execute_response.results_manifest,
448+
statement_id=execute_response.command_id.to_sea_statement_id(),
449+
description=execute_response.description,
450+
schema_bytes=execute_response.arrow_schema_bytes,
451+
)
452+
444453
super().__init__(
445454
connection=connection,
446455
backend=sea_client,
@@ -450,42 +459,135 @@ def __init__(
450459
status=execute_response.status,
451460
has_been_closed_server_side=execute_response.has_been_closed_server_side,
452461
has_more_rows=execute_response.has_more_rows,
453-
results_queue=execute_response.results_queue,
462+
results_queue=queue,
454463
description=execute_response.description,
455464
is_staging_operation=execute_response.is_staging_operation,
456465
)
466+
467+
def _convert_to_row_objects(self, rows):
468+
"""
469+
Convert raw data rows to Row objects with named columns based on description.
470+
471+
Args:
472+
rows: List of raw data rows
473+
474+
Returns:
475+
List of Row objects with named columns
476+
"""
477+
if not self.description or not rows:
478+
return rows
479+
480+
column_names = [col[0] for col in self.description]
481+
ResultRow = Row(*column_names)
482+
return [ResultRow(*row) for row in rows]
457483

458484
def _fill_results_buffer(self):
459485
"""Fill the results buffer from the backend."""
460-
raise NotImplementedError("fetchone is not implemented for SEA backend")
486+
return None
487+
488+
def _convert_rows_to_arrow_table(self, rows):
489+
"""Convert rows to Arrow table."""
490+
if not self.description:
491+
return pyarrow.Table.from_pylist([])
492+
493+
# Create dict of column data
494+
column_data = {}
495+
column_names = [col[0] for col in self.description]
496+
497+
for i, name in enumerate(column_names):
498+
column_data[name] = [row[i] for row in rows]
499+
500+
return pyarrow.Table.from_pydict(column_data)
501+
502+
def _create_empty_arrow_table(self):
503+
"""Create an empty Arrow table with the correct schema."""
504+
if not self.description:
505+
return pyarrow.Table.from_pylist([])
506+
507+
column_names = [col[0] for col in self.description]
508+
return pyarrow.Table.from_pydict({name: [] for name in column_names})
461509

462510
def fetchone(self) -> Optional[Row]:
463511
"""
464512
Fetch the next row of a query result set, returning a single sequence,
465513
or None when no more data is available.
466514
"""
467-
468-
raise NotImplementedError("fetchone is not implemented for SEA backend")
515+
if isinstance(self.results, JsonQueue):
516+
rows = self.results.next_n_rows(1)
517+
if not rows:
518+
return None
519+
520+
# Convert to Row object
521+
converted_rows = self._convert_to_row_objects(rows)
522+
return converted_rows[0] if converted_rows else None
523+
else:
524+
raise NotImplementedError("Unsupported queue type")
469525

470526
def fetchmany(self, size: Optional[int] = None) -> List[Row]:
471527
"""
472528
Fetch the next set of rows of a query result, returning a list of rows.
473529
474530
An empty sequence is returned when no more rows are available.
475531
"""
532+
if size is None:
533+
size = self.arraysize
534+
535+
if size < 0:
536+
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
537+
538+
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
539+
if isinstance(self.results, JsonQueue):
540+
rows = self.results.next_n_rows(size)
541+
self._next_row_index += len(rows)
476542

477-
raise NotImplementedError("fetchmany is not implemented for SEA backend")
543+
# Convert to Row objects
544+
return self._convert_to_row_objects(rows)
545+
else:
546+
raise NotImplementedError("Unsupported queue type")
478547

479548
def fetchall(self) -> List[Row]:
480549
"""
481550
Fetch all (remaining) rows of a query result, returning them as a list of rows.
482551
"""
483-
raise NotImplementedError("fetchall is not implemented for SEA backend")
552+
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
553+
if isinstance(self.results, JsonQueue):
554+
rows = self.results.remaining_rows()
555+
self._next_row_index += len(rows)
556+
557+
# Convert to Row objects
558+
return self._convert_to_row_objects(rows)
559+
else:
560+
raise NotImplementedError("Unsupported queue type")
484561

485562
def fetchmany_arrow(self, size: int) -> Any:
486563
"""Fetch the next set of rows as an Arrow table."""
487-
raise NotImplementedError("fetchmany_arrow is not implemented for SEA backend")
564+
if not pyarrow:
565+
raise ImportError("PyArrow is required for Arrow support")
566+
567+
if isinstance(self.results, JsonQueue):
568+
rows = self.fetchmany(size)
569+
if not rows:
570+
# Return empty Arrow table with schema
571+
return self._create_empty_arrow_table()
572+
573+
# Convert rows to Arrow table
574+
return self._convert_rows_to_arrow_table(rows)
575+
else:
576+
raise NotImplementedError("Unsupported queue type")
488577

489578
def fetchall_arrow(self) -> Any:
490579
"""Fetch all remaining rows as an Arrow table."""
491-
raise NotImplementedError("fetchall_arrow is not implemented for SEA backend")
580+
if not pyarrow:
581+
raise ImportError("PyArrow is required for Arrow support")
582+
583+
if isinstance(self.results, JsonQueue):
584+
rows = self.fetchall()
585+
if not rows:
586+
# Return empty Arrow table with schema
587+
return self._create_empty_arrow_table()
588+
589+
# Convert rows to Arrow table
590+
return self._convert_rows_to_arrow_table(rows)
591+
else:
592+
raise NotImplementedError("Unsupported queue type")
593+

src/databricks/sql/utils.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
import lz4.frame
1515

16+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
17+
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
18+
1619
try:
1720
import pyarrow
1821
except ImportError:
@@ -48,7 +51,7 @@ def remaining_rows(self):
4851
pass
4952

5053

51-
class ResultSetQueueFactory(ABC):
54+
class ThriftResultSetQueueFactory(ABC):
5255
@staticmethod
5356
def build_queue(
5457
row_set_type: TSparkRowSetType,
@@ -106,6 +109,69 @@ def build_queue(
106109
else:
107110
raise AssertionError("Row set type is not valid")
108111

112+
class SeaResultSetQueueFactory(ABC):
113+
@staticmethod
114+
def build_queue(
115+
sea_result_data: ResultData,
116+
manifest: Optional[ResultManifest],
117+
statement_id: str,
118+
description: Optional[List[Tuple[Any, ...]]] = None,
119+
schema_bytes: Optional[bytes] = None,
120+
max_download_threads: Optional[int] = None,
121+
ssl_options: Optional[SSLOptions] = None,
122+
sea_client: Optional["SeaDatabricksClient"] = None,
123+
lz4_compressed: bool = False,
124+
) -> ResultSetQueue:
125+
"""
126+
Factory method to build a result set queue for SEA backend.
127+
128+
Args:
129+
sea_result_data (ResultData): Result data from SEA response
130+
manifest (ResultManifest): Manifest from SEA response
131+
statement_id (str): Statement ID for the query
132+
description (List[List[Any]]): Column descriptions
133+
schema_bytes (bytes): Arrow schema bytes
134+
max_download_threads (int): Maximum number of download threads
135+
ssl_options (SSLOptions): SSL options for downloads
136+
sea_client (SeaDatabricksClient): SEA client for fetching additional links
137+
lz4_compressed (bool): Whether the data is LZ4 compressed
138+
139+
Returns:
140+
ResultSetQueue: The appropriate queue for the result data
141+
"""
142+
143+
if sea_result_data.data is not None:
144+
# INLINE disposition with JSON_ARRAY format
145+
return JsonQueue(sea_result_data.data)
146+
elif sea_result_data.external_links is not None:
147+
# EXTERNAL_LINKS disposition
148+
raise NotImplementedError("EXTERNAL_LINKS disposition is not implemented for SEA backend")
149+
else:
150+
# Empty result set
151+
return JsonQueue([])
152+
153+
154+
class JsonQueue(ResultSetQueue):
155+
"""Queue implementation for JSON_ARRAY format data."""
156+
157+
def __init__(self, data_array):
158+
"""Initialize with JSON array data."""
159+
self.data_array = data_array
160+
self.cur_row_index = 0
161+
self.n_valid_rows = len(data_array)
162+
163+
def next_n_rows(self, num_rows):
164+
"""Get the next n rows from the data array."""
165+
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
166+
slice = self.data_array[self.cur_row_index : self.cur_row_index + length]
167+
self.cur_row_index += length
168+
return slice
169+
170+
def remaining_rows(self):
171+
"""Get all remaining rows from the data array."""
172+
slice = self.data_array[self.cur_row_index :]
173+
self.cur_row_index += len(slice)
174+
return slice
109175

110176
class ColumnTable:
111177
def __init__(self, column_table, column_names):

0 commit comments

Comments
 (0)