3
3
import logging
4
4
import math
5
5
import time
6
- import uuid
7
6
import threading
8
7
from typing import List , Union , Any , TYPE_CHECKING
9
8
10
9
if TYPE_CHECKING :
11
10
from databricks .sql .client import Cursor
12
11
13
- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
14
12
from databricks .sql .backend .types import (
15
13
CommandState ,
16
14
SessionId ,
17
15
CommandId ,
18
- BackendType ,
19
- guid_to_hex_id ,
20
16
ExecuteResponse ,
21
17
)
22
18
from databricks .sql .backend .utils import guid_to_hex_id
23
19
20
+
24
21
try :
25
22
import pyarrow
26
23
except ImportError :
@@ -760,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
760
757
)
761
758
direct_results = resp .directResults
762
759
has_been_closed_server_side = direct_results and direct_results .closeOperation
763
- has_more_rows = (
760
+
761
+ is_direct_results = (
764
762
(not direct_results )
765
763
or (not direct_results .resultSet )
766
764
or direct_results .resultSet .hasMoreRows
767
765
)
766
+
768
767
description = self ._hive_schema_to_description (
769
768
t_result_set_metadata_resp .schema
770
769
)
@@ -780,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
780
779
schema_bytes = None
781
780
782
781
lz4_compressed = t_result_set_metadata_resp .lz4Compressed
783
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
784
- if direct_results and direct_results .resultSet :
785
- assert direct_results .resultSet .results .startRowOffset == 0
786
- assert direct_results .resultSetMetadata
787
-
788
- arrow_queue_opt = ResultSetQueueFactory .build_queue (
789
- row_set_type = t_result_set_metadata_resp .resultFormat ,
790
- t_row_set = direct_results .resultSet .results ,
791
- arrow_schema_bytes = schema_bytes ,
792
- max_download_threads = self .max_download_threads ,
793
- lz4_compressed = lz4_compressed ,
794
- description = description ,
795
- ssl_options = self ._ssl_options ,
796
- )
797
- else :
798
- arrow_queue_opt = None
799
-
800
782
command_id = CommandId .from_thrift_handle (resp .operationHandle )
801
783
802
784
status = CommandState .from_thrift_state (operation_state )
803
785
if status is None :
804
786
raise ValueError (f"Unknown command state: { operation_state } " )
805
787
806
- return (
807
- ExecuteResponse (
808
- command_id = command_id ,
809
- status = status ,
810
- description = description ,
811
- has_more_rows = has_more_rows ,
812
- results_queue = arrow_queue_opt ,
813
- has_been_closed_server_side = has_been_closed_server_side ,
814
- lz4_compressed = lz4_compressed ,
815
- is_staging_operation = is_staging_operation ,
816
- ),
817
- schema_bytes ,
788
+ execute_response = ExecuteResponse (
789
+ command_id = command_id ,
790
+ status = status ,
791
+ description = description ,
792
+ has_been_closed_server_side = has_been_closed_server_side ,
793
+ lz4_compressed = lz4_compressed ,
794
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
795
+ arrow_schema_bytes = schema_bytes ,
796
+ result_format = t_result_set_metadata_resp .resultFormat ,
818
797
)
819
798
799
+ return execute_response , is_direct_results
800
+
820
801
def get_execution_result (
821
802
self , command_id : CommandId , cursor : "Cursor"
822
803
) -> "ResultSet" :
@@ -841,9 +822,6 @@ def get_execution_result(
841
822
842
823
t_result_set_metadata_resp = resp .resultSetMetadata
843
824
844
- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
845
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
846
- has_more_rows = resp .hasMoreRows
847
825
description = self ._hive_schema_to_description (
848
826
t_result_set_metadata_resp .schema
849
827
)
@@ -858,25 +836,21 @@ def get_execution_result(
858
836
else :
859
837
schema_bytes = None
860
838
861
- queue = ResultSetQueueFactory .build_queue (
862
- row_set_type = resp .resultSetMetadata .resultFormat ,
863
- t_row_set = resp .results ,
864
- arrow_schema_bytes = schema_bytes ,
865
- max_download_threads = self .max_download_threads ,
866
- lz4_compressed = lz4_compressed ,
867
- description = description ,
868
- ssl_options = self ._ssl_options ,
869
- )
839
+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841
+ is_direct_results = resp .hasMoreRows
842
+
843
+ status = self .get_query_state (command_id )
870
844
871
845
execute_response = ExecuteResponse (
872
846
command_id = command_id ,
873
847
status = status ,
874
848
description = description ,
875
- has_more_rows = has_more_rows ,
876
- results_queue = queue ,
877
849
has_been_closed_server_side = False ,
878
850
lz4_compressed = lz4_compressed ,
879
851
is_staging_operation = is_staging_operation ,
852
+ arrow_schema_bytes = schema_bytes ,
853
+ result_format = t_result_set_metadata_resp .resultFormat ,
880
854
)
881
855
882
856
return ThriftResultSet (
@@ -886,7 +860,10 @@ def get_execution_result(
886
860
buffer_size_bytes = cursor .buffer_size_bytes ,
887
861
arraysize = cursor .arraysize ,
888
862
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
889
- arrow_schema_bytes = schema_bytes ,
863
+ t_row_set = resp .results ,
864
+ max_download_threads = self .max_download_threads ,
865
+ ssl_options = self ._ssl_options ,
866
+ is_direct_results = is_direct_results ,
890
867
)
891
868
892
869
def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -999,18 +976,25 @@ def execute_command(
999
976
self ._handle_execute_response_async (resp , cursor )
1000
977
return None
1001
978
else :
1002
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
979
+ execute_response , is_direct_results = self ._handle_execute_response (
1003
980
resp , cursor
1004
981
)
1005
982
983
+ t_row_set = None
984
+ if resp .directResults and resp .directResults .resultSet :
985
+ t_row_set = resp .directResults .resultSet .results
986
+
1006
987
return ThriftResultSet (
1007
988
connection = cursor .connection ,
1008
989
execute_response = execute_response ,
1009
990
thrift_client = self ,
1010
991
buffer_size_bytes = max_bytes ,
1011
992
arraysize = max_rows ,
1012
993
use_cloud_fetch = use_cloud_fetch ,
1013
- arrow_schema_bytes = arrow_schema_bytes ,
994
+ t_row_set = t_row_set ,
995
+ max_download_threads = self .max_download_threads ,
996
+ ssl_options = self ._ssl_options ,
997
+ is_direct_results = is_direct_results ,
1014
998
)
1015
999
1016
1000
def get_catalogs (
@@ -1032,18 +1016,25 @@ def get_catalogs(
1032
1016
)
1033
1017
resp = self .make_request (self ._client .GetCatalogs , req )
1034
1018
1035
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1019
+ execute_response , is_direct_results = self ._handle_execute_response (
1036
1020
resp , cursor
1037
1021
)
1038
1022
1023
+ t_row_set = None
1024
+ if resp .directResults and resp .directResults .resultSet :
1025
+ t_row_set = resp .directResults .resultSet .results
1026
+
1039
1027
return ThriftResultSet (
1040
1028
connection = cursor .connection ,
1041
1029
execute_response = execute_response ,
1042
1030
thrift_client = self ,
1043
1031
buffer_size_bytes = max_bytes ,
1044
1032
arraysize = max_rows ,
1045
1033
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1046
- arrow_schema_bytes = arrow_schema_bytes ,
1034
+ t_row_set = t_row_set ,
1035
+ max_download_threads = self .max_download_threads ,
1036
+ ssl_options = self ._ssl_options ,
1037
+ is_direct_results = is_direct_results ,
1047
1038
)
1048
1039
1049
1040
def get_schemas (
@@ -1069,18 +1060,25 @@ def get_schemas(
1069
1060
)
1070
1061
resp = self .make_request (self ._client .GetSchemas , req )
1071
1062
1072
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1063
+ execute_response , is_direct_results = self ._handle_execute_response (
1073
1064
resp , cursor
1074
1065
)
1075
1066
1067
+ t_row_set = None
1068
+ if resp .directResults and resp .directResults .resultSet :
1069
+ t_row_set = resp .directResults .resultSet .results
1070
+
1076
1071
return ThriftResultSet (
1077
1072
connection = cursor .connection ,
1078
1073
execute_response = execute_response ,
1079
1074
thrift_client = self ,
1080
1075
buffer_size_bytes = max_bytes ,
1081
1076
arraysize = max_rows ,
1082
1077
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1083
- arrow_schema_bytes = arrow_schema_bytes ,
1078
+ t_row_set = t_row_set ,
1079
+ max_download_threads = self .max_download_threads ,
1080
+ ssl_options = self ._ssl_options ,
1081
+ is_direct_results = is_direct_results ,
1084
1082
)
1085
1083
1086
1084
def get_tables (
@@ -1110,18 +1108,25 @@ def get_tables(
1110
1108
)
1111
1109
resp = self .make_request (self ._client .GetTables , req )
1112
1110
1113
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1111
+ execute_response , is_direct_results = self ._handle_execute_response (
1114
1112
resp , cursor
1115
1113
)
1116
1114
1115
+ t_row_set = None
1116
+ if resp .directResults and resp .directResults .resultSet :
1117
+ t_row_set = resp .directResults .resultSet .results
1118
+
1117
1119
return ThriftResultSet (
1118
1120
connection = cursor .connection ,
1119
1121
execute_response = execute_response ,
1120
1122
thrift_client = self ,
1121
1123
buffer_size_bytes = max_bytes ,
1122
1124
arraysize = max_rows ,
1123
1125
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1124
- arrow_schema_bytes = arrow_schema_bytes ,
1126
+ t_row_set = t_row_set ,
1127
+ max_download_threads = self .max_download_threads ,
1128
+ ssl_options = self ._ssl_options ,
1129
+ is_direct_results = is_direct_results ,
1125
1130
)
1126
1131
1127
1132
def get_columns (
@@ -1151,18 +1156,25 @@ def get_columns(
1151
1156
)
1152
1157
resp = self .make_request (self ._client .GetColumns , req )
1153
1158
1154
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1159
+ execute_response , is_direct_results = self ._handle_execute_response (
1155
1160
resp , cursor
1156
1161
)
1157
1162
1163
+ t_row_set = None
1164
+ if resp .directResults and resp .directResults .resultSet :
1165
+ t_row_set = resp .directResults .resultSet .results
1166
+
1158
1167
return ThriftResultSet (
1159
1168
connection = cursor .connection ,
1160
1169
execute_response = execute_response ,
1161
1170
thrift_client = self ,
1162
1171
buffer_size_bytes = max_bytes ,
1163
1172
arraysize = max_rows ,
1164
1173
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1165
- arrow_schema_bytes = arrow_schema_bytes ,
1174
+ t_row_set = t_row_set ,
1175
+ max_download_threads = self .max_download_threads ,
1176
+ ssl_options = self ._ssl_options ,
1177
+ is_direct_results = is_direct_results ,
1166
1178
)
1167
1179
1168
1180
def _handle_execute_response (self , resp , cursor ):
0 commit comments