Skip to content

Commit 7925954

Browse files
change names of ThriftBackend -> ThriftDatabricksClient in tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent f9fcb79 commit 7925954

File tree

4 files changed

+93
-92
lines changed

4 files changed

+93
-92
lines changed

tests/unit/test_client.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
THandleIdentifier,
1616
TOperationType,
1717
)
18-
from databricks.sql.thrift_backend import ThriftBackend
18+
from databricks.sql.thrift_backend import ThriftDatabricksClient
1919

2020
import databricks.sql
2121
import databricks.sql.client as client
@@ -27,10 +27,10 @@
2727
from tests.unit.test_arrow_queue import ArrowQueueSuite
2828

2929

30-
class ThriftBackendMockFactory:
30+
class ThriftDatabricksClientMockFactory:
3131
@classmethod
3232
def new(cls):
33-
ThriftBackendMock = Mock(spec=ThriftBackend)
33+
ThriftBackendMock = Mock(spec=ThriftDatabricksClient)
3434
ThriftBackendMock.return_value = ThriftBackendMock
3535

3636
cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None)
@@ -80,7 +80,7 @@ class ClientTestSuite(unittest.TestCase):
8080
"access_token": "tok",
8181
}
8282

83-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
83+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME, ThriftDatabricksClientMockFactory.new())
8484
@patch("%s.client.ResultSet" % PACKAGE_NAME)
8585
def test_closing_connection_closes_commands(self, mock_result_set_class):
8686
# Test once with has_been_closed_server side, once without
@@ -97,7 +97,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class):
9797
)
9898
mock_result_set_class.return_value.close.assert_called_once_with()
9999

100-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
100+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
101101
def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
102102
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
103103
self.assertTrue(connection.open)
@@ -107,7 +107,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
107107
connection.cursor()
108108
self.assertIn("closed", str(cm.exception))
109109

