Skip to content

Commit 787f1f7

Browse files
Merge branch 'sea-migration' into sea-test-scripts
2 parents 3e22c6c + 6d63df0 commit 787f1f7

File tree

10 files changed

+236
-158
lines changed

10 files changed

+236
-158
lines changed

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
@dataclass
66
class CreateSessionRequest:
7-
"""Request to create a new session."""
7+
"""Representation of a request to create a new session."""
88

99
warehouse_id: str
1010
session_confs: Optional[Dict[str, str]] = None
@@ -29,7 +29,7 @@ def to_dict(self) -> Dict[str, Any]:
2929

3030
@dataclass
3131
class DeleteSessionRequest:
32-
"""Request to delete a session."""
32+
"""Representation of a request to delete a session."""
3333

3434
warehouse_id: str
3535
session_id: str

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
@dataclass
66
class CreateSessionResponse:
7-
"""Response from creating a new session."""
7+
"""Representation of the response from creating a new session."""
88

99
session_id: str
1010

src/databricks/sql/backend/thrift_backend.py

Lines changed: 71 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,20 @@
33
import logging
44
import math
55
import time
6-
import uuid
76
import threading
87
from typing import List, Union, Any, TYPE_CHECKING
98

109
if TYPE_CHECKING:
1110
from databricks.sql.client import Cursor
1211

13-
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1412
from databricks.sql.backend.types import (
1513
CommandState,
1614
SessionId,
1715
CommandId,
18-
BackendType,
19-
guid_to_hex_id,
2016
ExecuteResponse,
2117
)
2218

19+
2320
try:
2421
import pyarrow
2522
except ImportError:
@@ -759,11 +756,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759756
)
760757
direct_results = resp.directResults
761758
has_been_closed_server_side = direct_results and direct_results.closeOperation
762-
has_more_rows = (
759+
760+
is_direct_results = (
763761
(not direct_results)
764762
or (not direct_results.resultSet)
765763
or direct_results.resultSet.hasMoreRows
766764
)
765+
767766
description = self._hive_schema_to_description(
768767
t_result_set_metadata_resp.schema
769768
)
@@ -779,43 +778,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779778
schema_bytes = None
780779

781780
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
782-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
783-
if direct_results and direct_results.resultSet:
784-
assert direct_results.resultSet.results.startRowOffset == 0
785-
assert direct_results.resultSetMetadata
786-
787-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
788-
row_set_type=t_result_set_metadata_resp.resultFormat,
789-
t_row_set=direct_results.resultSet.results,
790-
arrow_schema_bytes=schema_bytes,
791-
max_download_threads=self.max_download_threads,
792-
lz4_compressed=lz4_compressed,
793-
description=description,
794-
ssl_options=self._ssl_options,
795-
)
796-
else:
797-
arrow_queue_opt = None
798-
799781
command_id = CommandId.from_thrift_handle(resp.operationHandle)
800782

801783
status = CommandState.from_thrift_state(operation_state)
802784
if status is None:
803785
raise ValueError(f"Unknown command state: {operation_state}")
804786

