3
3
import logging
4
4
import math
5
5
import time
6
- import uuid
7
6
import threading
8
7
from typing import List , Union , Any , TYPE_CHECKING
9
8
10
9
if TYPE_CHECKING :
11
10
from databricks .sql .client import Cursor
12
11
13
- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
14
12
from databricks .sql .backend .types import (
15
13
CommandState ,
16
14
SessionId ,
17
15
CommandId ,
18
- BackendType ,
19
- guid_to_hex_id ,
20
16
ExecuteResponse ,
21
17
)
18
+ from databricks .sql .backend .utils import guid_to_hex_id
19
+
22
20
23
21
try :
24
22
import pyarrow
@@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759
757
)
760
758
direct_results = resp .directResults
761
759
has_been_closed_server_side = direct_results and direct_results .closeOperation
760
+
762
761
has_more_rows = (
763
762
(not direct_results )
764
763
or (not direct_results .resultSet )
765
764
or direct_results .resultSet .hasMoreRows
766
765
)
766
+
767
767
description = self ._hive_schema_to_description (
768
768
t_result_set_metadata_resp .schema
769
769
)
@@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779
779
schema_bytes = None
780
780
781
781
lz4_compressed = t_result_set_metadata_resp .lz4Compressed
782
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
783
- if direct_results and direct_results .resultSet :
784
- assert direct_results .resultSet .results .startRowOffset == 0
785
- assert direct_results .resultSetMetadata
786
-
787
- arrow_queue_opt = ThriftResultSetQueueFactory .build_queue (
788
- row_set_type = t_result_set_metadata_resp .resultFormat ,
789
- t_row_set = direct_results .resultSet .results ,
790
- arrow_schema_bytes = schema_bytes ,
791
- max_download_threads = self .max_download_threads ,
792
- lz4_compressed = lz4_compressed ,
793
- description = description ,
794
- ssl_options = self ._ssl_options ,
795
- )
796
- else :
797
- arrow_queue_opt = None
798
-
799
782
command_id = CommandId .from_thrift_handle (resp .operationHandle )
800
783
801
784
status = CommandState .from_thrift_state (operation_state )
802
785
if status is None :
803
786
raise ValueError (f"Unknown command state: { operation_state } " )
804
787
805
- return (
806
- ExecuteResponse (
807
- command_id = command_id ,
808
- status = status ,
809
- description = description ,
810
- has_more_rows = has_more_rows ,
811
- results_queue = arrow_queue_opt ,
812
- has_been_closed_server_side = has_been_closed_server_side ,
813
- lz4_compressed = lz4_compressed ,
814
- is_staging_operation = is_staging_operation ,
815
- ),
816
- schema_bytes ,
788
+ execute_response = ExecuteResponse (
789
+ command_id = command_id ,
790
+ status = status ,
791
+ description = description ,
792
+ has_been_closed_server_side = has_been_closed_server_side ,
793
+ lz4_compressed = lz4_compressed ,
794
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
795
+ arrow_schema_bytes = schema_bytes ,
796
+ result_format = t_result_set_metadata_resp .resultFormat ,
817
797
)
818
798
799
+ return execute_response , has_more_rows
800
+
819
801
def get_execution_result (
820
802
self , command_id : CommandId , cursor : "Cursor"
821
803
) -> "ResultSet" :
@@ -840,9 +822,6 @@ def get_execution_result(
840
822
841
823
t_result_set_metadata_resp = resp .resultSetMetadata
842
824
843
- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
844
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
845
- has_more_rows = resp .hasMoreRows
846
825
description = self ._hive_schema_to_description (
847
826
t_result_set_metadata_resp .schema
848
827
)
@@ -857,27 +836,21 @@ def get_execution_result(
857
836
else :
858
837
schema_bytes = None
859
838
860
- queue = ThriftResultSetQueueFactory .build_queue (
861
- row_set_type = resp .resultSetMetadata .resultFormat ,
862
- t_row_set = resp .results ,
863
- arrow_schema_bytes = schema_bytes ,
864
- max_download_threads = self .max_download_threads ,
865
- lz4_compressed = lz4_compressed ,
866
- description = description ,
867
- ssl_options = self ._ssl_options ,
868
- )
839
+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841
+ has_more_rows = resp .hasMoreRows
869
842
870
843
status = self .get_query_state (command_id )
871
844
872
845
execute_response = ExecuteResponse (
873
846
command_id = command_id ,
874
847
status = status ,
875
848
description = description ,
876
- has_more_rows = has_more_rows ,
877
- results_queue = queue ,
878
849
has_been_closed_server_side = False ,
879
850
lz4_compressed = lz4_compressed ,
880
851
is_staging_operation = is_staging_operation ,
852
+ arrow_schema_bytes = schema_bytes ,
853
+ result_format = t_result_set_metadata_resp .resultFormat ,
881
854
)
882
855
883
856
return ThriftResultSet (
@@ -887,7 +860,10 @@ def get_execution_result(
887
860
buffer_size_bytes = cursor .buffer_size_bytes ,
888
861
arraysize = cursor .arraysize ,
889
862
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
890
- arrow_schema_bytes = schema_bytes ,
863
+ t_row_set = resp .results ,
864
+ max_download_threads = self .max_download_threads ,
865
+ ssl_options = self ._ssl_options ,
866
+ has_more_rows = has_more_rows ,
891
867
)
892
868
893
869
def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
918
894
self ._check_command_not_in_error_or_closed_state (thrift_handle , poll_resp )
919
895
state = CommandState .from_thrift_state (operation_state )
920
896
if state is None :
921
- raise ValueError (f"Invalid operation state: { operation_state } " )
897
+ raise ValueError (f"Unknown command state: { operation_state } " )
922
898
return state
923
899
924
900
@staticmethod
@@ -1000,18 +976,25 @@ def execute_command(
1000
976
self ._handle_execute_response_async (resp , cursor )
1001
977
return None
1002
978
else :
1003
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
979
+ execute_response , has_more_rows = self ._handle_execute_response (
1004
980
resp , cursor
1005
981
)
1006
982
983
+ t_row_set = None
984
+ if resp .directResults and resp .directResults .resultSet :
985
+ t_row_set = resp .directResults .resultSet .results
986
+
1007
987
return ThriftResultSet (
1008
988
connection = cursor .connection ,
1009
989
execute_response = execute_response ,
1010
990
thrift_client = self ,
1011
991
buffer_size_bytes = max_bytes ,
1012
992
arraysize = max_rows ,
1013
993
use_cloud_fetch = use_cloud_fetch ,
1014
- arrow_schema_bytes = arrow_schema_bytes ,
994
+ t_row_set = t_row_set ,
995
+ max_download_threads = self .max_download_threads ,
996
+ ssl_options = self ._ssl_options ,
997
+ has_more_rows = has_more_rows ,
1015
998
)
1016
999
1017
1000
def get_catalogs (
@@ -1033,9 +1016,11 @@ def get_catalogs(
1033
1016
)
1034
1017
resp = self .make_request (self ._client .GetCatalogs , req )
1035
1018
1036
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1037
- resp , cursor
1038
- )
1019
+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1020
+
1021
+ t_row_set = None
1022
+ if resp .directResults and resp .directResults .resultSet :
1023
+ t_row_set = resp .directResults .resultSet .results
1039
1024
1040
1025
return ThriftResultSet (
1041
1026
connection = cursor .connection ,
@@ -1044,7 +1029,10 @@ def get_catalogs(
1044
1029
buffer_size_bytes = max_bytes ,
1045
1030
arraysize = max_rows ,
1046
1031
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1047
- arrow_schema_bytes = arrow_schema_bytes ,
1032
+ t_row_set = t_row_set ,
1033
+ max_download_threads = self .max_download_threads ,
1034
+ ssl_options = self ._ssl_options ,
1035
+ has_more_rows = has_more_rows ,
1048
1036
)
1049
1037
1050
1038
def get_schemas (
@@ -1070,9 +1058,11 @@ def get_schemas(
1070
1058
)
1071
1059
resp = self .make_request (self ._client .GetSchemas , req )
1072
1060
1073
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1074
- resp , cursor
1075
- )
1061
+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1062
+
1063
+ t_row_set = None
1064
+ if resp .directResults and resp .directResults .resultSet :
1065
+ t_row_set = resp .directResults .resultSet .results
1076
1066
1077
1067
return ThriftResultSet (
1078
1068
connection = cursor .connection ,
@@ -1081,7 +1071,10 @@ def get_schemas(
1081
1071
buffer_size_bytes = max_bytes ,
1082
1072
arraysize = max_rows ,
1083
1073
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1084
- arrow_schema_bytes = arrow_schema_bytes ,
1074
+ t_row_set = t_row_set ,
1075
+ max_download_threads = self .max_download_threads ,
1076
+ ssl_options = self ._ssl_options ,
1077
+ has_more_rows = has_more_rows ,
1085
1078
)
1086
1079
1087
1080
def get_tables (
@@ -1111,9 +1104,11 @@ def get_tables(
1111
1104
)
1112
1105
resp = self .make_request (self ._client .GetTables , req )
1113
1106
1114
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1115
- resp , cursor
1116
- )
1107
+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1108
+
1109
+ t_row_set = None
1110
+ if resp .directResults and resp .directResults .resultSet :
1111
+ t_row_set = resp .directResults .resultSet .results
1117
1112
1118
1113
return ThriftResultSet (
1119
1114
connection = cursor .connection ,
@@ -1122,7 +1117,10 @@ def get_tables(
1122
1117
buffer_size_bytes = max_bytes ,
1123
1118
arraysize = max_rows ,
1124
1119
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1125
- arrow_schema_bytes = arrow_schema_bytes ,
1120
+ t_row_set = t_row_set ,
1121
+ max_download_threads = self .max_download_threads ,
1122
+ ssl_options = self ._ssl_options ,
1123
+ has_more_rows = has_more_rows ,
1126
1124
)
1127
1125
1128
1126
def get_columns (
@@ -1152,9 +1150,11 @@ def get_columns(
1152
1150
)
1153
1151
resp = self .make_request (self ._client .GetColumns , req )
1154
1152
1155
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1156
- resp , cursor
1157
- )
1153
+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1154
+
1155
+ t_row_set = None
1156
+ if resp .directResults and resp .directResults .resultSet :
1157
+ t_row_set = resp .directResults .resultSet .results
1158
1158
1159
1159
return ThriftResultSet (
1160
1160
connection = cursor .connection ,
@@ -1163,7 +1163,10 @@ def get_columns(
1163
1163
buffer_size_bytes = max_bytes ,
1164
1164
arraysize = max_rows ,
1165
1165
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1166
- arrow_schema_bytes = arrow_schema_bytes ,
1166
+ t_row_set = t_row_set ,
1167
+ max_download_threads = self .max_download_threads ,
1168
+ ssl_options = self ._ssl_options ,
1169
+ has_more_rows = has_more_rows ,
1167
1170
)
1168
1171
1169
1172
def _handle_execute_response (self , resp , cursor ):
@@ -1177,11 +1180,7 @@ def _handle_execute_response(self, resp, cursor):
1177
1180
resp .directResults and resp .directResults .operationStatus ,
1178
1181
)
1179
1182
1180
- (
1181
- execute_response ,
1182
- arrow_schema_bytes ,
1183
- ) = self ._results_message_to_execute_response (resp , final_operation_state )
1184
- return execute_response , arrow_schema_bytes
1183
+ return self ._results_message_to_execute_response (resp , final_operation_state )
1185
1184
1186
1185
def _handle_execute_response_async (self , resp , cursor ):
1187
1186
command_id = CommandId .from_thrift_handle (resp .operationHandle )
@@ -1225,7 +1224,9 @@ def fetch_results(
1225
1224
)
1226
1225
)
1227
1226
1228
- queue = ThriftResultSetQueueFactory .build_queue (
1227
+ from databricks .sql .utils import ResultSetQueueFactory
1228
+
1229
+ queue = ResultSetQueueFactory .build_queue (
1229
1230
row_set_type = resp .resultSetMetadata .resultFormat ,
1230
1231
t_row_set = resp .results ,
1231
1232
arrow_schema_bytes = arrow_schema_bytes ,
0 commit comments