Skip to content

Commit 576eafc

Browse files
Fix potential resource leak in CloudFetchQueue (#624)
* add close() for Queue, add ResultSet invocation Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move Queue closure to finally: block to ensure client-side cleanup regardless of server side state Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add unit test assertions to ensure Queue closure Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * move results closure to try block Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> --------- Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 9c34acd commit 576eafc

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

src/databricks/sql/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,7 @@ def close(self) -> None:
15961596
been closed on the server for some other reason, issue a request to the server to close it.
15971597
"""
15981598
try:
1599+
self.results.close()
15991600
if (
16001601
self.op_state != self.thrift_backend.CLOSED_OP_STATE
16011602
and not self.has_been_closed_server_side

src/databricks/sql/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def next_n_rows(self, num_rows: int):
4646
def remaining_rows(self):
4747
pass
4848

49+
@abstractmethod
50+
def close(self):
51+
pass
52+
4953

5054
class ResultSetQueueFactory(ABC):
5155
@staticmethod
@@ -157,6 +161,9 @@ def remaining_rows(self):
157161
self.cur_row_index += slice.num_rows
158162
return slice
159163

164+
def close(self):
165+
return
166+
160167

161168
class ArrowQueue(ResultSetQueue):
162169
def __init__(
@@ -192,6 +199,9 @@ def remaining_rows(self) -> "pyarrow.Table":
192199
self.cur_row_index += slice.num_rows
193200
return slice
194201

202+
def close(self):
203+
return
204+
195205

196206
class CloudFetchQueue(ResultSetQueue):
197207
def __init__(
@@ -341,6 +351,9 @@ def _create_empty_table(self) -> "pyarrow.Table":
341351
# Create a 0-row table with just the schema bytes
342352
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)
343353

354+
def close(self):
355+
self.download_manager._shutdown_manager()
356+
344357

345358
ExecuteResponse = namedtuple(
346359
"ExecuteResponse",

tests/unit/test_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,33 +267,39 @@ def test_arraysize_buffer_size_passthrough(
267267
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
268268
mock_connection = Mock()
269269
mock_backend = Mock()
270+
mock_results = Mock()
270271
result_set = client.ResultSet(
271272
connection=mock_connection,
272273
thrift_backend=mock_backend,
273274
execute_response=Mock(),
274275
)
276+
result_set.results = mock_results
275277
mock_connection.open = False
276278

277279
result_set.close()
278280

279281
self.assertFalse(mock_backend.close_command.called)
280282
self.assertTrue(result_set.has_been_closed_server_side)
283+
mock_results.close.assert_called_once()
281284

282285
def test_closing_result_set_hard_closes_commands(self):
283286
mock_results_response = Mock()
284287
mock_results_response.has_been_closed_server_side = False
285288
mock_connection = Mock()
286289
mock_thrift_backend = Mock()
290+
mock_results = Mock()
287291
mock_connection.open = True
288292
result_set = client.ResultSet(
289293
mock_connection, mock_results_response, mock_thrift_backend
290294
)
295+
result_set.results = mock_results
291296

292297
result_set.close()
293298

294299
mock_thrift_backend.close_command.assert_called_once_with(
295300
mock_results_response.command_handle
296301
)
302+
mock_results.close.assert_called_once()
297303

298304
@patch("%s.client.ResultSet" % PACKAGE_NAME)
299305
def test_executing_multiple_commands_uses_the_most_recent_command(

0 commit comments

Comments
 (0)