Skip to content

Commit 8ea5cf4

Browse files
ensure backend client returns a ResultSet type in backend tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 574df21 commit 8ea5cf4

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

tests/unit/test_thrift_backend.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from databricks.sql import *
1919
from databricks.sql.auth.authenticators import AuthProvider
2020
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
21+
from databricks.sql.result_set import ResultSet, ThriftResultSet
2122

2223

2324
def retry_policy_factory():
@@ -1146,7 +1147,10 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11461147
thrift_backend._handle_execute_response = Mock()
11471148
cursor_mock = Mock()
11481149

1149-
thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock)
1150+
result = thrift_backend.execute_command("foo", Mock(), 100, 200, Mock(), cursor_mock)
1151+
# Verify the result is a ResultSet
1152+
self.assertIsInstance(result, ResultSet)
1153+
11501154
# Check call to client
11511155
req = tcli_service_instance.ExecuteStatement.call_args[0][0]
11521156
get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)
@@ -1175,7 +1179,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
11751179
thrift_backend._handle_execute_response = Mock()
11761180
cursor_mock = Mock()
11771181

1178-
thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock)
1182+
result = thrift_backend.get_catalogs(Mock(), 100, 200, cursor_mock)
1183+
# Verify the result is a ResultSet
1184+
self.assertIsInstance(result, ResultSet)
1185+
11791186
# Check call to client
11801187
req = tcli_service_instance.GetCatalogs.call_args[0][0]
11811188
get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)
@@ -1203,14 +1210,17 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12031210
thrift_backend._handle_execute_response = Mock()
12041211
cursor_mock = Mock()
12051212

1206-
thrift_backend.get_schemas(
1213+
result = thrift_backend.get_schemas(
12071214
Mock(),
12081215
100,
12091216
200,
12101217
cursor_mock,
12111218
catalog_name="catalog_pattern",
12121219
schema_name="schema_pattern",
12131220
)
1221+
# Verify the result is a ResultSet
1222+
self.assertIsInstance(result, ResultSet)
1223+
12141224
# Check call to client
12151225
req = tcli_service_instance.GetSchemas.call_args[0][0]
12161226
get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)
@@ -1240,7 +1250,7 @@ def test_get_tables_calls_client_and_handle_execute_response(
12401250
thrift_backend._handle_execute_response = Mock()
12411251
cursor_mock = Mock()
12421252

1243-
thrift_backend.get_tables(
1253+
result = thrift_backend.get_tables(
12441254
Mock(),
12451255
100,
12461256
200,
@@ -1250,6 +1260,9 @@ def test_get_tables_calls_client_and_handle_execute_response(
12501260
table_name="table_pattern",
12511261
table_types=["type1", "type2"],
12521262
)
1263+
# Verify the result is a ResultSet
1264+
self.assertIsInstance(result, ResultSet)
1265+
12531266
# Check call to client
12541267
req = tcli_service_instance.GetTables.call_args[0][0]
12551268
get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)
@@ -1281,7 +1294,7 @@ def test_get_columns_calls_client_and_handle_execute_response(
12811294
thrift_backend._handle_execute_response = Mock()
12821295
cursor_mock = Mock()
12831296

1284-
thrift_backend.get_columns(
1297+
result = thrift_backend.get_columns(
12851298
Mock(),
12861299
100,
12871300
200,
@@ -1291,6 +1304,9 @@ def test_get_columns_calls_client_and_handle_execute_response(
12911304
table_name="table_pattern",
12921305
column_name="column_pattern",
12931306
)
1307+
# Verify the result is a ResultSet
1308+
self.assertIsInstance(result, ResultSet)
1309+
12941310
# Check call to client
12951311
req = tcli_service_instance.GetColumns.call_args[0][0]
12961312
get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200)

0 commit comments

Comments
 (0)