110-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
110+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
111111
@patch("%s.client.Cursor" % PACKAGE_NAME)
112112
def test_arraysize_buffer_size_passthrough(
113113
self, mock_cursor_class, mock_client_class
@@ -124,7 +124,7 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
124124
mock_backend = Mock()
125125
result_set = client.ResultSet(
126126
connection=mock_connection,
127-
thrift_backend=mock_backend,
127+
backend=mock_backend,
128128
execute_response=Mock(),
129129
)
130130
# Setup session mock on the mock_connection
@@ -166,7 +166,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command(
166166
mock_result_set_class.side_effect = mock_result_sets
167167

168168
cursor = client.Cursor(
169-
connection=Mock(), thrift_backend=ThriftBackendMockFactory.new()
169+
connection=Mock(), backend=ThriftDatabricksClientMockFactory.new()
170170
)
171171
cursor.execute("SELECT 1;")
172172
cursor.execute("SELECT 1;")
@@ -215,7 +215,7 @@ def dict_product(self, dicts):
215215
"""
216216
return (dict(zip(dicts.keys(), x)) for x in itertools.product(*dicts.values()))
217217

218-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
218+
@patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME)
219219
def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backend):
220220
req_args_combinations = self.dict_product(
221221
dict(
@@ -236,7 +236,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe
236236
for k, v in req_args.items():
237237
self.assertEqual(v, call_args[k])
238238

239-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
239+
@patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME)
240240
def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backend):
241241
req_args_combinations = self.dict_product(
242242
dict(
@@ -259,7 +259,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen
259259
for k, v in req_args.items():
260260
self.assertEqual(v, call_args[k])
261261

262-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
262+
@patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME)
263263
def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backend):
264264
req_args_combinations = self.dict_product(
265265
dict(
@@ -310,7 +310,7 @@ def test_version_is_canonical(self):
310310
self.assertIsNotNone(re.match(canonical_version_re, version))
311311

312312
def test_execute_parameter_passthrough(self):
313-
mock_thrift_backend = ThriftBackendMockFactory.new()
313+
mock_thrift_backend = ThriftDatabricksClientMockFactory.new()
314314
cursor = client.Cursor(Mock(), mock_thrift_backend)
315315

316316
tests = [
@@ -334,16 +334,16 @@ def test_execute_parameter_passthrough(self):
334334
expected_query,
335335
)
336336

337-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
337+
@patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME)
338338
@patch("%s.client.ResultSet" % PACKAGE_NAME)
339339
def test_executemany_parameter_passhthrough_and_uses_last_result_set(
340340
self, mock_result_set_class, mock_thrift_backend
341341
):
342342
# Create a new mock result set each time the class is instantiated
343343
mock_result_set_instances = [Mock(), Mock(), Mock()]
344344
mock_result_set_class.side_effect = mock_result_set_instances
345-
mock_thrift_backend = ThriftBackendMockFactory.new()
346-
cursor = client.Cursor(Mock(), mock_thrift_backend())
345+
mock_backend = ThriftDatabricksClientMockFactory.new()
346+
cursor = client.Cursor(Mock(), mock_backend())
347347

348348
params = [{"x": None}, {"x": "foo1"}, {"x": "bar2"}]
349349
expected_queries = ["SELECT NULL", "SELECT 'foo1'", "SELECT 'bar2'"]
@@ -368,7 +368,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(
368368
"last operation",
369369
)
370370

371-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
371+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
372372
def test_commit_a_noop(self, mock_thrift_backend_class):
373373
c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
374374
c.commit()
@@ -381,14 +381,14 @@ def test_setoutputsizes_a_noop(self):
381381
cursor = client.Cursor(Mock(), Mock())
382382
cursor.setoutputsize(1)
383383

384-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
384+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
385385
def test_rollback_not_supported(self, mock_thrift_backend_class):
386386
c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
387387
with self.assertRaises(NotSupportedError):
388388
c.rollback()
389389

390390
@unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface")
391-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
391+
@patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME)
392392
def test_row_number_respected(self, mock_thrift_backend_class):
393393
def make_fake_row_slice(n_rows):
394394
mock_slice = Mock()
@@ -413,7 +413,7 @@ def make_fake_row_slice(n_rows):
413413
self.assertEqual(cursor.rownumber, 29)
414414

415415
@unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface")
416-
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
416+
@patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME)
417417
def test_disable_pandas_respected(self, mock_thrift_backend_class):
418418
mock_thrift_backend = mock_thrift_backend_class.return_value
419419
mock_table = Mock()
@@ -466,7 +466,7 @@ def test_column_name_api(self):
466466
},
467467
)
468468

469-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
469+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
470470
def test_cursor_keeps_connection_alive(self, mock_client_class):
471471
instance = mock_client_class.return_value
472472

@@ -485,13 +485,13 @@ def test_cursor_keeps_connection_alive(self, mock_client_class):
485485

486486
@patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True)
487487
@patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME)
488-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
488+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
489489
def test_staging_operation_response_is_handled(
490490
self, mock_client_class, mock_handle_staging_operation, mock_execute_response
491491
):
492492
# If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called
493493

494-
ThriftBackendMockFactory.apply_property_to_mock(
494+
ThriftDatabricksClientMockFactory.apply_property_to_mock(
495495
mock_execute_response, is_staging_operation=True
496496
)
497497
mock_client_class.execute_command.return_value = mock_execute_response
@@ -504,7 +504,7 @@ def test_staging_operation_response_is_handled(
504504

505505
mock_handle_staging_operation.call_count == 1
506506

507-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new())
507+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME, ThriftDatabricksClientMockFactory.new())
508508
def test_access_current_query_id(self):
509509
operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821"
510510

tests/unit/test_fetches.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import databricks.sql.client as client
1111
from databricks.sql.utils import ExecuteResponse, ArrowQueue
12+
from databricks.sql.thrift_backend import ThriftDatabricksClient
1213

1314

1415
@pytest.mark.skipif(pa is None, reason="PyArrow is not installed")
@@ -39,7 +40,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
3940
arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0)
4041
rs = client.ResultSet(
4142
connection=Mock(),
42-
thrift_backend=None,
43+
backend=None,
4344
execute_response=ExecuteResponse(
4445
status=None,
4546
has_been_closed_server_side=True,
@@ -79,13 +80,13 @@ def fetch_results(
7980

8081
return results, batch_index < len(batch_list)
8182

82-
mock_thrift_backend = Mock()
83+
mock_thrift_backend = Mock(spec=ThriftDatabricksClient)
8384
mock_thrift_backend.fetch_results = fetch_results
8485
num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0
8586

8687
rs = client.ResultSet(
8788
connection=Mock(),
88-
thrift_backend=mock_thrift_backend,
89+
backend=mock_thrift_backend,
8990
execute_response=ExecuteResponse(
9091
status=None,
9192
has_been_closed_server_side=False,

tests/unit/test_session.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class SessionTestSuite(unittest.TestCase):
2121
"access_token": "tok",
2222
}
2323

24-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
24+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
2525
def test_close_uses_the_correct_session_id(self, mock_client_class):
2626
instance = mock_client_class.return_value
2727

@@ -36,7 +36,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class):
3636
close_session_id = instance.close_session.call_args[0][0].sessionId
3737
self.assertEqual(close_session_id, b"\x22")
3838

39-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
39+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
4040
def test_auth_args(self, mock_client_class):
4141
# Test that the following auth args work:
4242
# token = foo,
@@ -63,15 +63,15 @@ def test_auth_args(self, mock_client_class):
6363
self.assertEqual(args["http_path"], http_path)
6464
connection.close()
6565

66-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
66+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
6767
def test_http_header_passthrough(self, mock_client_class):
6868
http_headers = [("foo", "bar")]
6969
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
7070

7171
call_args = mock_client_class.call_args[0][3]
7272
self.assertIn(("foo", "bar"), call_args)
7373

74-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
74+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
7575
def test_tls_arg_passthrough(self, mock_client_class):
7676
databricks.sql.connect(
7777
**self.DUMMY_CONNECTION_ARGS,
@@ -87,7 +87,7 @@ def test_tls_arg_passthrough(self, mock_client_class):
8787
self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert")
8888
self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password")
8989

90-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
90+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
9191
def test_useragent_header(self, mock_client_class):
9292
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
9393

@@ -108,7 +108,7 @@ def test_useragent_header(self, mock_client_class):
108108
http_headers = mock_client_class.call_args[0][3]
109109
self.assertIn(user_agent_header_with_entry, http_headers)
110110

111-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
111+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
112112
def test_context_manager_closes_connection(self, mock_client_class):
113113
instance = mock_client_class.return_value
114114

@@ -123,7 +123,7 @@ def test_context_manager_closes_connection(self, mock_client_class):
123123
close_session_id = instance.close_session.call_args[0][0].sessionId
124124
self.assertEqual(close_session_id, b"\x22")
125125

126-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
126+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
127127
def test_max_number_of_retries_passthrough(self, mock_client_class):
128128
databricks.sql.connect(
129129
_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS
@@ -133,12 +133,12 @@ def test_max_number_of_retries_passthrough(self, mock_client_class):
133133
mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54
134134
)
135135

136-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
136+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
137137
def test_socket_timeout_passthrough(self, mock_client_class):
138138
databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS)
139139
self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234)
140140

141-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
141+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
142142
def test_configuration_passthrough(self, mock_client_class):
143143
mock_session_config = Mock()
144144
databricks.sql.connect(
@@ -150,7 +150,7 @@ def test_configuration_passthrough(self, mock_client_class):
150150
mock_session_config,
151151
)
152152

153-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
153+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
154154
def test_initial_namespace_passthrough(self, mock_client_class):
155155
mock_cat = Mock()
156156
mock_schem = Mock()
@@ -165,7 +165,7 @@ def test_initial_namespace_passthrough(self, mock_client_class):
165165
mock_client_class.return_value.open_session.call_args[0][2], mock_schem
166166
)
167167

168-
@patch("%s.session.ThriftBackend" % PACKAGE_NAME)
168+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
169169
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
170170
instance = mock_client_class.return_value
171171

0 commit comments

Comments
 (0)