Skip to content

Commit 4e07f1e

Browse files
align SeaResultSet with new structure
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent bf26ea3 commit 4e07f1e

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

src/databricks/sql/result_set.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from databricks.sql.thrift_api.TCLIService import ttypes
2020
from databricks.sql.types import Row
2121
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
22-
from databricks.sql.utils import ColumnTable, ColumnQueue
22+
from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue
2323
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2424

2525
logger = logging.getLogger(__name__)
@@ -41,10 +41,11 @@ def __init__(
4141
command_id: CommandId,
4242
status: CommandState,
4343
has_been_closed_server_side: bool = False,
44-
has_more_rows: bool = False,
4544
results_queue=None,
4645
description=None,
4746
is_staging_operation: bool = False,
47+
lz4_compressed: bool = False,
48+
arrow_schema_bytes: bytes = b"",
4849
):
4950
"""
5051
A ResultSet manages the results of a single command.
@@ -72,9 +73,10 @@ def __init__(
7273
self.command_id = command_id
7374
self.status = status
7475
self.has_been_closed_server_side = has_been_closed_server_side
75-
self.has_more_rows = has_more_rows
7676
self.results = results_queue
7777
self._is_staging_operation = is_staging_operation
78+
self.lz4_compressed = lz4_compressed
79+
self._arrow_schema_bytes = arrow_schema_bytes
7880

7981
def __iter__(self):
8082
while True:
@@ -157,7 +159,10 @@ def __init__(
157159
buffer_size_bytes: int = 104857600,
158160
arraysize: int = 10000,
159161
use_cloud_fetch: bool = True,
160-
arrow_schema_bytes: Optional[bytes] = None,
162+
t_row_set=None,
163+
max_download_threads: int = 10,
164+
ssl_options=None,
165+
has_more_rows: bool = True,
161166
):
162167
"""
163168
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -169,12 +174,30 @@ def __init__(
169174
buffer_size_bytes: Buffer size for fetching results
170175
arraysize: Default number of rows to fetch
171176
use_cloud_fetch: Whether to use cloud fetch for retrieving results
172-
arrow_schema_bytes: Arrow schema bytes for the result set
177+
t_row_set: The TRowSet containing result data (if available)
178+
max_download_threads: Maximum number of download threads for cloud fetch
179+
ssl_options: SSL options for cloud fetch
180+
has_more_rows: Whether there are more rows to fetch
173181
"""
174182
# Initialize ThriftResultSet-specific attributes
175-
self._arrow_schema_bytes = arrow_schema_bytes
176183
self._use_cloud_fetch = use_cloud_fetch
177-
self.lz4_compressed = execute_response.lz4_compressed
184+
self.has_more_rows = has_more_rows
185+
186+
# Build the results queue if t_row_set is provided
187+
results_queue = None
188+
if t_row_set and execute_response.result_format is not None:
189+
from databricks.sql.utils import ThriftResultSetQueueFactory
190+
191+
# Create the results queue using the provided format
192+
results_queue = ThriftResultSetQueueFactory.build_queue(
193+
row_set_type=execute_response.result_format,
194+
t_row_set=t_row_set,
195+
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
196+
max_download_threads=max_download_threads,
197+
lz4_compressed=execute_response.lz4_compressed,
198+
description=execute_response.description,
199+
ssl_options=ssl_options,
200+
)
178201

179202
# Call parent constructor with common attributes
180203
super().__init__(
@@ -185,10 +208,11 @@ def __init__(
185208
command_id=execute_response.command_id,
186209
status=execute_response.status,
187210
has_been_closed_server_side=execute_response.has_been_closed_server_side,
188-
has_more_rows=execute_response.has_more_rows,
189-
results_queue=execute_response.results_queue,
211+
results_queue=results_queue,
190212
description=execute_response.description,
191213
is_staging_operation=execute_response.is_staging_operation,
214+
lz4_compressed=execute_response.lz4_compressed,
215+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
192216
)
193217

194218
# Initialize results queue if not provided
@@ -419,7 +443,7 @@ def map_col_type(type_):
419443

420444

421445
class SeaResultSet(ResultSet):
422-
"""ResultSet implementation for the SEA backend."""
446+
"""ResultSet implementation for SEA backend."""
423447

424448
def __init__(
425449
self,
@@ -428,17 +452,20 @@ def __init__(
428452
sea_client: "SeaDatabricksClient",
429453
buffer_size_bytes: int = 104857600,
430454
arraysize: int = 10000,
455+
result_data=None,
456+
manifest=None,
431457
):
432458
"""
433459
Initialize a SeaResultSet with the response from a SEA query execution.
434460
435461
Args:
436462
connection: The parent connection
463+
execute_response: Response from the execute command
437464
sea_client: The SeaDatabricksClient instance for direct access
438465
buffer_size_bytes: Buffer size for fetching results
439466
arraysize: Default number of rows to fetch
440-
execute_response: Response from the execute command (new style)
441-
sea_response: Direct SEA response (legacy style)
467+
result_data: Result data from SEA response (optional)
468+
manifest: Manifest from SEA response (optional)
442469
"""
443470

444471
super().__init__(
@@ -449,15 +476,15 @@ def __init__(
449476
command_id=execute_response.command_id,
450477
status=execute_response.status,
451478
has_been_closed_server_side=execute_response.has_been_closed_server_side,
452-
has_more_rows=execute_response.has_more_rows,
453-
results_queue=execute_response.results_queue,
454479
description=execute_response.description,
455480
is_staging_operation=execute_response.is_staging_operation,
481+
lz4_compressed=execute_response.lz4_compressed,
482+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
456483
)
457484

458485
def _fill_results_buffer(self):
459486
"""Fill the results buffer from the backend."""
460-
raise NotImplementedError("fetchone is not implemented for SEA backend")
487+
raise NotImplementedError("fetchall_arrow is not implemented for SEA backend")
461488

462489
def fetchone(self) -> Optional[Row]:
463490
"""
@@ -480,6 +507,7 @@ def fetchall(self) -> List[Row]:
480507
"""
481508
Fetch all (remaining) rows of a query result, returning them as a list of rows.
482509
"""
510+
483511
raise NotImplementedError("fetchall is not implemented for SEA backend")
484512

485513
def fetchmany_arrow(self, size: int) -> Any:

tests/unit/test_sea_result_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,4 @@ def test_fill_results_buffer_not_implemented(
197197
with pytest.raises(
198198
NotImplementedError, match="fetchone is not implemented for SEA backend"
199199
):
200-
result_set._fill_results_buffer()
200+
result_set._fill_results_buffer()

0 commit comments

Comments
 (0)