@@ -623,7 +623,10 @@ def test_handle_execute_response_sets_compression_in_direct_results(
623
623
status = Mock (),
624
624
operationHandle = Mock (),
625
625
directResults = ttypes .TSparkDirectResults (
626
- operationStatus = op_status ,
626
+ operationStatus = ttypes .TGetOperationStatusResp (
627
+ status = self .okay_status ,
628
+ operationState = ttypes .TOperationState .FINISHED_STATE ,
629
+ ),
627
630
resultSetMetadata = ttypes .TGetResultSetMetadataResp (
628
631
status = self .okay_status ,
629
632
resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET ,
@@ -832,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self):
832
835
thrift_backend ._handle_execute_response (error_resp , Mock ())
833
836
self .assertIn ("this is a bad error" , str (cm .exception ))
834
837
838
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
835
839
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
836
840
def test_handle_execute_response_can_handle_without_direct_results (
837
- self , tcli_service_class
841
+ self , tcli_service_class , mock_result_set
838
842
):
839
843
tcli_service_instance = tcli_service_class .return_value
840
844
@@ -878,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results(
878
882
auth_provider = AuthProvider (),
879
883
ssl_options = SSLOptions (),
880
884
)
881
- execute_response , _ = thrift_backend . _handle_execute_response (
882
- execute_resp , Mock ()
883
- )
884
-
885
+ (
886
+ execute_response ,
887
+ _ ,
888
+ ) = thrift_backend . _handle_execute_response ( execute_resp , Mock ())
885
889
self .assertEqual (
886
890
execute_response .status ,
887
891
CommandState .SUCCEEDED ,
@@ -947,8 +951,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class):
947
951
tcli_service_instance .GetResultSetMetadata .return_value = (
948
952
t_get_result_set_metadata_resp
949
953
)
954
+ tcli_service_instance .GetOperationStatus .return_value = (
955
+ ttypes .TGetOperationStatusResp (
956
+ status = self .okay_status ,
957
+ operationState = ttypes .TOperationState .FINISHED_STATE ,
958
+ )
959
+ )
950
960
thrift_backend = self ._make_fake_thrift_backend ()
951
- execute_response , arrow_schema_bytes = thrift_backend ._handle_execute_response (
961
+ execute_response , _ = thrift_backend ._handle_execute_response (
952
962
t_execute_resp , Mock ()
953
963
)
954
964
@@ -973,8 +983,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
973
983
)
974
984
975
985
tcli_service_instance .GetResultSetMetadata .return_value = hive_schema_req
986
+ tcli_service_instance .GetOperationStatus .return_value = (
987
+ ttypes .TGetOperationStatusResp (
988
+ status = self .okay_status ,
989
+ operationState = ttypes .TOperationState .FINISHED_STATE ,
990
+ )
991
+ )
976
992
thrift_backend = self ._make_fake_thrift_backend ()
977
- thrift_backend ._handle_execute_response (t_execute_resp , Mock ())
993
+ _ , _ = thrift_backend ._handle_execute_response (t_execute_resp , Mock ())
978
994
979
995
self .assertEqual (
980
996
hive_schema_mock ,
@@ -988,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
988
1004
def test_handle_execute_response_reads_has_more_rows_in_direct_results (
989
1005
self , tcli_service_class , build_queue
990
1006
):
991
- for has_more_rows , resp_type in itertools .product (
1007
+ for is_direct_results , resp_type in itertools .product (
992
1008
[True , False ], self .execute_response_types
993
1009
):
994
- with self .subTest (has_more_rows = has_more_rows , resp_type = resp_type ):
1010
+ with self .subTest (is_direct_results = is_direct_results , resp_type = resp_type ):
995
1011
tcli_service_instance = tcli_service_class .return_value
996
1012
results_mock = Mock ()
997
1013
results_mock .startRowOffset = 0
@@ -1003,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
1003
1019
resultSetMetadata = self .metadata_resp ,
1004
1020
resultSet = ttypes .TFetchResultsResp (
1005
1021
status = self .okay_status ,
1006
- hasMoreRows = has_more_rows ,
1022
+ hasMoreRows = is_direct_results ,
1007
1023
results = results_mock ,
1008
1024
),
1009
1025
closeOperation = Mock (),
@@ -1019,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
1019
1035
)
1020
1036
thrift_backend = self ._make_fake_thrift_backend ()
1021
1037
1022
- execute_response , _ = thrift_backend ._handle_execute_response (
1023
- execute_resp , Mock ()
1024
- )
1038
+ (
1039
+ execute_response ,
1040
+ has_more_rows_result ,
1041
+ ) = thrift_backend ._handle_execute_response (execute_resp , Mock ())
1025
1042
1026
- self .assertEqual (has_more_rows , execute_response . has_more_rows )
1043
+ self .assertEqual (is_direct_results , has_more_rows_result )
1027
1044
1028
1045
@patch (
1029
1046
"databricks.sql.utils.ResultSetQueueFactory.build_queue" , return_value = Mock ()
@@ -1032,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
1032
1049
def test_handle_execute_response_reads_has_more_rows_in_result_response (
1033
1050
self , tcli_service_class , build_queue
1034
1051
):
1035
- for has_more_rows , resp_type in itertools .product (
1052
+ for is_direct_results , resp_type in itertools .product (
1036
1053
[True , False ], self .execute_response_types
1037
1054
):
1038
- with self .subTest (has_more_rows = has_more_rows , resp_type = resp_type ):
1055
+ with self .subTest (is_direct_results = is_direct_results , resp_type = resp_type ):
1039
1056
tcli_service_instance = tcli_service_class .return_value
1040
1057
results_mock = MagicMock ()
1041
1058
results_mock .startRowOffset = 0
@@ -1048,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
1048
1065
1049
1066
fetch_results_resp = ttypes .TFetchResultsResp (
1050
1067
status = self .okay_status ,
1051
- hasMoreRows = has_more_rows ,
1068
+ hasMoreRows = is_direct_results ,
1052
1069
results = results_mock ,
1053
1070
resultSetMetadata = ttypes .TGetResultSetMetadataResp (
1054
1071
resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET
@@ -1081,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
1081
1098
description = Mock (),
1082
1099
)
1083
1100
1084
- self .assertEqual (has_more_rows , has_more_rows_resp )
1101
+ self .assertEqual (is_direct_results , has_more_rows_resp )
1085
1102
1086
1103
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
1087
1104
def test_arrow_batches_row_count_are_respected (self , tcli_service_class ):
@@ -1136,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
1136
1153
1137
1154
self .assertEqual (arrow_queue .n_valid_rows , 15 * 10 )
1138
1155
1156
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
1139
1157
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
1140
1158
def test_execute_statement_calls_client_and_handle_execute_response (
1141
- self , tcli_service_class
1159
+ self , tcli_service_class , mock_result_set
1142
1160
):
1143
1161
tcli_service_instance = tcli_service_class .return_value
1144
1162
response = Mock ()
@@ -1151,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response(
1151
1169
auth_provider = AuthProvider (),
1152
1170
ssl_options = SSLOptions (),
1153
1171
)
1154
- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1172
+ thrift_backend ._handle_execute_response = Mock ()
1173
+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
1155
1174
cursor_mock = Mock ()
1156
1175
1157
1176
result = thrift_backend .execute_command (
1158
1177
"foo" , Mock (), 100 , 200 , Mock (), cursor_mock
1159
1178
)
1160
1179
# Verify the result is a ResultSet
1161
- self .assertIsInstance (result , ResultSet )
1180
+ self .assertEqual (result , mock_result_set . return_value )
1162
1181
1163
1182
# Check call to client
1164
1183
req = tcli_service_instance .ExecuteStatement .call_args [0 ][0 ]
@@ -1170,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response(
1170
1189
response , cursor_mock
1171
1190
)
1172
1191
1192
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
1173
1193
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
1174
1194
def test_get_catalogs_calls_client_and_handle_execute_response (
1175
- self , tcli_service_class
1195
+ self , tcli_service_class , mock_result_set
1176
1196
):
1177
1197
tcli_service_instance = tcli_service_class .return_value
1178
1198
response = Mock ()
@@ -1185,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
1185
1205
auth_provider = AuthProvider (),
1186
1206
ssl_options = SSLOptions (),
1187
1207
)
1188
- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1208
+ thrift_backend ._handle_execute_response = Mock ()
1209
+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
1189
1210
cursor_mock = Mock ()
1190
1211
1191
1212
result = thrift_backend .get_catalogs (Mock (), 100 , 200 , cursor_mock )
1192
1213
# Verify the result is a ResultSet
1193
- self .assertIsInstance (result , ResultSet )
1214
+ self .assertEqual (result , mock_result_set . return_value )
1194
1215
1195
1216
# Check call to client
1196
1217
req = tcli_service_instance .GetCatalogs .call_args [0 ][0 ]
@@ -1201,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
1201
1222
response , cursor_mock
1202
1223
)
1203
1224
1225
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
1204
1226
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
1205
1227
def test_get_schemas_calls_client_and_handle_execute_response (
1206
- self , tcli_service_class
1228
+ self , tcli_service_class , mock_result_set
1207
1229
):
1208
1230
tcli_service_instance = tcli_service_class .return_value
1209
1231
response = Mock ()
@@ -1216,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response(
1216
1238
auth_provider = AuthProvider (),
1217
1239
ssl_options = SSLOptions (),
1218
1240
)
1219
- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1241
+ thrift_backend ._handle_execute_response = Mock ()
1242
+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
1220
1243
cursor_mock = Mock ()
1221
1244
1222
1245
result = thrift_backend .get_schemas (
@@ -1228,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(
1228
1251
schema_name = "schema_pattern" ,
1229
1252
)
1230
1253
# Verify the result is a ResultSet
1231
- self .assertIsInstance (result , ResultSet )
1254
+ self .assertEqual (result , mock_result_set . return_value )
1232
1255
1233
1256
# Check call to client
1234
1257
req = tcli_service_instance .GetSchemas .call_args [0 ][0 ]
@@ -1241,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response(
1241
1264
response , cursor_mock
1242
1265
)
1243
1266
1267
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
1244
1268
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
1245
1269
def test_get_tables_calls_client_and_handle_execute_response (
1246
- self , tcli_service_class
1270
+ self , tcli_service_class , mock_result_set
1247
1271
):
1248
1272
tcli_service_instance = tcli_service_class .return_value
1249
1273
response = Mock ()
@@ -1256,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response(
1256
1280
auth_provider = AuthProvider (),
1257
1281
ssl_options = SSLOptions (),
1258
1282
)
1259
- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1283
+ thrift_backend ._handle_execute_response = Mock ()
1284
+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
1260
1285
cursor_mock = Mock ()
1261
1286
1262
1287
result = thrift_backend .get_tables (
@@ -1270,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response(
1270
1295
table_types = ["type1" , "type2" ],
1271
1296
)
1272
1297
# Verify the result is a ResultSet
1273
- self .assertIsInstance (result , ResultSet )
1298
+ self .assertEqual (result , mock_result_set . return_value )
1274
1299
1275
1300
# Check call to client
1276
1301
req = tcli_service_instance .GetTables .call_args [0 ][0 ]
@@ -1285,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response(
1285
1310
response , cursor_mock
1286
1311
)
1287
1312
1313
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
1288
1314
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
1289
1315
def test_get_columns_calls_client_and_handle_execute_response (
1290
- self , tcli_service_class
1316
+ self , tcli_service_class , mock_result_set
1291
1317
):
1292
1318
tcli_service_instance = tcli_service_class .return_value
1293
1319
response = Mock ()
@@ -1300,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response(
1300
1326
auth_provider = AuthProvider (),
1301
1327
ssl_options = SSLOptions (),
1302
1328
)
1303
- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1329
+ thrift_backend ._handle_execute_response = Mock ()
1330
+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
1304
1331
cursor_mock = Mock ()
1305
1332
1306
1333
result = thrift_backend .get_columns (
@@ -1314,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response(
1314
1341
column_name = "column_pattern" ,
1315
1342
)
1316
1343
# Verify the result is a ResultSet
1317
- self .assertIsInstance (result , ResultSet )
1344
+ self .assertEqual (result , mock_result_set . return_value )
1318
1345
1319
1346
# Check call to client
1320
1347
req = tcli_service_instance .GetColumns .call_args [0 ][0 ]
@@ -2203,14 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
2203
2230
str (cm .exception ),
2204
2231
)
2205
2232
2233
+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
2206
2234
@patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
2207
2235
@patch (
2208
2236
"databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response"
2209
2237
)
2210
2238
def test_execute_command_sets_complex_type_fields_correctly (
2211
- self , mock_handle_execute_response , tcli_service_class
2239
+ self , mock_handle_execute_response , tcli_service_class , mock_result_set
2212
2240
):
2213
2241
tcli_service_instance = tcli_service_class .return_value
2242
+ # Set up the mock to return a tuple with two values
2243
+ mock_execute_response = Mock ()
2244
+ mock_arrow_schema = Mock ()
2245
+ mock_handle_execute_response .return_value = (
2246
+ mock_execute_response ,
2247
+ mock_arrow_schema ,
2248
+ )
2249
+
2214
2250
# Iterate through each possible combination of native types (True, False and unset)
2215
2251
for complex , timestamp , decimals in itertools .product (
2216
2252
[True , False , None ], [True , False , None ], [True , False , None ]
0 commit comments