Skip to content

Commit 6229848

Browse files
remove irrelevant changes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 349c021 commit 6229848

File tree

6 files changed

+131
-97
lines changed

6 files changed

+131
-97
lines changed

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
9898

9999
@dataclass
100100
class CreateSessionRequest:
101-
"""Request to create a new session."""
101+
"""Representation of a request to create a new session."""
102102

103103
warehouse_id: str
104104
session_confs: Optional[Dict[str, str]] = None
@@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]:
123123

124124
@dataclass
125125
class DeleteSessionRequest:
126-
"""Request to delete a session."""
126+
"""Representation of a request to delete a session."""
127127

128128
warehouse_id: str
129129
session_id: str

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
146146

147147
@dataclass
148148
class CreateSessionResponse:
149-
"""Response from creating a new session."""
149+
"""Representation of the response from creating a new session."""
150150

151151
session_id: str
152152

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
import requests
4-
from typing import Callable, Dict, Any, Optional, Union, List, Tuple
4+
from typing import Callable, Dict, Any, Optional, List, Tuple
55
from urllib.parse import urljoin
66

77
from databricks.sql.auth.authenticators import AuthProvider

src/databricks/sql/backend/thrift_backend.py

Lines changed: 71 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,21 @@
33
import logging
44
import math
55
import time
6-
import uuid
76
import threading
87
from typing import List, Union, Any, TYPE_CHECKING
98

109
if TYPE_CHECKING:
1110
from databricks.sql.client import Cursor
1211

13-
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1412
from databricks.sql.backend.types import (
1513
CommandState,
1614
SessionId,
1715
CommandId,
18-
BackendType,
19-
guid_to_hex_id,
2016
ExecuteResponse,
2117
)
2218
from databricks.sql.backend.utils import guid_to_hex_id
2319

