Skip to content

Commit 64e58b0

Browse files
remove un-necessary changes in thrift backend tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent fd52356 commit 64e58b0

File tree

1 file changed

+71
-35
lines changed

1 file changed

+71
-35
lines changed

tests/unit/test_thrift_backend.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,10 @@ def test_handle_execute_response_sets_compression_in_direct_results(
623623
status=Mock(),
624624
operationHandle=Mock(),
625625
directResults=ttypes.TSparkDirectResults(
626-
operationStatus=op_status,
626+
operationStatus=ttypes.TGetOperationStatusResp(
627+
status=self.okay_status,
628+
operationState=ttypes.TOperationState.FINISHED_STATE,
629+
),
627630
resultSetMetadata=ttypes.TGetResultSetMetadataResp(
628631
status=self.okay_status,
629632
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET,
@@ -832,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self):
832835
thrift_backend._handle_execute_response(error_resp, Mock())
833836
self.assertIn("this is a bad error", str(cm.exception))
834837

838+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
835839
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
836840
def test_handle_execute_response_can_handle_without_direct_results(
837-
self, tcli_service_class
841+
self, tcli_service_class, mock_result_set
838842
):
839843
tcli_service_instance = tcli_service_class.return_value
840844

@@ -878,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results(
878882
auth_provider=AuthProvider(),
879883
ssl_options=SSLOptions(),
880884
)
881-
execute_response, _ = thrift_backend._handle_execute_response(
882-
execute_resp, Mock()
883-
)
884-
885+
(
886+
execute_response,
887+
_,
888+
) = thrift_backend._handle_execute_response(execute_resp, Mock())
885889
self.assertEqual(
886890
execute_response.status,
887891
CommandState.SUCCEEDED,
@@ -947,8 +951,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class):
947951
tcli_service_instance.GetResultSetMetadata.return_value = (
948952
t_get_result_set_metadata_resp
949953
)
954+
tcli_service_instance.GetOperationStatus.return_value = (
955+
ttypes.TGetOperationStatusResp(
956+
status=self.okay_status,
957+
operationState=ttypes.TOperationState.FINISHED_STATE,
958+
)
959+
)
950960
thrift_backend = self._make_fake_thrift_backend()
951-
execute_response, arrow_schema_bytes = thrift_backend._handle_execute_response(
961+
execute_response, _ = thrift_backend._handle_execute_response(
952962
t_execute_resp, Mock()
953963
)
954964

@@ -973,8 +983,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
973983
)
974984

975985
tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req
986+
tcli_service_instance.GetOperationStatus.return_value = (
987+
ttypes.TGetOperationStatusResp(
988+
status=self.okay_status,
989+
operationState=ttypes.TOperationState.FINISHED_STATE,
990+
)
991+
)
976992
thrift_backend = self._make_fake_thrift_backend()
977-
thrift_backend._handle_execute_response(t_execute_resp, Mock())
993+
_, _ = thrift_backend._handle_execute_response(t_execute_resp, Mock())
978994

979995
self.assertEqual(
980996
hive_schema_mock,
@@ -988,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
9881004
def test_handle_execute_response_reads_has_more_rows_in_direct_results(
9891005
self, tcli_service_class, build_queue
9901006
):
991-
for has_more_rows, resp_type in itertools.product(
1007+
for is_direct_results, resp_type in itertools.product(
9921008
[True, False], self.execute_response_types
9931009
):
994-
with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type):
1010+
with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type):
9951011
tcli_service_instance = tcli_service_class.return_value
9961012
results_mock = Mock()
9971013
results_mock.startRowOffset = 0
@@ -1003,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10031019
resultSetMetadata=self.metadata_resp,
10041020
resultSet=ttypes.TFetchResultsResp(
10051021
status=self.okay_status,
1006-
hasMoreRows=has_more_rows,
1022+
hasMoreRows=is_direct_results,
10071023
results=results_mock,
10081024
),
10091025
closeOperation=Mock(),
@@ -1019,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10191035
)
10201036
thrift_backend = self._make_fake_thrift_backend()
10211037

1022-
execute_response, _ = thrift_backend._handle_execute_response(
1023-
execute_resp, Mock()
1024-
)
1038+
(
1039+
execute_response,
1040+
has_more_rows_result,
1041+
) = thrift_backend._handle_execute_response(execute_resp, Mock())
10251042

1026-
self.assertEqual(has_more_rows, execute_response.has_more_rows)
1043+
self.assertEqual(is_direct_results, has_more_rows_result)
10271044

10281045
@patch(
10291046
"databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()
@@ -1032,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10321049
def test_handle_execute_response_reads_has_more_rows_in_result_response(
10331050
self, tcli_service_class, build_queue
10341051
):
1035-
for has_more_rows, resp_type in itertools.product(
1052+
for is_direct_results, resp_type in itertools.product(
10361053
[True, False], self.execute_response_types
10371054
):
1038-
with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type):
1055+
with self.subTest(is_direct_results=is_direct_results, resp_type=resp_type):
10391056
tcli_service_instance = tcli_service_class.return_value
10401057
results_mock = MagicMock()
10411058
results_mock.startRowOffset = 0
@@ -1048,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
10481065

10491066
fetch_results_resp = ttypes.TFetchResultsResp(
10501067
status=self.okay_status,
1051-
hasMoreRows=has_more_rows,
1068+
hasMoreRows=is_direct_results,
10521069
results=results_mock,
10531070
resultSetMetadata=ttypes.TGetResultSetMetadataResp(
10541071
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET
@@ -1081,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
10811098
description=Mock(),
10821099
)
10831100

1084-
self.assertEqual(has_more_rows, has_more_rows_resp)
1101+
self.assertEqual(is_direct_results, has_more_rows_resp)
10851102

10861103
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
10871104
def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
@@ -1136,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
11361153

11371154
self.assertEqual(arrow_queue.n_valid_rows, 15 * 10)
11381155

1156+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
11391157
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
11401158
def test_execute_statement_calls_client_and_handle_execute_response(
1141-
self, tcli_service_class
1159+
self, tcli_service_class, mock_result_set
11421160
):
11431161
tcli_service_instance = tcli_service_class.return_value
11441162
response = Mock()
@@ -1151,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11511169
auth_provider=AuthProvider(),
11521170
ssl_options=SSLOptions(),
11531171
)
1154-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1172+
thrift_backend._handle_execute_response = Mock()
1173+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
11551174
cursor_mock = Mock()
11561175

11571176
result = thrift_backend.execute_command(
11581177
"foo", Mock(), 100, 200, Mock(), cursor_mock
11591178
)
11601179
# Verify the result is a ResultSet
1161-
self.assertIsInstance(result, ResultSet)
1180+
self.assertEqual(result, mock_result_set.return_value)
11621181

11631182
# Check call to client
11641183
req = tcli_service_instance.ExecuteStatement.call_args[0][0]
@@ -1170,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11701189
response, cursor_mock
11711190
)
11721191

1192+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
11731193
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
11741194
def test_get_catalogs_calls_client_and_handle_execute_response(
1175-
self, tcli_service_class
1195+
self, tcli_service_class, mock_result_set
11761196
):
11771197
tcli_service_instance = tcli_service_class.return_value
11781198
response = Mock()
@@ -1185,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
11851205
auth_provider=AuthProvider(),
11861206
ssl_options=SSLOptions(),
11871207
)
1188-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1208+
thrift_backend._handle_execute_response = Mock()
1209+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
11891210
cursor_mock = Mock()
11901211

