Skip to content

Commit 6d63df0

Browse files
Normalise Execution Response (clean backend interfaces) (#587)
* [squash from exec-sea] bring over execution phase changes Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove excess test Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add docstring Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remvoe exec func in sea backend Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove excess files Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove excess models Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove excess sea backend tests Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * cleanup Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * re-introduce get_schema_desc Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove SeaResultSet Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * clean imports and attributes Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * pass CommandId to ExecResp Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove changes in types Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add back essential types (ExecResponse, from_sea_state) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix fetch types Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * excess imports Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce diff by maintaining logs Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix int test types Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move guid_to_hex_id import to utils Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce diff in guid utils import Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move arrow_schema_bytes back into ExecuteResult Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * maintain log Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove un-necessary assignment Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove un-necessary tuple response Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove un-ncessary verbose mocking Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move Queue construction to ResultSert Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move description to List[Tuple] Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * frmatting (black) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce diff (remove explicit tuple conversion) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove has_more_rows from ExecuteResponse Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove un-necessary has_more_rows aclc Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * default has_more_rows to True Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * return has_more_rows from ExecResponse conversion during GetRespMetadata Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove unnecessary replacement Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * better mocked backend naming Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove has_more_rows test in ExecuteResponse Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * introduce replacement of original has_more_rows read test Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * call correct method in test_use_arrow_schema Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * call correct method in test_fall_back_to_hive_schema Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * re-introduce result response read test Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * simplify test Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove excess fetch_results mocks Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * more minimal changes to thrift_backend tests Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move back to old table types Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove outdated arrow_schema_bytes return Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove duplicate import Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * rephrase model docstrings to explicitly denote that they are representations and not used over the wire Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * has_more_rows -> is_direct_results Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * switch docstring format to align with Connection class Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> --------- Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 0887bc1 commit 6d63df0

16 files changed

+343
-220
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
from databricks.sql.thrift_api.TCLIService import ttypes
1818
from databricks.sql.backend.types import SessionId, CommandId, CommandState
19-
from databricks.sql.utils import ExecuteResponse
20-
from databricks.sql.types import SSLOptions
2119

2220
# Forward reference for type hints
2321
from typing import TYPE_CHECKING

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
@dataclass
66
class CreateSessionRequest:
7-
"""Request to create a new session."""
7+
"""Representation of a request to create a new session."""
88

99
warehouse_id: str
1010
session_confs: Optional[Dict[str, str]] = None
@@ -29,7 +29,7 @@ def to_dict(self) -> Dict[str, Any]:
2929

3030
@dataclass
3131
class DeleteSessionRequest:
32-
"""Request to delete a session."""
32+
"""Representation of a request to delete a session."""
3333

3434
warehouse_id: str
3535
session_id: str

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
@dataclass
66
class CreateSessionResponse:
7-
"""Response from creating a new session."""
7+
"""Representation of the response from creating a new session."""
88

99
session_id: str
1010

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
import requests
4-
from typing import Callable, Dict, Any, Optional, Union, List, Tuple
4+
from typing import Callable, Dict, Any, Optional, List, Tuple
55
from urllib.parse import urljoin
66

77
from databricks.sql.auth.authenticators import AuthProvider

src/databricks/sql/backend/thrift_backend.py

Lines changed: 88 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,21 @@
33
import logging
44
import math
55
import time
6-
import uuid
76
import threading
8-
from typing import List, Optional, Union, Any, TYPE_CHECKING
7+
from typing import List, Union, Any, TYPE_CHECKING
98

109
if TYPE_CHECKING:
1110
from databricks.sql.client import Cursor
12-
from databricks.sql.result_set import ResultSet, ThriftResultSet
1311

14-
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1512
from databricks.sql.backend.types import (
1613
CommandState,
1714
SessionId,
1815
CommandId,
19-
BackendType,
16+
ExecuteResponse,
2017
)
2118
from databricks.sql.backend.utils import guid_to_hex_id
2219

20+
2321
try:
2422
import pyarrow
2523
except ImportError:
@@ -42,7 +40,7 @@
4240
)
4341

