Skip to content

Commit 13ffb8d

Browse files
remove un-necessary artifacts in test_session, add back assertion
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 63360b3 commit 13ffb8d

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

tests/unit/test_session.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ def test_context_manager_closes_connection(self, mock_client_class):
128128
self.assertEqual(close_session_call_args.guid, b"\x22")
129129
self.assertEqual(close_session_call_args.secret, b"\x33")
130130

131+
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
132+
connection.close = Mock()
133+
try:
134+
with self.assertRaises(KeyboardInterrupt):
135+
with connection:
136+
raise KeyboardInterrupt("Simulated interrupt")
137+
finally:
138+
connection.close.assert_called()
139+
131140
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
132141
def test_max_number_of_retries_passthrough(self, mock_client_class):
133142
databricks.sql.connect(
@@ -146,33 +155,21 @@ def test_socket_timeout_passthrough(self, mock_client_class):
146155
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
147156
def test_configuration_passthrough(self, mock_client_class):
148157
mock_session_config = Mock()
149-
150-
# Create a mock SessionId that will be returned by open_session
151-
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
152-
mock_client_class.return_value.open_session.return_value = mock_session_id
153-
154158
databricks.sql.connect(
155159
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
156160
)
157161

158-
# Check that open_session was called with the correct session_configuration as keyword argument
159162
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
160163
self.assertEqual(call_kwargs["session_configuration"], mock_session_config)
161164

162165
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
163166
def test_initial_namespace_passthrough(self, mock_client_class):
164167
mock_cat = Mock()
165168
mock_schem = Mock()
166-
167-
# Create a mock SessionId that will be returned by open_session
168-
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
169-
mock_client_class.return_value.open_session.return_value = mock_session_id
170-
171169
databricks.sql.connect(
172170
**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem
173171
)
174172

175-
# Check that open_session was called with the correct catalog and schema as keyword arguments
176173
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
177174
self.assertEqual(call_kwargs["catalog"], mock_cat)
178175
self.assertEqual(call_kwargs["schema"], mock_schem)
@@ -181,7 +178,6 @@ def test_initial_namespace_passthrough(self, mock_client_class):
181178
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
182179
instance = mock_client_class.return_value
183180

184-
# Create a mock SessionId that will be returned by open_session
185181
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
186182
instance.open_session.return_value = mock_session_id
187183

0 commit comments

Comments
 (0)