Skip to content

Commit 170f339

Browse files
Merge branch 'exec-resp-norm' into fetch-json-inline
2 parents 71b451a + 73bc282 commit 170f339

File tree

7 files changed

+204
-161
lines changed

7 files changed

+204
-161
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 74 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,20 @@
33
import logging
44
import math
55
import time
6-
import uuid
76
import threading
87
from typing import List, Union, Any, TYPE_CHECKING
98

109
if TYPE_CHECKING:
1110
from databricks.sql.client import Cursor
1211

13-
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1412
from databricks.sql.backend.types import (
1513
CommandState,
1614
SessionId,
1715
CommandId,
18-
BackendType,
19-
guid_to_hex_id,
2016
ExecuteResponse,
2117
)
18+
from databricks.sql.backend.utils import guid_to_hex_id
19+
2220

2321
try:
2422
import pyarrow
@@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759757
)
760758
direct_results = resp.directResults
761759
has_been_closed_server_side = direct_results and direct_results.closeOperation
760+
762761
has_more_rows = (
763762
(not direct_results)
764763
or (not direct_results.resultSet)
765764
or direct_results.resultSet.hasMoreRows
766765
)
766+
767767
description = self._hive_schema_to_description(
768768
t_result_set_metadata_resp.schema
769769
)
@@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779779
schema_bytes = None
780780

781781
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
782-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
783-
if direct_results and direct_results.resultSet:
784-
assert direct_results.resultSet.results.startRowOffset == 0
785-
assert direct_results.resultSetMetadata
786-
787-
arrow_queue_opt = ThriftResultSetQueueFactory.build_queue(
788-
row_set_type=t_result_set_metadata_resp.resultFormat,
789-
t_row_set=direct_results.resultSet.results,
790-
arrow_schema_bytes=schema_bytes,
791-
max_download_threads=self.max_download_threads,
792-
lz4_compressed=lz4_compressed,
793-
description=description,
794-
ssl_options=self._ssl_options,
795-
)
796-
else:
797-
arrow_queue_opt = None
798-
799782
command_id = CommandId.from_thrift_handle(resp.operationHandle)
800783

801784
status = CommandState.from_thrift_state(operation_state)
802785
if status is None:
803786
raise ValueError(f"Unknown command state: {operation_state}")
804787