11911212
result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock)
11921213
# Verify the result is a ResultSet
1193-
self.assertIsInstance(result, ResultSet)
1214+
self.assertEqual(result, mock_result_set.return_value)
11941215

11951216
# Check call to client
11961217
req = tcli_service_instance.GetCatalogs.call_args[0][0]
@@ -1201,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
12011222
response, cursor_mock
12021223
)
12031224

1225+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
12041226
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
12051227
def test_get_schemas_calls_client_and_handle_execute_response(
1206-
self, tcli_service_class
1228+
self, tcli_service_class, mock_result_set
12071229
):
12081230
tcli_service_instance = tcli_service_class.return_value
12091231
response = Mock()
@@ -1216,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12161238
auth_provider=AuthProvider(),
12171239
ssl_options=SSLOptions(),
12181240
)
1219-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1241+
thrift_backend._handle_execute_response = Mock()
1242+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
12201243
cursor_mock = Mock()
12211244

12221245
result = thrift_backend.get_schemas(
@@ -1228,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12281251
schema_name="schema_pattern",
12291252
)
12301253
# Verify the result is a ResultSet
1231-
self.assertIsInstance(result, ResultSet)
1254+
self.assertEqual(result, mock_result_set.return_value)
12321255

12331256
# Check call to client
12341257
req = tcli_service_instance.GetSchemas.call_args[0][0]
@@ -1241,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12411264
response, cursor_mock
12421265
)
12431266

