15
15
THandleIdentifier ,
16
16
TOperationType ,
17
17
)
18
- from databricks .sql .thrift_backend import ThriftBackend
18
+ from databricks .sql .thrift_backend import ThriftDatabricksClient
19
19
20
20
import databricks .sql
21
21
import databricks .sql .client as client
27
27
from tests .unit .test_arrow_queue import ArrowQueueSuite
28
28
29
29
30
- class ThriftBackendMockFactory :
30
+ class ThriftDatabricksClientMockFactory :
31
31
@classmethod
32
32
def new (cls ):
33
- ThriftBackendMock = Mock (spec = ThriftBackend )
33
+ ThriftBackendMock = Mock (spec = ThriftDatabricksClient )
34
34
ThriftBackendMock .return_value = ThriftBackendMock
35
35
36
36
cls .apply_property_to_mock (ThriftBackendMock , staging_allowed_local_path = None )
@@ -80,7 +80,7 @@ class ClientTestSuite(unittest.TestCase):
80
80
"access_token" : "tok" ,
81
81
}
82
82
83
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME , ThriftBackendMockFactory .new ())
83
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME , ThriftDatabricksClientMockFactory .new ())
84
84
@patch ("%s.client.ResultSet" % PACKAGE_NAME )
85
85
def test_closing_connection_closes_commands (self , mock_result_set_class ):
86
86
# Test once with has_been_closed_server side, once without
@@ -97,7 +97,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class):
97
97
)
98
98
mock_result_set_class .return_value .close .assert_called_once_with ()
99
99
100
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME )
100
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME )
101
101
def test_cant_open_cursor_on_closed_connection (self , mock_client_class ):
102
102
connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
103
103
self .assertTrue (connection .open )
@@ -107,7 +107,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class):
107
107
connection .cursor ()
108
108
self .assertIn ("closed" , str (cm .exception ))
109
109
110
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME )
110
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME )
111
111
@patch ("%s.client.Cursor" % PACKAGE_NAME )
112
112
def test_arraysize_buffer_size_passthrough (
113
113
self , mock_cursor_class , mock_client_class
@@ -124,7 +124,7 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
124
124
mock_backend = Mock ()
125
125
result_set = client .ResultSet (
126
126
connection = mock_connection ,
127
- thrift_backend = mock_backend ,
127
+ backend = mock_backend ,
128
128
execute_response = Mock (),
129
129
)
130
130
# Setup session mock on the mock_connection
@@ -166,7 +166,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command(
166
166
mock_result_set_class .side_effect = mock_result_sets
167
167
168
168
cursor = client .Cursor (
169
- connection = Mock (), thrift_backend = ThriftBackendMockFactory .new ()
169
+ connection = Mock (), backend = ThriftDatabricksClientMockFactory .new ()
170
170
)
171
171
cursor .execute ("SELECT 1;" )
172
172
cursor .execute ("SELECT 1;" )
@@ -215,7 +215,7 @@ def dict_product(self, dicts):
215
215
"""
216
216
return (dict (zip (dicts .keys (), x )) for x in itertools .product (* dicts .values ()))
217
217
218
- @patch ("%s.client.ThriftBackend " % PACKAGE_NAME )
218
+ @patch ("%s.client.ThriftDatabricksClient " % PACKAGE_NAME )
219
219
def test_get_schemas_parameters_passed_to_thrift_backend (self , mock_thrift_backend ):
220
220
req_args_combinations = self .dict_product (
221
221
dict (
@@ -236,7 +236,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe
236
236
for k , v in req_args .items ():
237
237
self .assertEqual (v , call_args [k ])
238
238
239
- @patch ("%s.client.ThriftBackend " % PACKAGE_NAME )
239
+ @patch ("%s.client.ThriftDatabricksClient " % PACKAGE_NAME )
240
240
def test_get_tables_parameters_passed_to_thrift_backend (self , mock_thrift_backend ):
241
241
req_args_combinations = self .dict_product (
242
242
dict (
@@ -259,7 +259,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen
259
259
for k , v in req_args .items ():
260
260
self .assertEqual (v , call_args [k ])
261
261
262
- @patch ("%s.client.ThriftBackend " % PACKAGE_NAME )
262
+ @patch ("%s.client.ThriftDatabricksClient " % PACKAGE_NAME )
263
263
def test_get_columns_parameters_passed_to_thrift_backend (self , mock_thrift_backend ):
264
264
req_args_combinations = self .dict_product (
265
265
dict (
@@ -310,7 +310,7 @@ def test_version_is_canonical(self):
310
310
self .assertIsNotNone (re .match (canonical_version_re , version ))
311
311
312
312
def test_execute_parameter_passthrough (self ):
313
- mock_thrift_backend = ThriftBackendMockFactory .new ()
313
+ mock_thrift_backend = ThriftDatabricksClientMockFactory .new ()
314
314
cursor = client .Cursor (Mock (), mock_thrift_backend )
315
315
316
316
tests = [
@@ -334,16 +334,16 @@ def test_execute_parameter_passthrough(self):
334
334
expected_query ,
335
335
)
336
336
337
- @patch ("%s.client.ThriftBackend " % PACKAGE_NAME )
337
+ @patch ("%s.client.ThriftDatabricksClient " % PACKAGE_NAME )
338
338
@patch ("%s.client.ResultSet" % PACKAGE_NAME )
339
339
def test_executemany_parameter_passhthrough_and_uses_last_result_set (
340
340
self , mock_result_set_class , mock_thrift_backend
341
341
):
342
342
# Create a new mock result set each time the class is instantiated
343
343
mock_result_set_instances = [Mock (), Mock (), Mock ()]
344
344
mock_result_set_class .side_effect = mock_result_set_instances
345
- mock_thrift_backend = ThriftBackendMockFactory .new ()
346
- cursor = client .Cursor (Mock (), mock_thrift_backend ())
345
+ mock_backend = ThriftDatabricksClientMockFactory .new ()
346
+ cursor = client .Cursor (Mock (), mock_backend ())
347
347
348
348
params = [{"x" : None }, {"x" : "foo1" }, {"x" : "bar2" }]
349
349
expected_queries = ["SELECT NULL" , "SELECT 'foo1'" , "SELECT 'bar2'" ]
@@ -368,7 +368,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(
368
368
"last operation" ,
369
369
)
370
370
371
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME )
371
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME )
372
372
def test_commit_a_noop (self , mock_thrift_backend_class ):
373
373
c = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
374
374
c .commit ()
@@ -381,14 +381,14 @@ def test_setoutputsizes_a_noop(self):
381
381
cursor = client .Cursor (Mock (), Mock ())
382
382
cursor .setoutputsize (1 )
383
383
384
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME )
384
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME )
385
385
def test_rollback_not_supported (self , mock_thrift_backend_class ):
386
386
c = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
387
387
with self .assertRaises (NotSupportedError ):
388
388
c .rollback ()
389
389
390
390
@unittest .skip ("JDW: skipping winter 2024 as we're about to rewrite this interface" )
391
- @patch ("%s.client.ThriftBackend " % PACKAGE_NAME )
391
+ @patch ("%s.client.ThriftDatabricksClient " % PACKAGE_NAME )
392
392
def test_row_number_respected (self , mock_thrift_backend_class ):
393
393
def make_fake_row_slice (n_rows ):
394
394
mock_slice = Mock ()
@@ -413,7 +413,7 @@ def make_fake_row_slice(n_rows):
413
413
self .assertEqual (cursor .rownumber , 29 )
414
414
415
415
@unittest .skip ("JDW: skipping winter 2024 as we're about to rewrite this interface" )
416
- @patch ("%s.client.ThriftBackend " % PACKAGE_NAME )
416
+ @patch ("%s.client.ThriftDatabricksClient " % PACKAGE_NAME )
417
417
def test_disable_pandas_respected (self , mock_thrift_backend_class ):
418
418
mock_thrift_backend = mock_thrift_backend_class .return_value
419
419
mock_table = Mock ()
@@ -466,7 +466,7 @@ def test_column_name_api(self):
466
466
},
467
467
)
468
468
469
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME )
469
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME )
470
470
def test_cursor_keeps_connection_alive (self , mock_client_class ):
471
471
instance = mock_client_class .return_value
472
472
@@ -485,13 +485,13 @@ def test_cursor_keeps_connection_alive(self, mock_client_class):
485
485
486
486
@patch ("%s.utils.ExecuteResponse" % PACKAGE_NAME , autospec = True )
487
487
@patch ("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME )
488
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME )
488
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME )
489
489
def test_staging_operation_response_is_handled (
490
490
self , mock_client_class , mock_handle_staging_operation , mock_execute_response
491
491
):
492
492
# If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called
493
493
494
- ThriftBackendMockFactory .apply_property_to_mock (
494
+ ThriftDatabricksClientMockFactory .apply_property_to_mock (
495
495
mock_execute_response , is_staging_operation = True
496
496
)
497
497
mock_client_class .execute_command .return_value = mock_execute_response
@@ -504,7 +504,7 @@ def test_staging_operation_response_is_handled(
504
504
505
505
mock_handle_staging_operation .call_count == 1
506
506
507
- @patch ("%s.session.ThriftBackend " % PACKAGE_NAME , ThriftBackendMockFactory .new ())
507
+ @patch ("%s.session.ThriftDatabricksClient " % PACKAGE_NAME , ThriftDatabricksClientMockFactory .new ())
508
508
def test_access_current_query_id (self ):
509
509
operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821"
510
510
0 commit comments