20+
2421
try:
2522
import pyarrow
2623
except ImportError:
@@ -760,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
760757
)
761758
direct_results = resp.directResults
762759
has_been_closed_server_side = direct_results and direct_results.closeOperation
763-
has_more_rows = (
760+
761+
is_direct_results = (
764762
(not direct_results)
765763
or (not direct_results.resultSet)
766764
or direct_results.resultSet.hasMoreRows
767765
)
766+
768767
description = self._hive_schema_to_description(
769768
t_result_set_metadata_resp.schema
770769
)
@@ -780,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
780779
schema_bytes = None
781780

782781
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
783-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
784-
if direct_results and direct_results.resultSet:
785-
assert direct_results.resultSet.results.startRowOffset == 0
786-
assert direct_results.resultSetMetadata
787-
788-
arrow_queue_opt = ResultSetQueueFactory.build_queue(
789-
row_set_type=t_result_set_metadata_resp.resultFormat,
790-
t_row_set=direct_results.resultSet.results,
791-
arrow_schema_bytes=schema_bytes,
792-
max_download_threads=self.max_download_threads,
793-
lz4_compressed=lz4_compressed,
794-
description=description,
795-
ssl_options=self._ssl_options,
796-
)
797-
else:
798-
arrow_queue_opt = None
799-
800782
command_id = CommandId.from_thrift_handle(resp.operationHandle)
801783

802784
status = CommandState.from_thrift_state(operation_state)
803785
if status is None:
804786
raise ValueError(f"Unknown command state: {operation_state}")
805787

806-
return (
807-
ExecuteResponse(
808-
command_id=command_id,
809-
status=status,
810-
description=description,
811-
has_more_rows=has_more_rows,
812-
results_queue=arrow_queue_opt,
813-
has_been_closed_server_side=has_been_closed_server_side,
814-
lz4_compressed=lz4_compressed,
815-
is_staging_operation=is_staging_operation,
816-
),
817-
schema_bytes,
788+
execute_response = ExecuteResponse(
789+
command_id=command_id,
790+
status=status,
791+
description=description,
792+
has_been_closed_server_side=has_been_closed_server_side,
793+
lz4_compressed=lz4_compressed,
794+
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
795+
arrow_schema_bytes=schema_bytes,
796+
result_format=t_result_set_metadata_resp.resultFormat,
818797
)
819798

799+
return execute_response, is_direct_results
800+
820801
def get_execution_result(
821802
self, command_id: CommandId, cursor: "Cursor"
822803
) -> "ResultSet":
@@ -841,9 +822,6 @@ def get_execution_result(
841822

842823
t_result_set_metadata_resp = resp.resultSetMetadata
843824

844-
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
845-
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
846-
has_more_rows = resp.hasMoreRows
847825
description = self._hive_schema_to_description(
848826
t_result_set_metadata_resp.schema
849827
)
@@ -858,25 +836,21 @@ def get_execution_result(
858836
else:
859837
schema_bytes = None
860838

861-
queue = ResultSetQueueFactory.build_queue(
862-
row_set_type=resp.resultSetMetadata.resultFormat,
863-
t_row_set=resp.results,
864-
arrow_schema_bytes=schema_bytes,
865-
max_download_threads=self.max_download_threads,
866-
lz4_compressed=lz4_compressed,
867-
description=description,
868-
ssl_options=self._ssl_options,
869-
)
839+
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
840+
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
841+
is_direct_results = resp.hasMoreRows
842+
843+
status = self.get_query_state(command_id)
870844

871845
execute_response = ExecuteResponse(
872846
command_id=command_id,
873847
status=status,
874848
description=description,
875-
has_more_rows=has_more_rows,
876-
results_queue=queue,
877849
has_been_closed_server_side=False,
878850
lz4_compressed=lz4_compressed,
879851
is_staging_operation=is_staging_operation,
852+
arrow_schema_bytes=schema_bytes,
853+
result_format=t_result_set_metadata_resp.resultFormat,
880854
)
881855

882856
return ThriftResultSet(
@@ -886,7 +860,10 @@ def get_execution_result(
886860
buffer_size_bytes=cursor.buffer_size_bytes,
887861
arraysize=cursor.arraysize,
888862
use_cloud_fetch=cursor.connection.use_cloud_fetch,
889-
arrow_schema_bytes=schema_bytes,
863+
t_row_set=resp.results,
864+
max_download_threads=self.max_download_threads,
865+
ssl_options=self._ssl_options,
866+
is_direct_results=is_direct_results,
890867
)
891868

892869
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -999,18 +976,25 @@ def execute_command(
999976
self._handle_execute_response_async(resp, cursor)
1000977
return None
1001978
else:
1002-
execute_response, arrow_schema_bytes = self._handle_execute_response(
979+
execute_response, is_direct_results = self._handle_execute_response(
1003980
resp, cursor
1004981
)
1005982

983+
t_row_set = None
984+
if resp.directResults and resp.directResults.resultSet:
985+
t_row_set = resp.directResults.resultSet.results
986+
1006987
return ThriftResultSet(
1007988
connection=cursor.connection,
1008989
execute_response=execute_response,
1009990
thrift_client=self,
1010991
buffer_size_bytes=max_bytes,
1011992
arraysize=max_rows,
1012993
use_cloud_fetch=use_cloud_fetch,
1013-
arrow_schema_bytes=arrow_schema_bytes,
994+
t_row_set=t_row_set,
995+
max_download_threads=self.max_download_threads,
996+
ssl_options=self._ssl_options,
997+
is_direct_results=is_direct_results,
1014998
)
1015999

10161000
def get_catalogs(
@@ -1032,18 +1016,25 @@ def get_catalogs(
10321016
)
10331017
resp = self.make_request(self._client.GetCatalogs, req)
10341018

1035-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1019+
execute_response, is_direct_results = self._handle_execute_response(
10361020
resp, cursor
10371021
)
10381022

1023+
t_row_set = None
1024+
if resp.directResults and resp.directResults.resultSet:
1025+
t_row_set = resp.directResults.resultSet.results
1026+
10391027
return ThriftResultSet(
10401028
connection=cursor.connection,
10411029
execute_response=execute_response,
10421030
thrift_client=self,
10431031
buffer_size_bytes=max_bytes,
10441032
arraysize=max_rows,
10451033
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1046-
arrow_schema_bytes=arrow_schema_bytes,
1034+
t_row_set=t_row_set,
1035+
max_download_threads=self.max_download_threads,
1036+
ssl_options=self._ssl_options,
1037+
is_direct_results=is_direct_results,
10471038
)
10481039

10491040
def get_schemas(
@@ -1069,18 +1060,25 @@ def get_schemas(
10691060
)
10701061
resp = self.make_request(self._client.GetSchemas, req)
10711062

1072-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1063+
execute_response, is_direct_results = self._handle_execute_response(
10731064
resp, cursor
10741065
)
10751066

1067+
t_row_set = None
1068+
if resp.directResults and resp.directResults.resultSet:
1069+
t_row_set = resp.directResults.resultSet.results
1070+
10761071
return ThriftResultSet(
10771072
connection=cursor.connection,
10781073
execute_response=execute_response,
10791074
thrift_client=self,
10801075
buffer_size_bytes=max_bytes,
10811076
arraysize=max_rows,
10821077
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1083-
arrow_schema_bytes=arrow_schema_bytes,
1078+
t_row_set=t_row_set,
1079+
max_download_threads=self.max_download_threads,
1080+
ssl_options=self._ssl_options,
1081+
is_direct_results=is_direct_results,
10841082
)
10851083

10861084
def get_tables(
@@ -1110,18 +1108,25 @@ def get_tables(
11101108
)
11111109
resp = self.make_request(self._client.GetTables, req)
11121110

1113-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1111+
execute_response, is_direct_results = self._handle_execute_response(
11141112
resp, cursor
11151113
)
11161114

1115+
t_row_set = None
1116+
if resp.directResults and resp.directResults.resultSet:
1117+
t_row_set = resp.directResults.resultSet.results
1118+
11171119
return ThriftResultSet(
11181120
connection=cursor.connection,
11191121
execute_response=execute_response,
11201122
thrift_client=self,
11211123
buffer_size_bytes=max_bytes,
11221124
arraysize=max_rows,
11231125
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1124-
arrow_schema_bytes=arrow_schema_bytes,
1126+
t_row_set=t_row_set,
1127+
max_download_threads=self.max_download_threads,
1128+
ssl_options=self._ssl_options,
1129+
is_direct_results=is_direct_results,
11251130
)
11261131

11271132
def get_columns(
@@ -1151,18 +1156,25 @@ def get_columns(
11511156
)
11521157
resp = self.make_request(self._client.GetColumns, req)
11531158

1154-
execute_response, arrow_schema_bytes = self._handle_execute_response(
1159+
execute_response, is_direct_results = self._handle_execute_response(
11551160
resp, cursor
11561161
)
11571162

1163+
t_row_set = None
1164+
if resp.directResults and resp.directResults.resultSet:
1165+
t_row_set = resp.directResults.resultSet.results
1166+
11581167
return ThriftResultSet(
11591168
connection=cursor.connection,
11601169
execute_response=execute_response,
11611170
thrift_client=self,
11621171
buffer_size_bytes=max_bytes,
11631172
arraysize=max_rows,
11641173
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1165-
arrow_schema_bytes=arrow_schema_bytes,
1174+
t_row_set=t_row_set,
1175+
max_download_threads=self.max_download_threads,
1176+
ssl_options=self._ssl_options,
1177+
is_direct_results=is_direct_results,
11661178
)
11671179

11681180
def _handle_execute_response(self, resp, cursor):

0 commit comments

Comments
 (0)