4442
from databricks.sql.utils import (
45-
ExecuteResponse,
43+
ResultSetQueueFactory,
4644
_bound,
4745
RequestErrorInfo,
4846
NoRetryReason,
@@ -53,6 +51,7 @@
5351
)
5452
from databricks.sql.types import SSLOptions
5553
from databricks.sql.backend.databricks_client import DatabricksClient
54+
from databricks.sql.result_set import ResultSet, ThriftResultSet
5655

5756
logger = logging.getLogger(__name__)
5857

@@ -758,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
758757
)
759758
direct_results = resp.directResults
760759
has_been_closed_server_side = direct_results and direct_results.closeOperation
761-
has_more_rows = (
760+
761+
is_direct_results = (
762762
(not direct_results)
763763
or (not direct_results.resultSet)
764764
or direct_results.resultSet.hasMoreRows
765765
)
766+
766767
description = self._hive_schema_to_description(
767768
t_result_set_metadata_resp.schema
768769
)
@@ -778,42 +779,28 @@ def _results_message_to_execute_response(self, resp, operation_state):
778779
schema_bytes = None
779780

780781
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
781-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
782-
if direct_results and direct_results.resultSet:
783-
assert direct_results.resultSet.results.startRowOffset == 0
784-
assert direct_results.resultSetMetadata
785-
786-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
787-
row_set_type=t_result_set_metadata_resp.resultFormat,
788-
t_row_set=direct_results.resultSet.results,
789-
arrow_schema_bytes=schema_bytes,
790-
max_download_threads=self.max_download_threads,
791-
lz4_compressed=lz4_compressed,
792-
description=description,
793-
ssl_options=self._ssl_options,
794-
)
795-
else:
796-
arrow_queue_opt = None
797-
798782
command_id = CommandId.from_thrift_handle(resp.operationHandle)
799783