805-
return (
806-
ExecuteResponse(
807-
command_id=command_id,
808-
status=status,
809-
description=description,
810-
has_more_rows=has_more_rows,
811-
results_queue=arrow_queue_opt,
812-
has_been_closed_server_side=has_been_closed_server_side,
813-
lz4_compressed=lz4_compressed,
814-
is_staging_operation=is_staging_operation,
815-
),
816-
schema_bytes,
787+
execute_response = ExecuteResponse(
788+
command_id=command_id,
789+
status=status,
790+
description=description,
791+
has_been_closed_server_side=has_been_closed_server_side,
792+
lz4_compressed=lz4_compressed,
793+
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
794+
arrow_schema_bytes=schema_bytes,
795+
result_format=t_result_set_metadata_resp.resultFormat,
817796
)
818797

798+
return execute_response, is_direct_results
799+
819800
def get_execution_result(
820801
self, command_id: CommandId, cursor: "Cursor"
821802
) -> "ResultSet":
@@ -840,9 +821,6 @@ def get_execution_result(
840821

841822
t_result_set_metadata_resp = resp.resultSetMetadata
842823

843-
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
844-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
845-
has_more_rows = resp.hasMoreRows
846824
description = self._hive_schema_to_description(
847825
t_result_set_metadata_resp.schema
848826
)
@@ -857,27 +835,23 @@ def get_execution_result(
857835
else:
858836
schema_bytes = None
859837

860-
queue = ResultSetQueueFactory.build_queue(
861-
row_set_type=resp.resultSetMetadata.resultFormat,
862-
t_row_set=resp.results,
863-
arrow_schema_bytes=schema_bytes,
864-
max_download_threads=self.max_download_threads,
865-
lz4_compressed=lz4_compressed,
866-
description=description,
867-
ssl_options=self._ssl_options,
868-
)
838+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
839+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
840+
is_direct_results = resp.hasMoreRows
841+
842+
status = self.get_query_state(command_id)
869843

870844
status = self.get_query_state(command_id)
871845

872846
execute_response = ExecuteResponse(
873847
command_id=command_id,
874848
status=status,
875849
description=description,
876-
has_more_rows=has_more_rows,
877-
results_queue=queue,
878850
has_been_closed_server_side=False,
879851
lz4_compressed=lz4_compressed,
880852
is_staging_operation=is_staging_operation,
853+
arrow_schema_bytes=schema_bytes,
854+
result_format=t_result_set_metadata_resp.resultFormat,
881855
)
882856

883857
return ThriftResultSet(
@@ -887,7 +861,10 @@ def get_execution_result(
887861
buffer_size_bytes=cursor.buffer_size_bytes,
888862
arraysize=cursor.arraysize,
889863
use_cloud_fetch=cursor.connection.use_cloud_fetch,
890-
arrow_schema_bytes=schema_bytes,
864+
t_row_set=resp.results,
865+
max_download_threads=self.max_download_threads,
866+
ssl_options=self._ssl_options,
867+
is_direct_results=is_direct_results,
891868
)
892869

893870
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -1000,18 +977,25 @@ def execute_command(
1000977
self._handle_execute_response_async(resp, cursor)
1001978
return None
1002979
else:
1003-
execute_response, arrow_schema_bytes = self._handle_execute_response(
980+
execute_response, is_direct_results = self._handle_execute_response(
1004981
resp, cursor
1005982
)
1006983

984+
t_row_set = None
985+
if resp.directResults and resp.directResults.resultSet:
986+
t_row_set = resp.directResults.resultSet.results
987+
1007988
return ThriftResultSet(
1008989
connection=cursor.connection,
1009990
execute_response=execute_response,
1010991
thrift_client=self,
1011992
buffer_size_bytes=max_bytes,
1012993
arraysize=max_rows,
1013994
use_cloud_fetch=use_cloud_fetch,
1014-
arrow_schema_bytes=arrow_schema_bytes,
995+
t_row_set=t_row_set,
996+
max_download_threads=self.max_download_threads,
997+
ssl_options=self._ssl_options,
998+
is_direct_results=is_direct_results,
1015999
)
10161000

10171001
def get_catalogs(
@@ -1033,18 +1017,25 @@ def get_catalogs(
10331017
)
10341018
resp = self.make_request(self._client.GetCatalogs, req)
10351019

1036-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1020+
execute_response, is_direct_results = self._handle_execute_response(
10371021
resp, cursor
10381022
)
10391023

1024+
t_row_set = None
1025+
if resp.directResults and resp.directResults.resultSet:
1026+
t_row_set = resp.directResults.resultSet.results
1027+
10401028
return ThriftResultSet(
10411029
connection=cursor.connection,
10421030
execute_response=execute_response,
10431031
thrift_client=self,
10441032
buffer_size_bytes=max_bytes,
10451033
arraysize=max_rows,
10461034
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1047-
arrow_schema_bytes=arrow_schema_bytes,
1035+
t_row_set=t_row_set,
1036+
max_download_threads=self.max_download_threads,
1037+
ssl_options=self._ssl_options,
1038+
is_direct_results=is_direct_results,
10481039
)
10491040

10501041
def get_schemas(
@@ -1070,18 +1061,25 @@ def get_schemas(
10701061
)
10711062
resp = self.make_request(self._client.GetSchemas, req)
10721063

1073-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1064+
execute_response, is_direct_results = self._handle_execute_response(
10741065
resp, cursor
10751066
)
10761067

1068+
t_row_set = None
1069+
if resp.directResults and resp.directResults.resultSet:
1070+
t_row_set = resp.directResults.resultSet.results
1071+
10771072
return ThriftResultSet(
10781073
connection=cursor.connection,
10791074
execute_response=execute_response,
10801075
thrift_client=self,
10811076
buffer_size_bytes=max_bytes,
10821077
arraysize=max_rows,
10831078
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1084-
arrow_schema_bytes=arrow_schema_bytes,
1079+
t_row_set=t_row_set,
1080+
max_download_threads=self.max_download_threads,
1081+
ssl_options=self._ssl_options,
1082+
is_direct_results=is_direct_results,
10851083
)
10861084

10871085
def get_tables(
@@ -1111,18 +1109,25 @@ def get_tables(
11111109
)
11121110
resp = self.make_request(self._client.GetTables, req)
11131111

1114-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1112+
execute_response, is_direct_results = self._handle_execute_response(
11151113
resp, cursor
11161114
)
11171115

1116+
t_row_set = None
1117+
if resp.directResults and resp.directResults.resultSet:
1118+
t_row_set = resp.directResults.resultSet.results
1119+
11181120
return ThriftResultSet(
11191121
connection=cursor.connection,
11201122
execute_response=execute_response,
11211123
thrift_client=self,
11221124
buffer_size_bytes=max_bytes,
11231125
arraysize=max_rows,
11241126
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1125-
arrow_schema_bytes=arrow_schema_bytes,
1127+
t_row_set=t_row_set,
1128+
max_download_threads=self.max_download_threads,
1129+
ssl_options=self._ssl_options,
1130+
is_direct_results=is_direct_results,
11261131
)
11271132

11281133
def get_columns(
@@ -1152,18 +1157,25 @@ def get_columns(
11521157
)
11531158
resp = self.make_request(self._client.GetColumns, req)
11541159

1155-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1160+
execute_response, is_direct_results = self._handle_execute_response(
11561161
resp, cursor
11571162
)
11581163

1164+
t_row_set = None
1165+
if resp.directResults and resp.directResults.resultSet:
1166+
t_row_set = resp.directResults.resultSet.results
1167+
11591168
return ThriftResultSet(
11601169
connection=cursor.connection,
11611170
execute_response=execute_response,
11621171
thrift_client=self,
11631172
buffer_size_bytes=max_bytes,
11641173
arraysize=max_rows,
11651174
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1166-
arrow_schema_bytes=arrow_schema_bytes,
1175+
t_row_set=t_row_set,
1176+
max_download_threads=self.max_download_threads,
1177+
ssl_options=self._ssl_options,
1178+
is_direct_results=is_direct_results,
11671179
)
11681180

11691181
def _handle_execute_response(self, resp, cursor):

src/databricks/sql/backend/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,9 @@ class ExecuteResponse:
423423

424424
command_id: CommandId
425425
status: CommandState
426-
description: Optional[
427-
List[Tuple[str, str, None, None, Optional[int], Optional[int], bool]]
428-
] = None
429-
has_more_rows: bool = False
430-
results_queue: Optional[Any] = None
426+
description: Optional[List[Tuple]] = None
431427
has_been_closed_server_side: bool = False
432428
lz4_compressed: bool = True
433429
is_staging_operation: bool = False
430+
arrow_schema_bytes: Optional[bytes] = None
431+
result_format: Optional[Any] = None

0 commit comments

Comments
 (0)