3
3
import logging
4
4
import math
5
5
import time
6
- import uuid
7
6
import threading
8
- from typing import List , Optional , Union , Any , TYPE_CHECKING
7
+ from typing import List , Union , Any , TYPE_CHECKING
9
8
10
9
if TYPE_CHECKING :
11
10
from databricks .sql .client import Cursor
12
- from databricks .sql .result_set import ResultSet , ThriftResultSet
13
11
14
- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
15
12
from databricks .sql .backend .types import (
16
13
CommandState ,
17
14
SessionId ,
18
15
CommandId ,
19
- BackendType ,
16
+ ExecuteResponse ,
20
17
)
21
18
from databricks .sql .backend .utils import guid_to_hex_id
22
19
20
+
23
21
try :
24
22
import pyarrow
25
23
except ImportError :
42
40
)
43
41
44
42
from databricks .sql .utils import (
45
- ExecuteResponse ,
43
+ ResultSetQueueFactory ,
46
44
_bound ,
47
45
RequestErrorInfo ,
48
46
NoRetryReason ,
53
51
)
54
52
from databricks .sql .types import SSLOptions
55
53
from databricks .sql .backend .databricks_client import DatabricksClient
54
+ from databricks .sql .result_set import ResultSet , ThriftResultSet
56
55
57
56
logger = logging .getLogger (__name__ )
58
57
@@ -758,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
758
757
)
759
758
direct_results = resp .directResults
760
759
has_been_closed_server_side = direct_results and direct_results .closeOperation
761
- has_more_rows = (
760
+
761
+ is_direct_results = (
762
762
(not direct_results )
763
763
or (not direct_results .resultSet )
764
764
or direct_results .resultSet .hasMoreRows
765
765
)
766
+
766
767
description = self ._hive_schema_to_description (
767
768
t_result_set_metadata_resp .schema
768
769
)
@@ -778,42 +779,28 @@ def _results_message_to_execute_response(self, resp, operation_state):
778
779
schema_bytes = None
779
780
780
781
lz4_compressed = t_result_set_metadata_resp .lz4Compressed
781
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
782
- if direct_results and direct_results .resultSet :
783
- assert direct_results .resultSet .results .startRowOffset == 0
784
- assert direct_results .resultSetMetadata
785
-
786
- arrow_queue_opt = ResultSetQueueFactory .build_queue (
787
- row_set_type = t_result_set_metadata_resp .resultFormat ,
788
- t_row_set = direct_results .resultSet .results ,
789
- arrow_schema_bytes = schema_bytes ,
790
- max_download_threads = self .max_download_threads ,
791
- lz4_compressed = lz4_compressed ,
792
- description = description ,
793
- ssl_options = self ._ssl_options ,
794
- )
795
- else :
796
- arrow_queue_opt = None
797
-
798
782
command_id = CommandId .from_thrift_handle (resp .operationHandle )
799
783
800
- return ExecuteResponse (
801
- arrow_queue = arrow_queue_opt ,
802
- status = CommandState .from_thrift_state (operation_state ),
803
- has_been_closed_server_side = has_been_closed_server_side ,
804
- has_more_rows = has_more_rows ,
805
- lz4_compressed = lz4_compressed ,
806
- is_staging_operation = is_staging_operation ,
784
+ status = CommandState .from_thrift_state (operation_state )
785
+ if status is None :
786
+ raise ValueError (f"Unknown command state: { operation_state } " )
787
+
788
+ execute_response = ExecuteResponse (
807
789
command_id = command_id ,
790
+ status = status ,
808
791
description = description ,
792
+ has_been_closed_server_side = has_been_closed_server_side ,
793
+ lz4_compressed = lz4_compressed ,
794
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
809
795
arrow_schema_bytes = schema_bytes ,
796
+ result_format = t_result_set_metadata_resp .resultFormat ,
810
797
)
811
798
799
+ return execute_response , is_direct_results
800
+
812
801
def get_execution_result (
813
802
self , command_id : CommandId , cursor : "Cursor"
814
803
) -> "ResultSet" :
815
- from databricks .sql .result_set import ThriftResultSet
816
-
817
804
thrift_handle = command_id .to_thrift_handle ()
818
805
if not thrift_handle :
819
806
raise ValueError ("Not a valid Thrift command ID" )
@@ -835,9 +822,6 @@ def get_execution_result(
835
822
836
823
t_result_set_metadata_resp = resp .resultSetMetadata
837
824
838
- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
839
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
840
- has_more_rows = resp .hasMoreRows
841
825
description = self ._hive_schema_to_description (
842
826
t_result_set_metadata_resp .schema
843
827
)
@@ -852,26 +836,21 @@ def get_execution_result(
852
836
else :
853
837
schema_bytes = None
854
838
855
- queue = ResultSetQueueFactory .build_queue (
856
- row_set_type = resp .resultSetMetadata .resultFormat ,
857
- t_row_set = resp .results ,
858
- arrow_schema_bytes = schema_bytes ,
859
- max_download_threads = self .max_download_threads ,
860
- lz4_compressed = lz4_compressed ,
861
- description = description ,
862
- ssl_options = self ._ssl_options ,
863
- )
839
+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841
+ is_direct_results = resp .hasMoreRows
842
+
843
+ status = self .get_query_state (command_id )
864
844
865
845
execute_response = ExecuteResponse (
866
- arrow_queue = queue ,
867
- status = CommandState .from_thrift_state (resp .status ),
846
+ command_id = command_id ,
847
+ status = status ,
848
+ description = description ,
868
849
has_been_closed_server_side = False ,
869
- has_more_rows = has_more_rows ,
870
850
lz4_compressed = lz4_compressed ,
871
851
is_staging_operation = is_staging_operation ,
872
- command_id = command_id ,
873
- description = description ,
874
852
arrow_schema_bytes = schema_bytes ,
853
+ result_format = t_result_set_metadata_resp .resultFormat ,
875
854
)
876
855
877
856
return ThriftResultSet (
@@ -881,6 +860,10 @@ def get_execution_result(
881
860
buffer_size_bytes = cursor .buffer_size_bytes ,
882
861
arraysize = cursor .arraysize ,
883
862
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
863
+ t_row_set = resp .results ,
864
+ max_download_threads = self .max_download_threads ,
865
+ ssl_options = self ._ssl_options ,
866
+ is_direct_results = is_direct_results ,
884
867
)
885
868
886
869
def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -947,8 +930,6 @@ def execute_command(
947
930
async_op = False ,
948
931
enforce_embedded_schema_correctness = False ,
949
932
) -> Union ["ResultSet" , None ]:
950
- from databricks .sql .result_set import ThriftResultSet
951
-
952
933
thrift_handle = session_id .to_thrift_handle ()
953
934
if not thrift_handle :
954
935
raise ValueError ("Not a valid Thrift session ID" )
@@ -995,7 +976,13 @@ def execute_command(
995
976
self ._handle_execute_response_async (resp , cursor )
996
977
return None
997
978
else :
998
- execute_response = self ._handle_execute_response (resp , cursor )
979
+ execute_response , is_direct_results = self ._handle_execute_response (
980
+ resp , cursor
981
+ )
982
+
983
+ t_row_set = None
984
+ if resp .directResults and resp .directResults .resultSet :
985
+ t_row_set = resp .directResults .resultSet .results
999
986
1000
987
return ThriftResultSet (
1001
988
connection = cursor .connection ,
@@ -1004,6 +991,10 @@ def execute_command(
1004
991
buffer_size_bytes = max_bytes ,
1005
992
arraysize = max_rows ,
1006
993
use_cloud_fetch = use_cloud_fetch ,
994
+ t_row_set = t_row_set ,
995
+ max_download_threads = self .max_download_threads ,
996
+ ssl_options = self ._ssl_options ,
997
+ is_direct_results = is_direct_results ,
1007
998
)
1008
999
1009
1000
def get_catalogs (
@@ -1013,8 +1004,6 @@ def get_catalogs(
1013
1004
max_bytes : int ,
1014
1005
cursor : "Cursor" ,
1015
1006
) -> "ResultSet" :
1016
- from databricks .sql .result_set import ThriftResultSet
1017
-
1018
1007
thrift_handle = session_id .to_thrift_handle ()
1019
1008
if not thrift_handle :
1020
1009
raise ValueError ("Not a valid Thrift session ID" )
@@ -1027,7 +1016,13 @@ def get_catalogs(
1027
1016
)
1028
1017
resp = self .make_request (self ._client .GetCatalogs , req )
1029
1018
1030
- execute_response = self ._handle_execute_response (resp , cursor )
1019
+ execute_response , is_direct_results = self ._handle_execute_response (
1020
+ resp , cursor
1021
+ )
1022
+
1023
+ t_row_set = None
1024
+ if resp .directResults and resp .directResults .resultSet :
1025
+ t_row_set = resp .directResults .resultSet .results
1031
1026
1032
1027
return ThriftResultSet (
1033
1028
connection = cursor .connection ,
@@ -1036,6 +1031,10 @@ def get_catalogs(
1036
1031
buffer_size_bytes = max_bytes ,
1037
1032
arraysize = max_rows ,
1038
1033
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1034
+ t_row_set = t_row_set ,
1035
+ max_download_threads = self .max_download_threads ,
1036
+ ssl_options = self ._ssl_options ,
1037
+ is_direct_results = is_direct_results ,
1039
1038
)
1040
1039
1041
1040
def get_schemas (
@@ -1047,8 +1046,6 @@ def get_schemas(
1047
1046
catalog_name = None ,
1048
1047
schema_name = None ,
1049
1048
) -> "ResultSet" :
1050
- from databricks .sql .result_set import ThriftResultSet
1051
-
1052
1049
thrift_handle = session_id .to_thrift_handle ()
1053
1050
if not thrift_handle :
1054
1051
raise ValueError ("Not a valid Thrift session ID" )
@@ -1063,7 +1060,13 @@ def get_schemas(
1063
1060
)
1064
1061
resp = self .make_request (self ._client .GetSchemas , req )
1065
1062
1066
- execute_response = self ._handle_execute_response (resp , cursor )
1063
+ execute_response , is_direct_results = self ._handle_execute_response (
1064
+ resp , cursor
1065
+ )
1066
+
1067
+ t_row_set = None
1068
+ if resp .directResults and resp .directResults .resultSet :
1069
+ t_row_set = resp .directResults .resultSet .results
1067
1070
1068
1071
return ThriftResultSet (
1069
1072
connection = cursor .connection ,
@@ -1072,6 +1075,10 @@ def get_schemas(
1072
1075
buffer_size_bytes = max_bytes ,
1073
1076
arraysize = max_rows ,
1074
1077
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1078
+ t_row_set = t_row_set ,
1079
+ max_download_threads = self .max_download_threads ,
1080
+ ssl_options = self ._ssl_options ,
1081
+ is_direct_results = is_direct_results ,
1075
1082
)
1076
1083
1077
1084
def get_tables (
@@ -1085,8 +1092,6 @@ def get_tables(
1085
1092
table_name = None ,
1086
1093
table_types = None ,
1087
1094
) -> "ResultSet" :
1088
- from databricks .sql .result_set import ThriftResultSet
1089
-
1090
1095
thrift_handle = session_id .to_thrift_handle ()
1091
1096
if not thrift_handle :
1092
1097
raise ValueError ("Not a valid Thrift session ID" )
@@ -1103,7 +1108,13 @@ def get_tables(
1103
1108
)
1104
1109
resp = self .make_request (self ._client .GetTables , req )
1105
1110
1106
- execute_response = self ._handle_execute_response (resp , cursor )
1111
+ execute_response , is_direct_results = self ._handle_execute_response (
1112
+ resp , cursor
1113
+ )
1114
+
1115
+ t_row_set = None
1116
+ if resp .directResults and resp .directResults .resultSet :
1117
+ t_row_set = resp .directResults .resultSet .results
1107
1118
1108
1119
return ThriftResultSet (
1109
1120
connection = cursor .connection ,
@@ -1112,6 +1123,10 @@ def get_tables(
1112
1123
buffer_size_bytes = max_bytes ,
1113
1124
arraysize = max_rows ,
1114
1125
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1126
+ t_row_set = t_row_set ,
1127
+ max_download_threads = self .max_download_threads ,
1128
+ ssl_options = self ._ssl_options ,
1129
+ is_direct_results = is_direct_results ,
1115
1130
)
1116
1131
1117
1132
def get_columns (
@@ -1125,8 +1140,6 @@ def get_columns(
1125
1140
table_name = None ,
1126
1141
column_name = None ,
1127
1142
) -> "ResultSet" :
1128
- from databricks .sql .result_set import ThriftResultSet
1129
-
1130
1143
thrift_handle = session_id .to_thrift_handle ()
1131
1144
if not thrift_handle :
1132
1145
raise ValueError ("Not a valid Thrift session ID" )
@@ -1143,7 +1156,13 @@ def get_columns(
1143
1156
)
1144
1157
resp = self .make_request (self ._client .GetColumns , req )
1145
1158
1146
- execute_response = self ._handle_execute_response (resp , cursor )
1159
+ execute_response , is_direct_results = self ._handle_execute_response (
1160
+ resp , cursor
1161
+ )
1162
+
1163
+ t_row_set = None
1164
+ if resp .directResults and resp .directResults .resultSet :
1165
+ t_row_set = resp .directResults .resultSet .results
1147
1166
1148
1167
return ThriftResultSet (
1149
1168
connection = cursor .connection ,
@@ -1152,6 +1171,10 @@ def get_columns(
1152
1171
buffer_size_bytes = max_bytes ,
1153
1172
arraysize = max_rows ,
1154
1173
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1174
+ t_row_set = t_row_set ,
1175
+ max_download_threads = self .max_download_threads ,
1176
+ ssl_options = self ._ssl_options ,
1177
+ is_direct_results = is_direct_results ,
1155
1178
)
1156
1179
1157
1180
def _handle_execute_response (self , resp , cursor ):
0 commit comments