800-
return ExecuteResponse(
801-
arrow_queue=arrow_queue_opt,
802-
status=CommandState.from_thrift_state(operation_state),
803-
has_been_closed_server_side=has_been_closed_server_side,
804-
has_more_rows=has_more_rows,
805-
lz4_compressed=lz4_compressed,
806-
is_staging_operation=is_staging_operation,
784+
status = CommandState.from_thrift_state(operation_state)
785+
if status is None:
786+
raise ValueError(f"Unknown command state: {operation_state}")
787+
788+
execute_response = ExecuteResponse(
807789
command_id=command_id,
790+
status=status,
808791
description=description,
792+
has_been_closed_server_side=has_been_closed_server_side,
793+
lz4_compressed=lz4_compressed,
794+
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
809795
arrow_schema_bytes=schema_bytes,
796+
result_format=t_result_set_metadata_resp.resultFormat,
810797
)
811798

799+
return execute_response, is_direct_results
800+
812801
def get_execution_result(
813802
self, command_id: CommandId, cursor: "Cursor"
814803
) -> "ResultSet":
815-
from databricks.sql.result_set import ThriftResultSet
816-
817804
thrift_handle = command_id.to_thrift_handle()
818805
if not thrift_handle:
819806
raise ValueError("Not a valid Thrift command ID")
@@ -835,9 +822,6 @@ def get_execution_result(
835822

836823
t_result_set_metadata_resp = resp.resultSetMetadata
837824

838-
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
839-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
840-
has_more_rows = resp.hasMoreRows
841825
description = self._hive_schema_to_description(
842826
t_result_set_metadata_resp.schema
843827
)
@@ -852,26 +836,21 @@ def get_execution_result(
852836
else:
853837
schema_bytes = None
854838

855-
queue = ResultSetQueueFactory.build_queue(
856-
row_set_type=resp.resultSetMetadata.resultFormat,
857-
t_row_set=resp.results,
858-
arrow_schema_bytes=schema_bytes,
859-
max_download_threads=self.max_download_threads,
860-
lz4_compressed=lz4_compressed,
861-
description=description,
862-
ssl_options=self._ssl_options,
863-
)
839+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
840+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
841+
is_direct_results = resp.hasMoreRows
842+
843+
status = self.get_query_state(command_id)
864844

865845
execute_response = ExecuteResponse(
866-
arrow_queue=queue,
867-
status=CommandState.from_thrift_state(resp.status),
846+
command_id=command_id,
847+
status=status,
848+
description=description,
868849
has_been_closed_server_side=False,
869-
has_more_rows=has_more_rows,
870850
lz4_compressed=lz4_compressed,
871851
is_staging_operation=is_staging_operation,
872-
command_id=command_id,
873-
description=description,
874852
arrow_schema_bytes=schema_bytes,
853+
result_format=t_result_set_metadata_resp.resultFormat,
875854
)
876855

877856
return ThriftResultSet(
@@ -881,6 +860,10 @@ def get_execution_result(
881860
buffer_size_bytes=cursor.buffer_size_bytes,
882861
arraysize=cursor.arraysize,
883862
use_cloud_fetch=cursor.connection.use_cloud_fetch,
863+
t_row_set=resp.results,
864+
max_download_threads=self.max_download_threads,
865+
ssl_options=self._ssl_options,
866+
is_direct_results=is_direct_results,
884867
)
885868

886869
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -947,8 +930,6 @@ def execute_command(
947930
async_op=False,
948931
enforce_embedded_schema_correctness=False,
949932
) -> Union["ResultSet", None]:
950-
from databricks.sql.result_set import ThriftResultSet
951-
952933
thrift_handle = session_id.to_thrift_handle()
953934
if not thrift_handle:
954935
raise ValueError("Not a valid Thrift session ID")
@@ -995,7 +976,13 @@ def execute_command(
995976
self._handle_execute_response_async(resp, cursor)
996977
return None
997978
else:
998-
execute_response = self._handle_execute_response(resp, cursor)
979+
execute_response, is_direct_results = self._handle_execute_response(
980+
resp, cursor
981+
)
982+
983+
t_row_set = None
984+
if resp.directResults and resp.directResults.resultSet:
985+
t_row_set = resp.directResults.resultSet.results
999986

1000987
return ThriftResultSet(
1001988
connection=cursor.connection,
@@ -1004,6 +991,10 @@ def execute_command(
1004991
buffer_size_bytes=max_bytes,
1005992
arraysize=max_rows,
1006993
use_cloud_fetch=use_cloud_fetch,
994+
t_row_set=t_row_set,
995+
max_download_threads=self.max_download_threads,
996+
ssl_options=self._ssl_options,
997+
is_direct_results=is_direct_results,
1007998
)
1008999

10091000
def get_catalogs(
@@ -1013,8 +1004,6 @@ def get_catalogs(
10131004
max_bytes: int,
10141005
cursor: "Cursor",
10151006
) -> "ResultSet":
1016-
from databricks.sql.result_set import ThriftResultSet
1017-
10181007
thrift_handle = session_id.to_thrift_handle()
10191008
if not thrift_handle:
10201009
raise ValueError("Not a valid Thrift session ID")
@@ -1027,7 +1016,13 @@ def get_catalogs(
10271016
)
10281017
resp = self.make_request(self._client.GetCatalogs, req)
10291018

1030-
execute_response = self._handle_execute_response(resp, cursor)
1019+
execute_response, is_direct_results = self._handle_execute_response(
1020+
resp, cursor
1021+
)
1022+
1023+
t_row_set = None
1024+
if resp.directResults and resp.directResults.resultSet:
1025+
t_row_set = resp.directResults.resultSet.results
10311026

10321027
return ThriftResultSet(
10331028
connection=cursor.connection,
@@ -1036,6 +1031,10 @@ def get_catalogs(
10361031
buffer_size_bytes=max_bytes,
10371032
arraysize=max_rows,
10381033
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1034+
t_row_set=t_row_set,
1035+
max_download_threads=self.max_download_threads,
1036+
ssl_options=self._ssl_options,
1037+
is_direct_results=is_direct_results,
10391038
)
10401039

10411040
def get_schemas(
@@ -1047,8 +1046,6 @@ def get_schemas(
10471046
catalog_name=None,
10481047
schema_name=None,
10491048
) -> "ResultSet":
1050-
from databricks.sql.result_set import ThriftResultSet
1051-
10521049
thrift_handle = session_id.to_thrift_handle()
10531050
if not thrift_handle:
10541051
raise ValueError("Not a valid Thrift session ID")
@@ -1063,7 +1060,13 @@ def get_schemas(
10631060
)
10641061
resp = self.make_request(self._client.GetSchemas, req)
10651062

1066-
execute_response = self._handle_execute_response(resp, cursor)
1063+
execute_response, is_direct_results = self._handle_execute_response(
1064+
resp, cursor
1065+
)
1066+
1067+
t_row_set = None
1068+
if resp.directResults and resp.directResults.resultSet:
1069+
t_row_set = resp.directResults.resultSet.results
10671070

10681071
return ThriftResultSet(
10691072
connection=cursor.connection,
@@ -1072,6 +1075,10 @@ def get_schemas(
10721075
buffer_size_bytes=max_bytes,
10731076
arraysize=max_rows,
10741077
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1078+
t_row_set=t_row_set,
1079+
max_download_threads=self.max_download_threads,
1080+
ssl_options=self._ssl_options,
1081+
is_direct_results=is_direct_results,
10751082
)
10761083

10771084
def get_tables(
@@ -1085,8 +1092,6 @@ def get_tables(
10851092
table_name=None,
10861093
table_types=None,
10871094
) -> "ResultSet":
1088-
from databricks.sql.result_set import ThriftResultSet
1089-
10901095
thrift_handle = session_id.to_thrift_handle()
10911096
if not thrift_handle:
10921097
raise ValueError("Not a valid Thrift session ID")
@@ -1103,7 +1108,13 @@ def get_tables(
11031108
)
11041109
resp = self.make_request(self._client.GetTables, req)
11051110

1106-
execute_response = self._handle_execute_response(resp, cursor)
1111+
execute_response, is_direct_results = self._handle_execute_response(
1112+
resp, cursor
1113+
)
1114+
1115+
t_row_set = None
1116+
if resp.directResults and resp.directResults.resultSet:
1117+
t_row_set = resp.directResults.resultSet.results
11071118

11081119
return ThriftResultSet(
11091120
connection=cursor.connection,
@@ -1112,6 +1123,10 @@ def get_tables(
11121123
buffer_size_bytes=max_bytes,
11131124
arraysize=max_rows,
11141125
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1126+
t_row_set=t_row_set,
1127+
max_download_threads=self.max_download_threads,
1128+
ssl_options=self._ssl_options,
1129+
is_direct_results=is_direct_results,
11151130
)
11161131

11171132
def get_columns(
@@ -1125,8 +1140,6 @@ def get_columns(
11251140
table_name=None,
11261141
column_name=None,
11271142
) -> "ResultSet":
1128-
from databricks.sql.result_set import ThriftResultSet
1129-
11301143
thrift_handle = session_id.to_thrift_handle()
11311144
if not thrift_handle:
11321145
raise ValueError("Not a valid Thrift session ID")
@@ -1143,7 +1156,13 @@ def get_columns(
11431156
)
11441157
resp = self.make_request(self._client.GetColumns, req)
11451158

1146-
execute_response = self._handle_execute_response(resp, cursor)
1159+
execute_response, is_direct_results = self._handle_execute_response(
1160+
resp, cursor
1161+
)
1162+
1163+
t_row_set = None
1164+
if resp.directResults and resp.directResults.resultSet:
1165+
t_row_set = resp.directResults.resultSet.results
11471166

11481167
return ThriftResultSet(
11491168
connection=cursor.connection,
@@ -1152,6 +1171,10 @@ def get_columns(
11521171
buffer_size_bytes=max_bytes,
11531172
arraysize=max_rows,
11541173
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1174+
t_row_set=t_row_set,
1175+
max_download_threads=self.max_download_threads,
1176+
ssl_options=self._ssl_options,
1177+
is_direct_results=is_direct_results,
11551178
)
11561179

11571180
def _handle_execute_response(self, resp, cursor):

0 commit comments

Comments
 (0)