3
3
import logging
4
4
import math
5
5
import time
6
- import uuid
7
6
import threading
8
7
from typing import List , Union , Any , TYPE_CHECKING
9
8
10
9
if TYPE_CHECKING :
11
10
from databricks .sql .client import Cursor
12
11
13
- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
14
12
from databricks .sql .backend .types import (
15
13
CommandState ,
16
14
SessionId ,
17
15
CommandId ,
18
- BackendType ,
19
- guid_to_hex_id ,
20
16
ExecuteResponse ,
21
17
)
22
18
19
+
23
20
try :
24
21
import pyarrow
25
22
except ImportError :
@@ -759,11 +756,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759
756
)
760
757
direct_results = resp .directResults
761
758
has_been_closed_server_side = direct_results and direct_results .closeOperation
762
- has_more_rows = (
759
+
760
+ is_direct_results = (
763
761
(not direct_results )
764
762
or (not direct_results .resultSet )
765
763
or direct_results .resultSet .hasMoreRows
766
764
)
765
+
767
766
description = self ._hive_schema_to_description (
768
767
t_result_set_metadata_resp .schema
769
768
)
@@ -779,43 +778,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779
778
schema_bytes = None
780
779
781
780
lz4_compressed = t_result_set_metadata_resp .lz4Compressed
782
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
783
- if direct_results and direct_results .resultSet :
784
- assert direct_results .resultSet .results .startRowOffset == 0
785
- assert direct_results .resultSetMetadata
786
-
787
- arrow_queue_opt = ResultSetQueueFactory .build_queue (
788
- row_set_type = t_result_set_metadata_resp .resultFormat ,
789
- t_row_set = direct_results .resultSet .results ,
790
- arrow_schema_bytes = schema_bytes ,
791
- max_download_threads = self .max_download_threads ,
792
- lz4_compressed = lz4_compressed ,
793
- description = description ,
794
- ssl_options = self ._ssl_options ,
795
- )
796
- else :
797
- arrow_queue_opt = None
798
-
799
781
command_id = CommandId .from_thrift_handle (resp .operationHandle )
800
782
801
783
status = CommandState .from_thrift_state (operation_state )
802
784
if status is None :
803
785
raise ValueError (f"Unknown command state: { operation_state } " )
804
786
805
- return (
806
- ExecuteResponse (
807
- command_id = command_id ,
808
- status = status ,
809
- description = description ,
810
- has_more_rows = has_more_rows ,
811
- results_queue = arrow_queue_opt ,
812
- has_been_closed_server_side = has_been_closed_server_side ,
813
- lz4_compressed = lz4_compressed ,
814
- is_staging_operation = is_staging_operation ,
815
- ),
816
- schema_bytes ,
787
+ execute_response = ExecuteResponse (
788
+ command_id = command_id ,
789
+ status = status ,
790
+ description = description ,
791
+ has_been_closed_server_side = has_been_closed_server_side ,
792
+ lz4_compressed = lz4_compressed ,
793
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
794
+ arrow_schema_bytes = schema_bytes ,
795
+ result_format = t_result_set_metadata_resp .resultFormat ,
817
796
)
818
797
798
+ return execute_response , is_direct_results
799
+
819
800
def get_execution_result (
820
801
self , command_id : CommandId , cursor : "Cursor"
821
802
) -> "ResultSet" :
@@ -840,9 +821,6 @@ def get_execution_result(
840
821
841
822
t_result_set_metadata_resp = resp .resultSetMetadata
842
823
843
- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
844
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
845
- has_more_rows = resp .hasMoreRows
846
824
description = self ._hive_schema_to_description (
847
825
t_result_set_metadata_resp .schema
848
826
)
@@ -857,27 +835,23 @@ def get_execution_result(
857
835
else :
858
836
schema_bytes = None
859
837
860
- queue = ResultSetQueueFactory .build_queue (
861
- row_set_type = resp .resultSetMetadata .resultFormat ,
862
- t_row_set = resp .results ,
863
- arrow_schema_bytes = schema_bytes ,
864
- max_download_threads = self .max_download_threads ,
865
- lz4_compressed = lz4_compressed ,
866
- description = description ,
867
- ssl_options = self ._ssl_options ,
868
- )
838
+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
839
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
840
+ is_direct_results = resp .hasMoreRows
841
+
842
+ status = self .get_query_state (command_id )
869
843
870
844
status = self .get_query_state (command_id )
871
845
872
846
execute_response = ExecuteResponse (
873
847
command_id = command_id ,
874
848
status = status ,
875
849
description = description ,
876
- has_more_rows = has_more_rows ,
877
- results_queue = queue ,
878
850
has_been_closed_server_side = False ,
879
851
lz4_compressed = lz4_compressed ,
880
852
is_staging_operation = is_staging_operation ,
853
+ arrow_schema_bytes = schema_bytes ,
854
+ result_format = t_result_set_metadata_resp .resultFormat ,
881
855
)
882
856
883
857
return ThriftResultSet (
@@ -887,7 +861,10 @@ def get_execution_result(
887
861
buffer_size_bytes = cursor .buffer_size_bytes ,
888
862
arraysize = cursor .arraysize ,
889
863
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
890
- arrow_schema_bytes = schema_bytes ,
864
+ t_row_set = resp .results ,
865
+ max_download_threads = self .max_download_threads ,
866
+ ssl_options = self ._ssl_options ,
867
+ is_direct_results = is_direct_results ,
891
868
)
892
869
893
870
def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -1000,18 +977,25 @@ def execute_command(
1000
977
self ._handle_execute_response_async (resp , cursor )
1001
978
return None
1002
979
else :
1003
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
980
+ execute_response , is_direct_results = self ._handle_execute_response (
1004
981
resp , cursor
1005
982
)
1006
983
984
+ t_row_set = None
985
+ if resp .directResults and resp .directResults .resultSet :
986
+ t_row_set = resp .directResults .resultSet .results
987
+
1007
988
return ThriftResultSet (
1008
989
connection = cursor .connection ,
1009
990
execute_response = execute_response ,
1010
991
thrift_client = self ,
1011
992
buffer_size_bytes = max_bytes ,
1012
993
arraysize = max_rows ,
1013
994
use_cloud_fetch = use_cloud_fetch ,
1014
- arrow_schema_bytes = arrow_schema_bytes ,
995
+ t_row_set = t_row_set ,
996
+ max_download_threads = self .max_download_threads ,
997
+ ssl_options = self ._ssl_options ,
998
+ is_direct_results = is_direct_results ,
1015
999
)
1016
1000
1017
1001
def get_catalogs (
@@ -1033,18 +1017,25 @@ def get_catalogs(
1033
1017
)
1034
1018
resp = self .make_request (self ._client .GetCatalogs , req )
1035
1019
1036
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1020
+ execute_response , is_direct_results = self ._handle_execute_response (
1037
1021
resp , cursor
1038
1022
)
1039
1023
1024
+ t_row_set = None
1025
+ if resp .directResults and resp .directResults .resultSet :
1026
+ t_row_set = resp .directResults .resultSet .results
1027
+
1040
1028
return ThriftResultSet (
1041
1029
connection = cursor .connection ,
1042
1030
execute_response = execute_response ,
1043
1031
thrift_client = self ,
1044
1032
buffer_size_bytes = max_bytes ,
1045
1033
arraysize = max_rows ,
1046
1034
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1047
- arrow_schema_bytes = arrow_schema_bytes ,
1035
+ t_row_set = t_row_set ,
1036
+ max_download_threads = self .max_download_threads ,
1037
+ ssl_options = self ._ssl_options ,
1038
+ is_direct_results = is_direct_results ,
1048
1039
)
1049
1040
1050
1041
def get_schemas (
@@ -1070,18 +1061,25 @@ def get_schemas(
1070
1061
)
1071
1062
resp = self .make_request (self ._client .GetSchemas , req )
1072
1063
1073
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1064
+ execute_response , is_direct_results = self ._handle_execute_response (
1074
1065
resp , cursor
1075
1066
)
1076
1067
1068
+ t_row_set = None
1069
+ if resp .directResults and resp .directResults .resultSet :
1070
+ t_row_set = resp .directResults .resultSet .results
1071
+
1077
1072
return ThriftResultSet (
1078
1073
connection = cursor .connection ,
1079
1074
execute_response = execute_response ,
1080
1075
thrift_client = self ,
1081
1076
buffer_size_bytes = max_bytes ,
1082
1077
arraysize = max_rows ,
1083
1078
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1084
- arrow_schema_bytes = arrow_schema_bytes ,
1079
+ t_row_set = t_row_set ,
1080
+ max_download_threads = self .max_download_threads ,
1081
+ ssl_options = self ._ssl_options ,
1082
+ is_direct_results = is_direct_results ,
1085
1083
)
1086
1084
1087
1085
def get_tables (
@@ -1111,18 +1109,25 @@ def get_tables(
1111
1109
)
1112
1110
resp = self .make_request (self ._client .GetTables , req )
1113
1111
1114
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1112
+ execute_response , is_direct_results = self ._handle_execute_response (
1115
1113
resp , cursor
1116
1114
)
1117
1115
1116
+ t_row_set = None
1117
+ if resp .directResults and resp .directResults .resultSet :
1118
+ t_row_set = resp .directResults .resultSet .results
1119
+
1118
1120
return ThriftResultSet (
1119
1121
connection = cursor .connection ,
1120
1122
execute_response = execute_response ,
1121
1123
thrift_client = self ,
1122
1124
buffer_size_bytes = max_bytes ,
1123
1125
arraysize = max_rows ,
1124
1126
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1125
- arrow_schema_bytes = arrow_schema_bytes ,
1127
+ t_row_set = t_row_set ,
1128
+ max_download_threads = self .max_download_threads ,
1129
+ ssl_options = self ._ssl_options ,
1130
+ is_direct_results = is_direct_results ,
1126
1131
)
1127
1132
1128
1133
def get_columns (
@@ -1152,18 +1157,25 @@ def get_columns(
1152
1157
)
1153
1158
resp = self .make_request (self ._client .GetColumns , req )
1154
1159
1155
- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1160
+ execute_response , is_direct_results = self ._handle_execute_response (
1156
1161
resp , cursor
1157
1162
)
1158
1163
1164
+ t_row_set = None
1165
+ if resp .directResults and resp .directResults .resultSet :
1166
+ t_row_set = resp .directResults .resultSet .results
1167
+
1159
1168
return ThriftResultSet (
1160
1169
connection = cursor .connection ,
1161
1170
execute_response = execute_response ,
1162
1171
thrift_client = self ,
1163
1172
buffer_size_bytes = max_bytes ,
1164
1173
arraysize = max_rows ,
1165
1174
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1166
- arrow_schema_bytes = arrow_schema_bytes ,
1175
+ t_row_set = t_row_set ,
1176
+ max_download_threads = self .max_download_threads ,
1177
+ ssl_options = self ._ssl_options ,
1178
+ is_direct_results = is_direct_results ,
1167
1179
)
1168
1180
1169
1181
def _handle_execute_response (self , resp , cursor ):
0 commit comments