Skip to content

Fix potential resource leak in CloudFetchQueue #624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,7 @@ def close(self) -> None:
if isinstance(e.args[1], CursorAlreadyClosedError):
logger.info("Operation was canceled by a prior request")
finally:
self.results.close()
self.has_been_closed_server_side = True
self.op_state = self.thrift_backend.CLOSED_OP_STATE

Expand Down
13 changes: 13 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def next_n_rows(self, num_rows: int):
def remaining_rows(self):
pass

@abstractmethod
def close(self):
pass


class ResultSetQueueFactory(ABC):
@staticmethod
Expand Down Expand Up @@ -157,6 +161,9 @@ def remaining_rows(self):
self.cur_row_index += slice.num_rows
return slice

def close(self):
return


class ArrowQueue(ResultSetQueue):
def __init__(
Expand Down Expand Up @@ -192,6 +199,9 @@ def remaining_rows(self) -> "pyarrow.Table":
self.cur_row_index += slice.num_rows
return slice

def close(self):
return


class CloudFetchQueue(ResultSetQueue):
def __init__(
Expand Down Expand Up @@ -341,6 +351,9 @@ def _create_empty_table(self) -> "pyarrow.Table":
# Create a 0-row table with just the schema bytes
return create_arrow_table_from_arrow_file(self.schema_bytes, self.description)

def close(self):
self.download_manager._shutdown_manager()


ExecuteResponse = namedtuple(
"ExecuteResponse",
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,33 +267,39 @@ def test_arraysize_buffer_size_passthrough(
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
mock_connection = Mock()
mock_backend = Mock()
mock_results = Mock()
result_set = client.ResultSet(
connection=mock_connection,
thrift_backend=mock_backend,
execute_response=Mock(),
)
result_set.results = mock_results
mock_connection.open = False

result_set.close()

self.assertFalse(mock_backend.close_command.called)
self.assertTrue(result_set.has_been_closed_server_side)
mock_results.close.assert_called_once()

def test_closing_result_set_hard_closes_commands(self):
mock_results_response = Mock()
mock_results_response.has_been_closed_server_side = False
mock_connection = Mock()
mock_thrift_backend = Mock()
mock_results = Mock()
mock_connection.open = True
result_set = client.ResultSet(
mock_connection, mock_results_response, mock_thrift_backend
)
result_set.results = mock_results

result_set.close()

mock_thrift_backend.close_command.assert_called_once_with(
mock_results_response.command_handle
)
mock_results.close.assert_called_once()

@patch("%s.client.ResultSet" % PACKAGE_NAME)
def test_executing_multiple_commands_uses_the_most_recent_command(
Expand Down
Loading