@@ -128,6 +128,15 @@ def test_context_manager_closes_connection(self, mock_client_class):
128
128
self .assertEqual (close_session_call_args .guid , b"\x22 " )
129
129
self .assertEqual (close_session_call_args .secret , b"\x33 " )
130
130
131
+ connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
132
+ connection .close = Mock ()
133
+ try :
134
+ with self .assertRaises (KeyboardInterrupt ):
135
+ with connection :
136
+ raise KeyboardInterrupt ("Simulated interrupt" )
137
+ finally :
138
+ connection .close .assert_called ()
139
+
131
140
@patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
132
141
def test_max_number_of_retries_passthrough (self , mock_client_class ):
133
142
databricks .sql .connect (
@@ -146,33 +155,21 @@ def test_socket_timeout_passthrough(self, mock_client_class):
146
155
@patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
147
156
def test_configuration_passthrough (self , mock_client_class ):
148
157
mock_session_config = Mock ()
149
-
150
- # Create a mock SessionId that will be returned by open_session
151
- mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
152
- mock_client_class .return_value .open_session .return_value = mock_session_id
153
-
154
158
databricks .sql .connect (
155
159
session_configuration = mock_session_config , ** self .DUMMY_CONNECTION_ARGS
156
160
)
157
161
158
- # Check that open_session was called with the correct session_configuration as keyword argument
159
162
call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
160
163
self .assertEqual (call_kwargs ["session_configuration" ], mock_session_config )
161
164
162
165
@patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
163
166
def test_initial_namespace_passthrough (self , mock_client_class ):
164
167
mock_cat = Mock ()
165
168
mock_schem = Mock ()
166
-
167
- # Create a mock SessionId that will be returned by open_session
168
- mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
169
- mock_client_class .return_value .open_session .return_value = mock_session_id
170
-
171
169
databricks .sql .connect (
172
170
** self .DUMMY_CONNECTION_ARGS , catalog = mock_cat , schema = mock_schem
173
171
)
174
172
175
- # Check that open_session was called with the correct catalog and schema as keyword arguments
176
173
call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
177
174
self .assertEqual (call_kwargs ["catalog" ], mock_cat )
178
175
self .assertEqual (call_kwargs ["schema" ], mock_schem )
@@ -181,7 +178,6 @@ def test_initial_namespace_passthrough(self, mock_client_class):
181
178
def test_finalizer_closes_abandoned_connection (self , mock_client_class ):
182
179
instance = mock_client_class .return_value
183
180
184
- # Create a mock SessionId that will be returned by open_session
185
181
mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
186
182
instance .open_session .return_value = mock_session_id
187
183
0 commit comments