21
21
CursorAlreadyClosedError ,
22
22
)
23
23
from databricks .sql .thrift_api .TCLIService import ttypes
24
- from databricks .sql .thrift_backend import ThriftBackend
24
+ from databricks .sql .thrift_backend import ThriftDatabricksClient
25
+ from databricks .sql .db_client_interface import DatabricksClient
25
26
from databricks .sql .utils import (
26
27
ExecuteResponse ,
27
28
ParamEscaper ,
@@ -336,7 +337,7 @@ def cursor(
336
337
337
338
cursor = Cursor (
338
339
self ,
339
- self .session .thrift_backend ,
340
+ self .session .backend ,
340
341
arraysize = arraysize ,
341
342
result_buffer_size_bytes = buffer_size_bytes ,
342
343
)
@@ -369,7 +370,7 @@ class Cursor:
369
370
def __init__ (
370
371
self ,
371
372
connection : Connection ,
372
- thrift_backend : ThriftBackend ,
373
+ backend : DatabricksClient ,
373
374
result_buffer_size_bytes : int = DEFAULT_RESULT_BUFFER_SIZE_BYTES ,
374
375
arraysize : int = DEFAULT_ARRAY_SIZE ,
375
376
) -> None :
@@ -388,7 +389,7 @@ def __init__(
388
389
# Note that Cursor closed => active result set closed, but not vice versa
389
390
self .open = True
390
391
self .executing_command_id = None
391
- self .thrift_backend = thrift_backend
392
+ self .backend = backend
392
393
self .active_op_handle = None
393
394
self .escaper = ParamEscaper ()
394
395
self .lastrowid = None
@@ -753,7 +754,8 @@ def execute(
753
754
754
755
self ._check_not_closed ()
755
756
self ._close_and_clear_active_result_set ()
756
- execute_response = self .thrift_backend .execute_command (
757
+ print ("here" )
758
+ execute_response = self .backend .execute_command (
757
759
operation = prepared_operation ,
758
760
session_handle = self .connection .session ._session_handle ,
759
761
max_rows = self .arraysize ,
@@ -768,15 +770,15 @@ def execute(
768
770
self .active_result_set = ResultSet (
769
771
self .connection ,
770
772
execute_response ,
771
- self .thrift_backend ,
773
+ self .backend ,
772
774
self .buffer_size_bytes ,
773
775
self .arraysize ,
774
776
self .connection .use_cloud_fetch ,
775
777
)
776
778
777
779
if execute_response .is_staging_operation :
778
780
self ._handle_staging_operation (
779
- staging_allowed_local_path = self .thrift_backend .staging_allowed_local_path
781
+ staging_allowed_local_path = self .backend .staging_allowed_local_path
780
782
)
781
783
782
784
return self
@@ -816,7 +818,7 @@ def execute_async(
816
818
817
819
self ._check_not_closed ()
818
820
self ._close_and_clear_active_result_set ()
819
- self .thrift_backend .execute_command (
821
+ self .backend .execute_command (
820
822
operation = prepared_operation ,
821
823
session_handle = self .connection .session ._session_handle ,
822
824
max_rows = self .arraysize ,
@@ -838,7 +840,7 @@ def get_query_state(self) -> "TOperationState":
838
840
:return:
839
841
"""
840
842
self ._check_not_closed ()
841
- return self .thrift_backend .get_query_state (self .active_op_handle )
843
+ return self .backend .get_query_state (self .active_op_handle )
842
844
843
845
def is_query_pending (self ):
844
846
"""
@@ -868,20 +870,20 @@ def get_async_execution_result(self):
868
870
869
871
operation_state = self .get_query_state ()
870
872
if operation_state == ttypes .TOperationState .FINISHED_STATE :
871
- execute_response = self .thrift_backend .get_execution_result (
873
+ execute_response = self .backend .get_execution_result (
872
874
self .active_op_handle , self
873
875
)
874
876
self .active_result_set = ResultSet (
875
877
self .connection ,
876
878
execute_response ,
877
- self .thrift_backend ,
879
+ self .backend ,
878
880
self .buffer_size_bytes ,
879
881
self .arraysize ,
880
882
)
881
883
882
884
if execute_response .is_staging_operation :
883
885
self ._handle_staging_operation (
884
- staging_allowed_local_path = self .thrift_backend .staging_allowed_local_path
886
+ staging_allowed_local_path = self .backend .staging_allowed_local_path
885
887
)
886
888
887
889
return self
@@ -913,7 +915,7 @@ def catalogs(self) -> "Cursor":
913
915
"""
914
916
self ._check_not_closed ()
915
917
self ._close_and_clear_active_result_set ()
916
- execute_response = self .thrift_backend .get_catalogs (
918
+ execute_response = self .backend .get_catalogs (
917
919
session_handle = self .connection .session ._session_handle ,
918
920
max_rows = self .arraysize ,
919
921
max_bytes = self .buffer_size_bytes ,
@@ -922,9 +924,10 @@ def catalogs(self) -> "Cursor":
922
924
self .active_result_set = ResultSet (
923
925
self .connection ,
924
926
execute_response ,
925
- self .thrift_backend ,
927
+ self .backend ,
926
928
self .buffer_size_bytes ,
927
929
self .arraysize ,
930
+ self .connection .use_cloud_fetch ,
928
931
)
929
932
return self
930
933
@@ -939,7 +942,7 @@ def schemas(
939
942
"""
940
943
self ._check_not_closed ()
941
944
self ._close_and_clear_active_result_set ()
942
- execute_response = self .thrift_backend .get_schemas (
945
+ execute_response = self .backend .get_schemas (
943
946
session_handle = self .connection .session ._session_handle ,
944
947
max_rows = self .arraysize ,
945
948
max_bytes = self .buffer_size_bytes ,
@@ -950,9 +953,10 @@ def schemas(
950
953
self .active_result_set = ResultSet (
951
954
self .connection ,
952
955
execute_response ,
953
- self .thrift_backend ,
956
+ self .backend ,
954
957
self .buffer_size_bytes ,
955
958
self .arraysize ,
959
+ self .connection .use_cloud_fetch ,
956
960
)
957
961
return self
958
962
@@ -972,7 +976,7 @@ def tables(
972
976
self ._check_not_closed ()
973
977
self ._close_and_clear_active_result_set ()
974
978
975
- execute_response = self .thrift_backend .get_tables (
979
+ execute_response = self .backend .get_tables (
976
980
session_handle = self .connection .session ._session_handle ,
977
981
max_rows = self .arraysize ,
978
982
max_bytes = self .buffer_size_bytes ,
@@ -985,9 +989,10 @@ def tables(
985
989
self .active_result_set = ResultSet (
986
990
self .connection ,
987
991
execute_response ,
988
- self .thrift_backend ,
992
+ self .backend ,
989
993
self .buffer_size_bytes ,
990
994
self .arraysize ,
995
+ self .connection .use_cloud_fetch ,
991
996
)
992
997
return self
993
998
@@ -1007,7 +1012,7 @@ def columns(
1007
1012
self ._check_not_closed ()
1008
1013
self ._close_and_clear_active_result_set ()
1009
1014
1010
- execute_response = self .thrift_backend .get_columns (
1015
+ execute_response = self .backend .get_columns (
1011
1016
session_handle = self .connection .session ._session_handle ,
1012
1017
max_rows = self .arraysize ,
1013
1018
max_bytes = self .buffer_size_bytes ,
@@ -1020,9 +1025,10 @@ def columns(
1020
1025
self .active_result_set = ResultSet (
1021
1026
self .connection ,
1022
1027
execute_response ,
1023
- self .thrift_backend ,
1028
+ self .backend ,
1024
1029
self .buffer_size_bytes ,
1025
1030
self .arraysize ,
1031
+ self .connection .use_cloud_fetch ,
1026
1032
)
1027
1033
return self
1028
1034
@@ -1097,7 +1103,7 @@ def cancel(self) -> None:
1097
1103
This method can be called from another thread.
1098
1104
"""
1099
1105
if self .active_op_handle is not None :
1100
- self .thrift_backend .cancel_command (self .active_op_handle )
1106
+ self .backend .cancel_command (self .active_op_handle )
1101
1107
else :
1102
1108
logger .warning (
1103
1109
"Attempting to cancel a command, but there is no "
@@ -1172,7 +1178,7 @@ def __init__(
1172
1178
self ,
1173
1179
connection : Connection ,
1174
1180
execute_response : ExecuteResponse ,
1175
- thrift_backend : ThriftBackend ,
1181
+ backend : DatabricksClient ,
1176
1182
result_buffer_size_bytes : int = DEFAULT_RESULT_BUFFER_SIZE_BYTES ,
1177
1183
arraysize : int = 10000 ,
1178
1184
use_cloud_fetch : bool = True ,
@@ -1182,8 +1188,10 @@ def __init__(
1182
1188
1183
1189
:param connection: The parent connection that was used to execute this command
1184
1190
:param execute_response: A `ExecuteResponse` class returned by a command execution
1185
- :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
1186
- amount :param arraysize: The max number of rows to fetch at a time (PEP-249)
1191
+ :param backend: The DatabricksClient instance to use for fetching results
1192
+ :param result_buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch amount
1193
+ :param arraysize: The max number of rows to fetch at a time (PEP-249)
1194
+ :param use_cloud_fetch: Whether to use cloud fetch for retrieving results
1187
1195
"""
1188
1196
self .connection = connection
1189
1197
self .command_id = execute_response .command_handle
@@ -1193,7 +1201,7 @@ def __init__(
1193
1201
self .buffer_size_bytes = result_buffer_size_bytes
1194
1202
self .lz4_compressed = execute_response .lz4_compressed
1195
1203
self .arraysize = arraysize
1196
- self .thrift_backend = thrift_backend
1204
+ self .backend = backend
1197
1205
self .description = execute_response .description
1198
1206
self ._arrow_schema_bytes = execute_response .arrow_schema_bytes
1199
1207
self ._next_row_index = 0
@@ -1216,8 +1224,15 @@ def __iter__(self):
1216
1224
break
1217
1225
1218
1226
def _fill_results_buffer (self ):
1219
- # At initialization or if the server does not have cloud fetch result links available
1220
- results , has_more_rows = self .thrift_backend .fetch_results (
1227
+ if not isinstance (self .backend , ThriftDatabricksClient ):
1228
+ # This specific logic is for Thrift. SEA will have its own way.
1229
+ raise NotImplementedError (
1230
+ "Fetching further result batches is currently only implemented for the Thrift backend."
1231
+ )
1232
+
1233
+ # Now we know self.backend is ThriftDatabricksClient, so it has fetch_results
1234
+ thrift_backend_instance = self .backend # type: ThriftDatabricksClient
1235
+ results , has_more_rows = thrift_backend_instance .fetch_results (
1221
1236
op_handle = self .command_id ,
1222
1237
max_rows = self .arraysize ,
1223
1238
max_bytes = self .buffer_size_bytes ,
@@ -1433,19 +1448,20 @@ def close(self) -> None:
1433
1448
If the connection has not been closed, and the cursor has not already
1434
1449
been closed on the server for some other reason, issue a request to the server to close it.
1435
1450
"""
1451
+ # TODO: the state is still thrift specific, define some ENUM for status that each service has to map to
1436
1452
try :
1437
1453
if (
1438
- self .op_state != self . thrift_backend . CLOSED_OP_STATE
1454
+ self .op_state != ttypes . TOperationState . CLOSED_STATE
1439
1455
and not self .has_been_closed_server_side
1440
1456
and self .connection .open
1441
1457
):
1442
- self .thrift_backend .close_command (self .command_id )
1458
+ self .backend .close_command (self .command_id )
1443
1459
except RequestError as e :
1444
1460
if isinstance (e .args [1 ], CursorAlreadyClosedError ):
1445
1461
logger .info ("Operation was canceled by a prior request" )
1446
1462
finally :
1447
1463
self .has_been_closed_server_side = True
1448
- self .op_state = self . thrift_backend . CLOSED_OP_STATE
1464
+ self .op_state = ttypes . TOperationState . CLOSED_STATE
1449
1465
1450
1466
@staticmethod
1451
1467
def _get_schema_description (table_schema_message ):
0 commit comments