805-
return (
806-
ExecuteResponse(
807-
command_id=command_id,
808-
status=status,
809-
description=description,
810-
has_more_rows=has_more_rows,
811-
results_queue=arrow_queue_opt,
812-
has_been_closed_server_side=has_been_closed_server_side,
813-
lz4_compressed=lz4_compressed,
814-
is_staging_operation=is_staging_operation,
815-
),
816-
schema_bytes,
788+
execute_response = ExecuteResponse(
789+
command_id=command_id,
790+
status=status,
791+
description=description,
792+
has_been_closed_server_side=has_been_closed_server_side,
793+
lz4_compressed=lz4_compressed,
794+
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
795+
arrow_schema_bytes=schema_bytes,
796+
result_format=t_result_set_metadata_resp.resultFormat,
817797
)
818798

799+
return execute_response, has_more_rows
800+
819801
def get_execution_result(
820802
self, command_id: CommandId, cursor: "Cursor"
821803
) -> "ResultSet":
@@ -840,9 +822,6 @@ def get_execution_result(
840822

841823
t_result_set_metadata_resp = resp.resultSetMetadata
842824

843-
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
844-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
845-
has_more_rows = resp.hasMoreRows
846825
description = self._hive_schema_to_description(
847826
t_result_set_metadata_resp.schema
848827
)
@@ -857,27 +836,21 @@ def get_execution_result(
857836
else:
858837
schema_bytes = None
859838

860-
queue = ThriftResultSetQueueFactory.build_queue(
861-
row_set_type=resp.resultSetMetadata.resultFormat,
862-
t_row_set=resp.results,
863-
arrow_schema_bytes=schema_bytes,
864-
max_download_threads=self.max_download_threads,
865-
lz4_compressed=lz4_compressed,
866-
description=description,
867-
ssl_options=self._ssl_options,
868-
)
839+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
840+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
841+
has_more_rows = resp.hasMoreRows
869842

870843
status = self.get_query_state(command_id)
871844

872845
execute_response = ExecuteResponse(
873846
command_id=command_id,
874847
status=status,
875848
description=description,
876-
has_more_rows=has_more_rows,
877-
results_queue=queue,
878849
has_been_closed_server_side=False,
879850
lz4_compressed=lz4_compressed,
880851
is_staging_operation=is_staging_operation,
852+
arrow_schema_bytes=schema_bytes,
853+
result_format=t_result_set_metadata_resp.resultFormat,
881854
)
882855

883856
return ThriftResultSet(
@@ -887,7 +860,10 @@ def get_execution_result(
887860
buffer_size_bytes=cursor.buffer_size_bytes,
888861
arraysize=cursor.arraysize,
889862
use_cloud_fetch=cursor.connection.use_cloud_fetch,
890-
arrow_schema_bytes=schema_bytes,
863+
t_row_set=resp.results,
864+
max_download_threads=self.max_download_threads,
865+
ssl_options=self._ssl_options,
866+
has_more_rows=has_more_rows,
891867
)
892868

893869
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
918894
self._check_command_not_in_error_or_closed_state(thrift_handle, poll_resp)
919895
state = CommandState.from_thrift_state(operation_state)
920896
if state is None:
921-
raise ValueError(f"Invalid operation state: {operation_state}")
897+
raise ValueError(f"Unknown command state: {operation_state}")
922898
return state
923899

924900
@staticmethod
@@ -1000,18 +976,25 @@ def execute_command(
1000976
self._handle_execute_response_async(resp, cursor)
1001977
return None
1002978
else:
1003-
execute_response, arrow_schema_bytes = self._handle_execute_response(
979+
execute_response, has_more_rows = self._handle_execute_response(
1004980
resp, cursor
1005981
)
1006982

983+
t_row_set = None
984+
if resp.directResults and resp.directResults.resultSet:
985+
t_row_set = resp.directResults.resultSet.results
986+
1007987
return ThriftResultSet(
1008988
connection=cursor.connection,
1009989
execute_response=execute_response,
1010990
thrift_client=self,
1011991
buffer_size_bytes=max_bytes,
1012992
arraysize=max_rows,
1013993
use_cloud_fetch=use_cloud_fetch,
1014-
arrow_schema_bytes=arrow_schema_bytes,
994+
t_row_set=t_row_set,
995+
max_download_threads=self.max_download_threads,
996+
ssl_options=self._ssl_options,
997+
has_more_rows=has_more_rows,
1015998
)
1016999

10171000
def get_catalogs(
@@ -1033,9 +1016,11 @@ def get_catalogs(
10331016
)
10341017
resp = self.make_request(self._client.GetCatalogs, req)
10351018

1036-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1037-
resp, cursor
1038-
)
1019+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
1020+
1021+
t_row_set = None
1022+
if resp.directResults and resp.directResults.resultSet:
1023+
t_row_set = resp.directResults.resultSet.results
10391024

10401025
return ThriftResultSet(
10411026
connection=cursor.connection,
@@ -1044,7 +1029,10 @@ def get_catalogs(
10441029
buffer_size_bytes=max_bytes,
10451030
arraysize=max_rows,
10461031
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1047-
arrow_schema_bytes=arrow_schema_bytes,
1032+
t_row_set=t_row_set,
1033+
max_download_threads=self.max_download_threads,
1034+
ssl_options=self._ssl_options,
1035+
has_more_rows=has_more_rows,
10481036
)
10491037

10501038
def get_schemas(
@@ -1070,9 +1058,11 @@ def get_schemas(
10701058
)
10711059
resp = self.make_request(self._client.GetSchemas, req)
10721060

1073-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1074-
resp, cursor
1075-
)
1061+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
1062+
1063+
t_row_set = None
1064+
if resp.directResults and resp.directResults.resultSet:
1065+
t_row_set = resp.directResults.resultSet.results
10761066

10771067
return ThriftResultSet(
10781068
connection=cursor.connection,
@@ -1081,7 +1071,10 @@ def get_schemas(
10811071
buffer_size_bytes=max_bytes,
10821072
arraysize=max_rows,
10831073
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1084-
arrow_schema_bytes=arrow_schema_bytes,
1074+
t_row_set=t_row_set,
1075+
max_download_threads=self.max_download_threads,
1076+
ssl_options=self._ssl_options,
1077+
has_more_rows=has_more_rows,
10851078
)
10861079

10871080
def get_tables(
@@ -1111,9 +1104,11 @@ def get_tables(
11111104
)
11121105
resp = self.make_request(self._client.GetTables, req)
11131106

1114-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1115-
resp, cursor
1116-
)
1107+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
1108+
1109+
t_row_set = None
1110+
if resp.directResults and resp.directResults.resultSet:
1111+
t_row_set = resp.directResults.resultSet.results
11171112

11181113
return ThriftResultSet(
11191114
connection=cursor.connection,
@@ -1122,7 +1117,10 @@ def get_tables(
11221117
buffer_size_bytes=max_bytes,
11231118
arraysize=max_rows,
11241119
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1125-
arrow_schema_bytes=arrow_schema_bytes,
1120+
t_row_set=t_row_set,
1121+
max_download_threads=self.max_download_threads,
1122+
ssl_options=self._ssl_options,
1123+
has_more_rows=has_more_rows,
11261124
)
11271125

11281126
def get_columns(
@@ -1152,9 +1150,11 @@ def get_columns(
11521150
)
11531151
resp = self.make_request(self._client.GetColumns, req)
11541152

1155-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1156-
resp, cursor
1157-
)
1153+
execute_response, has_more_rows = self._handle_execute_response(resp, cursor)
1154+
1155+
t_row_set = None
1156+
if resp.directResults and resp.directResults.resultSet:
1157+
t_row_set = resp.directResults.resultSet.results
11581158

11591159
return ThriftResultSet(
11601160
connection=cursor.connection,
@@ -1163,7 +1163,10 @@ def get_columns(
11631163
buffer_size_bytes=max_bytes,
11641164
arraysize=max_rows,
11651165
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1166-
arrow_schema_bytes=arrow_schema_bytes,
1166+
t_row_set=t_row_set,
1167+
max_download_threads=self.max_download_threads,
1168+
ssl_options=self._ssl_options,
1169+
has_more_rows=has_more_rows,
11671170
)
11681171

11691172
def _handle_execute_response(self, resp, cursor):
@@ -1177,11 +1180,7 @@ def _handle_execute_response(self, resp, cursor):
11771180
resp.directResults and resp.directResults.operationStatus,
11781181
)
11791182

1180-
(
1181-
execute_response,
1182-
arrow_schema_bytes,
1183-
) = self._results_message_to_execute_response(resp, final_operation_state)
1184-
return execute_response, arrow_schema_bytes
1183+
return self._results_message_to_execute_response(resp, final_operation_state)
11851184

11861185
def _handle_execute_response_async(self, resp, cursor):
11871186
command_id = CommandId.from_thrift_handle(resp.operationHandle)
@@ -1225,7 +1224,9 @@ def fetch_results(
12251224
)
12261225
)
12271226

1228-
queue = ThriftResultSetQueueFactory.build_queue(
1227+
from databricks.sql.utils import ResultSetQueueFactory
1228+
1229+
queue = ResultSetQueueFactory.build_queue(
12291230
row_set_type=resp.resultSetMetadata.resultFormat,
12301231
t_row_set=resp.results,
12311232
arrow_schema_bytes=arrow_schema_bytes,

src/databricks/sql/backend/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,9 @@ class ExecuteResponse:
423423

424424
command_id: CommandId
425425
status: CommandState
426-
description: Optional[
427-
List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]]
428-
] = None
429-
has_more_rows: bool = False
430-
results_queue: Optional[Any] = None
426+
description: Optional[List[Tuple]] = None
431427
has_been_closed_server_side: bool = False
432428
lz4_compressed: bool = True
433429
is_staging_operation: bool = False
430+
arrow_schema_bytes: Optional[bytes] = None
431+
result_format: Optional[Any] = None

src/databricks/sql/result_set.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ def __init__(
158158
buffer_size_bytes: int = 104857600,
159159
arraysize: int = 10000,
160160
use_cloud_fetch: bool = True,
161-
arrow_schema_bytes: Optional[bytes] = None,
161+
t_row_set=None,
162+
max_download_threads: int = 10,
163+
ssl_options=None,
164+
has_more_rows: bool = True,
162165
):
163166
"""
164167
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -170,13 +173,32 @@ def __init__(
170173
buffer_size_bytes: Buffer size for fetching results
171174
arraysize: Default number of rows to fetch
172175
use_cloud_fetch: Whether to use cloud fetch for retrieving results
173-
arrow_schema_bytes: Arrow schema bytes for the result set
176+
t_row_set: The TRowSet containing result data (if available)
177+
max_download_threads: Maximum number of download threads for cloud fetch
178+
ssl_options: SSL options for cloud fetch
179+
has_more_rows: Whether there are more rows to fetch
174180
"""
175181
# Initialize ThriftResultSet-specific attributes
176-
self._arrow_schema_bytes = arrow_schema_bytes
182+
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
177183
self._use_cloud_fetch = use_cloud_fetch
178184
self.lz4_compressed = execute_response.lz4_compressed
179185

186+
# Build the results queue if t_row_set is provided
187+
results_queue = None
188+
if t_row_set and execute_response.result_format is not None:
189+
from databricks.sql.utils import ResultSetQueueFactory
190+
191+
# Create the results queue using the provided format
192+
results_queue = ResultSetQueueFactory.build_queue(
193+
row_set_type=execute_response.result_format,
194+
t_row_set=t_row_set,
195+
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
196+
max_download_threads=max_download_threads,
197+
lz4_compressed=execute_response.lz4_compressed,
198+
description=execute_response.description,
199+
ssl_options=ssl_options,
200+
)
201+
180202
# Call parent constructor with common attributes
181203
super().__init__(
182204
connection=connection,
@@ -186,8 +208,8 @@ def __init__(
186208
command_id=execute_response.command_id,
187209
status=execute_response.status,
188210
has_been_closed_server_side=execute_response.has_been_closed_server_side,
189-
has_more_rows=execute_response.has_more_rows,
190-
results_queue=execute_response.results_queue,
211+
has_more_rows=has_more_rows,
212+
results_queue=results_queue,
191213
description=execute_response.description,
192214
is_staging_operation=execute_response.is_staging_operation,
193215
)

0 commit comments

Comments
 (0)