3
3
import logging
4
4
import math
5
5
import time
6
+ import uuid
6
7
import threading
7
8
from typing import List , Union , Any , TYPE_CHECKING
8
9
9
10
if TYPE_CHECKING :
10
11
from databricks .sql .client import Cursor
11
12
13
+ from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
12
14
from databricks .sql .backend .types import (
13
15
CommandState ,
14
16
SessionId ,
15
17
CommandId ,
18
+ BackendType ,
19
+ guid_to_hex_id ,
16
20
ExecuteResponse ,
17
21
)
18
22
from databricks .sql .backend .utils import guid_to_hex_id
19
23
20
-
21
24
try :
22
25
import pyarrow
23
26
except ImportError :
@@ -757,13 +760,11 @@ def _results_message_to_execute_response(self, resp, operation_state):
757
760
)
758
761
direct_results = resp .directResults
759
762
has_been_closed_server_side = direct_results and direct_results .closeOperation
760
-
761
- is_direct_results = (
763
+ has_more_rows = (
762
764
(not direct_results )
763
765
or (not direct_results .resultSet )
764
766
or direct_results .resultSet .hasMoreRows
765
767
)
766
-
767
768
description = self ._hive_schema_to_description (
768
769
t_result_set_metadata_resp .schema
769
770
)
@@ -779,25 +780,43 @@ def _results_message_to_execute_response(self, resp, operation_state):
779
780
schema_bytes = None
780
781
781
782
lz4_compressed = t_result_set_metadata_resp .lz4Compressed
783
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
784
+ if direct_results and direct_results .resultSet :
785
+ assert direct_results .resultSet .results .startRowOffset == 0
786
+ assert direct_results .resultSetMetadata
787
+
788
+ arrow_queue_opt = ResultSetQueueFactory .build_queue (
789
+ row_set_type = t_result_set_metadata_resp .resultFormat ,
790
+ t_row_set = direct_results .resultSet .results ,
791
+ arrow_schema_bytes = schema_bytes ,
792
+ max_download_threads = self .max_download_threads ,
793
+ lz4_compressed = lz4_compressed ,
794
+ description = description ,
795
+ ssl_options = self ._ssl_options ,
796
+ )
797
+ else :
798
+ arrow_queue_opt = None
799
+
782
800
command_id = CommandId .from_thrift_handle (resp .operationHandle )
783
801
784
802
status = CommandState .from_thrift_state (operation_state )
785
803
if status is None :
786
804
raise ValueError (f"Unknown command state: { operation_state } " )
787
805
788
- execute_response = ExecuteResponse (
789
- command_id = command_id ,
790
- status = status ,
791
- description = description ,
792
- has_been_closed_server_side = has_been_closed_server_side ,
793
- lz4_compressed = lz4_compressed ,
794
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
795
- arrow_schema_bytes = schema_bytes ,
796
- result_format = t_result_set_metadata_resp .resultFormat ,
806
+ return (
807
+ ExecuteResponse (
808
+ command_id = command_id ,
809
+ status = status ,
810
+ description = description ,
811
+ has_more_rows = has_more_rows ,
812
+ results_queue = arrow_queue_opt ,
813
+ has_been_closed_server_side = has_been_closed_server_side ,
814
+ lz4_compressed = lz4_compressed ,
815
+ is_staging_operation = is_staging_operation ,
816
+ ),
817
+ schema_bytes ,
797
818
)
798
819
799
- return execute_response , is_direct_results
800
-
801
820
def get_execution_result (
802
821
self , command_id : CommandId , cursor : "Cursor"
803
822
) -> "ResultSet" :
@@ -822,6 +841,9 @@ def get_execution_result(
822
841
823
842
t_result_set_metadata_resp = resp .resultSetMetadata
824
843
844
+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
845
+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
846
+ has_more_rows = resp .hasMoreRows
825
847
description = self ._hive_schema_to_description (
826
848
t_result_set_metadata_resp .schema
827
849
)
@@ -836,21 +858,25 @@ def get_execution_result(
836
858
else :
837
859
schema_bytes = None
838
860
839
- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840
- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841
- is_direct_results = resp .hasMoreRows
842
-
843
- status = self .get_query_state (command_id )
861
+ queue = ResultSetQueueFactory .build_queue (
862
+ row_set_type = resp .resultSetMetadata .resultFormat ,
863
+ t_row_set = resp .results ,
864
+ arrow_schema_bytes = schema_bytes ,
865
+ max_download_threads = self .max_download_threads ,
866
+ lz4_compressed = lz4_compressed ,
867
+ description = description ,
868
+ ssl_options = self ._ssl_options ,
869
+ )
844
870
845
871
execute_response = ExecuteResponse (
846
872
command_id = command_id ,
847
873
status = status ,
848
874
description = description ,
875
+ has_more_rows = has_more_rows ,
876
+ results_queue = queue ,
849
877
has_been_closed_server_side = False ,
850
878
lz4_compressed = lz4_compressed ,
851
879
is_staging_operation = is_staging_operation ,
852
- arrow_schema_bytes = schema_bytes ,
853
- result_format = t_result_set_metadata_resp .resultFormat ,
854
880
)
855
881
856
882
return ThriftResultSet (
@@ -860,10 +886,7 @@ def get_execution_result(
860
886
buffer_size_bytes = cursor .buffer_size_bytes ,
861
887
arraysize = cursor .arraysize ,
862
888
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
863
- t_row_set = resp .results ,
864
- max_download_threads = self .max_download_threads ,
865
- ssl_options = self ._ssl_options ,
866
- is_direct_results = is_direct_results ,
889
+ arrow_schema_bytes = schema_bytes ,
867
890
)
868
891
869
892
def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -976,25 +999,18 @@ def execute_command(
976
999
self ._handle_execute_response_async (resp , cursor )
977
1000
return None
978
1001
else :
979
- execute_response , is_direct_results = self ._handle_execute_response (
1002
+ execute_response , arrow_schema_bytes = self ._handle_execute_response (
980
1003
resp , cursor
981
1004
)
982
1005
983
- t_row_set = None
984
- if resp .directResults and resp .directResults .resultSet :
985
- t_row_set = resp .directResults .resultSet .results
986
-
987
1006
return ThriftResultSet (
988
1007
connection = cursor .connection ,
989
1008
execute_response = execute_response ,
990
1009
thrift_client = self ,
991
1010
buffer_size_bytes = max_bytes ,
992
1011
arraysize = max_rows ,
993
1012
use_cloud_fetch = use_cloud_fetch ,
994
- t_row_set = t_row_set ,
995
- max_download_threads = self .max_download_threads ,
996
- ssl_options = self ._ssl_options ,
997
- is_direct_results = is_direct_results ,
1013
+ arrow_schema_bytes = arrow_schema_bytes ,
998
1014
)
999
1015
1000
1016
def get_catalogs (
@@ -1016,25 +1032,18 @@ def get_catalogs(
1016
1032
)
1017
1033
resp = self .make_request (self ._client .GetCatalogs , req )
1018
1034
1019
- execute_response , is_direct_results = self ._handle_execute_response (
1035
+ execute_response , arrow_schema_bytes = self ._handle_execute_response (
1020
1036
resp , cursor
1021
1037
)
1022
1038
1023
- t_row_set = None
1024
- if resp .directResults and resp .directResults .resultSet :
1025
- t_row_set = resp .directResults .resultSet .results
1026
-
1027
1039
return ThriftResultSet (
1028
1040
connection = cursor .connection ,
1029
1041
execute_response = execute_response ,
1030
1042
thrift_client = self ,
1031
1043
buffer_size_bytes = max_bytes ,
1032
1044
arraysize = max_rows ,
1033
1045
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1034
- t_row_set = t_row_set ,
1035
- max_download_threads = self .max_download_threads ,
1036
- ssl_options = self ._ssl_options ,
1037
- is_direct_results = is_direct_results ,
1046
+ arrow_schema_bytes = arrow_schema_bytes ,
1038
1047
)
1039
1048
1040
1049
def get_schemas (
@@ -1060,25 +1069,18 @@ def get_schemas(
1060
1069
)
1061
1070
resp = self .make_request (self ._client .GetSchemas , req )
1062
1071
1063
- execute_response , is_direct_results = self ._handle_execute_response (
1072
+ execute_response , arrow_schema_bytes = self ._handle_execute_response (
1064
1073
resp , cursor
1065
1074
)
1066
1075
1067
- t_row_set = None
1068
- if resp .directResults and resp .directResults .resultSet :
1069
- t_row_set = resp .directResults .resultSet .results
1070
-
1071
1076
return ThriftResultSet (
1072
1077
connection = cursor .connection ,
1073
1078
execute_response = execute_response ,
1074
1079
thrift_client = self ,
1075
1080
buffer_size_bytes = max_bytes ,
1076
1081
arraysize = max_rows ,
1077
1082
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1078
- t_row_set = t_row_set ,
1079
- max_download_threads = self .max_download_threads ,
1080
- ssl_options = self ._ssl_options ,
1081
- is_direct_results = is_direct_results ,
1083
+ arrow_schema_bytes = arrow_schema_bytes ,
1082
1084
)
1083
1085
1084
1086
def get_tables (
@@ -1108,25 +1110,18 @@ def get_tables(
1108
1110
)
1109
1111
resp = self .make_request (self ._client .GetTables , req )
1110
1112
1111
- execute_response , is_direct_results = self ._handle_execute_response (
1113
+ execute_response , arrow_schema_bytes = self ._handle_execute_response (
1112
1114
resp , cursor
1113
1115
)
1114
1116
1115
- t_row_set = None
1116
- if resp .directResults and resp .directResults .resultSet :
1117
- t_row_set = resp .directResults .resultSet .results
1118
-
1119
1117
return ThriftResultSet (
1120
1118
connection = cursor .connection ,
1121
1119
execute_response = execute_response ,
1122
1120
thrift_client = self ,
1123
1121
buffer_size_bytes = max_bytes ,
1124
1122
arraysize = max_rows ,
1125
1123
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1126
- t_row_set = t_row_set ,
1127
- max_download_threads = self .max_download_threads ,
1128
- ssl_options = self ._ssl_options ,
1129
- is_direct_results = is_direct_results ,
1124
+ arrow_schema_bytes = arrow_schema_bytes ,
1130
1125
)
1131
1126
1132
1127
def get_columns (
@@ -1156,25 +1151,18 @@ def get_columns(
1156
1151
)
1157
1152
resp = self .make_request (self ._client .GetColumns , req )
1158
1153
1159
- execute_response , is_direct_results = self ._handle_execute_response (
1154
+ execute_response , arrow_schema_bytes = self ._handle_execute_response (
1160
1155
resp , cursor
1161
1156
)
1162
1157
1163
- t_row_set = None
1164
- if resp .directResults and resp .directResults .resultSet :
1165
- t_row_set = resp .directResults .resultSet .results
1166
-
1167
1158
return ThriftResultSet (
1168
1159
connection = cursor .connection ,
1169
1160
execute_response = execute_response ,
1170
1161
thrift_client = self ,
1171
1162
buffer_size_bytes = max_bytes ,
1172
1163
arraysize = max_rows ,
1173
1164
use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1174
- t_row_set = t_row_set ,
1175
- max_download_threads = self .max_download_threads ,
1176
- ssl_options = self ._ssl_options ,
1177
- is_direct_results = is_direct_results ,
1165
+ arrow_schema_bytes = arrow_schema_bytes ,
1178
1166
)
1179
1167
1180
1168
def _handle_execute_response (self , resp , cursor ):
0 commit comments