Skip to content

Commit 40f79b5

Browse files
Merge branch 'sea-res-set' into fetch-json-inline
2 parents 170f339 + 65e7c6b commit 40f79b5

File tree

2 files changed

+37
-42
lines changed

2 files changed

+37
-42
lines changed

src/databricks/sql/result_set.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ def __init__(
4242
command_id: CommandId,
4343
status: CommandState,
4444
has_been_closed_server_side: bool = False,
45-
has_more_rows: bool = False,
4645
results_queue=None,
4746
description=None,
4847
is_staging_operation: bool = False,
48+
lz4_compressed: bool = False,
49+
arrow_schema_bytes: bytes = b"",
4950
):
5051
"""
5152
A ResultSet manages the results of a single command.
@@ -73,9 +74,10 @@ def __init__(
7374
self.command_id = command_id
7475
self.status = status
7576
self.has_been_closed_server_side = has_been_closed_server_side
76-
self.has_more_rows = has_more_rows
7777
self.results = results_queue
7878
self._is_staging_operation = is_staging_operation
79+
self.lz4_compressed = lz4_compressed
80+
self._arrow_schema_bytes = arrow_schema_bytes
7981

8082
def __iter__(self):
8183
while True:
@@ -179,9 +181,24 @@ def __init__(
179181
has_more_rows: Whether there are more rows to fetch
180182
"""
181183
# Initialize ThriftResultSet-specific attributes
182-
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
183184
self._use_cloud_fetch = use_cloud_fetch
184-
self.lz4_compressed = execute_response.lz4_compressed
185+
self.has_more_rows = has_more_rows
186+
187+
# Build the results queue if t_row_set is provided
188+
results_queue = None
189+
if t_row_set and execute_response.result_format is not None:
190+
from databricks.sql.utils import ThriftResultSetQueueFactory
191+
192+
# Create the results queue using the provided format
193+
results_queue = ThriftResultSetQueueFactory.build_queue(
194+
row_set_type=execute_response.result_format,
195+
t_row_set=t_row_set,
196+
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
197+
max_download_threads=max_download_threads,
198+
lz4_compressed=execute_response.lz4_compressed,
199+
description=execute_response.description,
200+
ssl_options=ssl_options,
201+
)
185202

186203
# Build the results queue if t_row_set is provided
187204
results_queue = None
@@ -208,10 +225,11 @@ def __init__(
208225
command_id=execute_response.command_id,
209226
status=execute_response.status,
210227
has_been_closed_server_side=execute_response.has_been_closed_server_side,
211-
has_more_rows=has_more_rows,
212228
results_queue=results_queue,
213229
description=execute_response.description,
214230
is_staging_operation=execute_response.is_staging_operation,
231+
lz4_compressed=execute_response.lz4_compressed,
232+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
215233
)
216234

217235
# Initialize results queue if not provided
@@ -442,7 +460,7 @@ def map_col_type(type_):
442460

443461

444462
class SeaResultSet(ResultSet):
445-
"""ResultSet implementation for the SEA backend."""
463+
"""ResultSet implementation for SEA backend."""
446464

447465
def __init__(
448466
self,
@@ -451,17 +469,20 @@ def __init__(
451469
sea_client: "SeaDatabricksClient",
452470
buffer_size_bytes: int = 104857600,
453471
arraysize: int = 10000,
472+
result_data=None,
473+
manifest=None,
454474
):
455475
"""
456476
Initialize a SeaResultSet with the response from a SEA query execution.
457477
458478
Args:
459479
connection: The parent connection
480+
execute_response: Response from the execute command
460481
sea_client: The SeaDatabricksClient instance for direct access
461482
buffer_size_bytes: Buffer size for fetching results
462483
arraysize: Default number of rows to fetch
463-
execute_response: Response from the execute command (new style)
464-
sea_response: Direct SEA response (legacy style)
484+
result_data: Result data from SEA response (optional)
485+
manifest: Manifest from SEA response (optional)
465486
"""
466487

467488
queue = SeaResultSetQueueFactory.build_queue(
@@ -480,10 +501,10 @@ def __init__(
480501
command_id=execute_response.command_id,
481502
status=execute_response.status,
482503
has_been_closed_server_side=execute_response.has_been_closed_server_side,
483-
has_more_rows=execute_response.has_more_rows,
484-
results_queue=queue,
485504
description=execute_response.description,
486505
is_staging_operation=execute_response.is_staging_operation,
506+
lz4_compressed=execute_response.lz4_compressed,
507+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
487508
)
488509

489510
def _convert_to_row_objects(self, rows):
@@ -505,29 +526,9 @@ def _convert_to_row_objects(self, rows):
505526

506527
def _fill_results_buffer(self):
507528
"""Fill the results buffer from the backend."""
508-
return None
509-
510-
def _convert_rows_to_arrow_table(self, rows):
511-
"""Convert rows to Arrow table."""
512-
if not self.description:
513-
return pyarrow.Table.from_pylist([])
514-
515-
# Create dict of column data
516-
column_data = {}
517-
column_names = [col[0] for col in self.description]
518-
519-
for i, name in enumerate(column_names):
520-
column_data[name] = [row[i] for row in rows]
521-
522-
return pyarrow.Table.from_pydict(column_data)
523-
524-
def _create_empty_arrow_table(self):
525-
"""Create an empty Arrow table with the correct schema."""
526-
if not self.description:
527-
return pyarrow.Table.from_pylist([])
528-
529-
column_names = [col[0] for col in self.description]
530-
return pyarrow.Table.from_pydict({name: [] for name in column_names})
529+
raise NotImplementedError(
530+
"_fill_results_buffer is not implemented for SEA backend"
531+
)
531532

532533
def fetchone(self) -> Optional[Row]:
533534
"""
@@ -571,15 +572,8 @@ def fetchall(self) -> List[Row]:
571572
"""
572573
Fetch all (remaining) rows of a query result, returning them as a list of rows.
573574
"""
574-
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
575-
if isinstance(self.results, JsonQueue):
576-
rows = self.results.remaining_rows()
577-
self._next_row_index += len(rows)
578575

579-
# Convert to Row objects
580-
return self._convert_to_row_objects(rows)
581-
else:
582-
raise NotImplementedError("Unsupported queue type")
576+
raise NotImplementedError("fetchall is not implemented for SEA backend")
583577

584578
def fetchmany_arrow(self, size: int) -> Any:
585579
"""Fetch the next set of rows as an Arrow table."""

tests/unit/test_sea_result_set.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def test_fill_results_buffer_not_implemented(
195195
)
196196

197197
with pytest.raises(
198-
NotImplementedError, match="fetchone is not implemented for SEA backend"
198+
NotImplementedError,
199+
match="_fill_results_buffer is not implemented for SEA backend",
199200
):
200201
result_set._fill_results_buffer()

0 commit comments

Comments
 (0)