1267+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
12441268
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
12451269
def test_get_tables_calls_client_and_handle_execute_response(
1246-
self, tcli_service_class
1270+
self, tcli_service_class, mock_result_set
12471271
):
12481272
tcli_service_instance = tcli_service_class.return_value
12491273
response = Mock()
@@ -1256,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response(
12561280
auth_provider=AuthProvider(),
12571281
ssl_options=SSLOptions(),
12581282
)
1259-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1283+
thrift_backend._handle_execute_response = Mock()
1284+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
12601285
cursor_mock = Mock()
12611286

12621287
result = thrift_backend.get_tables(
@@ -1270,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response(
12701295
table_types=["type1", "type2"],
12711296
)
12721297
# Verify the result is a ResultSet
1273-
self.assertIsInstance(result, ResultSet)
1298+
self.assertEqual(result, mock_result_set.return_value)
12741299

12751300
# Check call to client
12761301
req = tcli_service_instance.GetTables.call_args[0][0]
@@ -1285,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response(
12851310
response, cursor_mock
12861311
)
12871312

1313+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
12881314
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
12891315
def test_get_columns_calls_client_and_handle_execute_response(
1290-
self, tcli_service_class
1316+
self, tcli_service_class, mock_result_set
12911317
):
12921318
tcli_service_instance = tcli_service_class.return_value
12931319
response = Mock()
@@ -1300,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response(
13001326
auth_provider=AuthProvider(),
13011327
ssl_options=SSLOptions(),
13021328
)
1303-
thrift_backend._handle_execute_response = Mock(return_value=(Mock(), Mock()))
1329+
thrift_backend._handle_execute_response = Mock()
1330+
thrift_backend._handle_execute_response.return_value = (Mock(), Mock())
13041331
cursor_mock = Mock()
13051332

13061333
result = thrift_backend.get_columns(
@@ -1314,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response(
13141341
column_name="column_pattern",
13151342
)
13161343
# Verify the result is a ResultSet
1317-
self.assertIsInstance(result, ResultSet)
1344+
self.assertEqual(result, mock_result_set.return_value)
13181345

13191346
# Check call to client
13201347
req = tcli_service_instance.GetColumns.call_args[0][0]
@@ -2203,14 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
22032230
str(cm.exception),
22042231
)
22052232

2233+
@patch("databricks.sql.backend.thrift_backend.ThriftResultSet")
22062234
@patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True)
22072235
@patch(
22082236
"databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response"
22092237
)
22102238
def test_execute_command_sets_complex_type_fields_correctly(
2211-
self, mock_handle_execute_response, tcli_service_class
2239+
self, mock_handle_execute_response, tcli_service_class, mock_result_set
22122240
):
22132241
tcli_service_instance = tcli_service_class.return_value
2242+
# Set up the mock to return a tuple with two values
2243+
mock_execute_response = Mock()
2244+
mock_arrow_schema = Mock()
2245+
mock_handle_execute_response.return_value = (
2246+
mock_execute_response,
2247+
mock_arrow_schema,
2248+
)
2249+
22142250
# Iterate through each possible combination of native types (True, False and unset)
22152251
for complex, timestamp, decimals in itertools.product(
22162252
[True, False, None], [True, False, None], [True, False, None]

0 commit comments

Comments
 (0)