diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b81416e1..1e109405 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1509,6 +1509,7 @@ def close(self) -> None: if isinstance(e.args[1], CursorAlreadyClosedError): logger.info("Operation was canceled by a prior request") finally: + self.results.close() self.has_been_closed_server_side = True self.op_state = self.thrift_backend.CLOSED_OP_STATE diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0ce2fa16..23380877 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -46,6 +46,10 @@ def next_n_rows(self, num_rows: int): def remaining_rows(self): pass + @abstractmethod + def close(self): + pass + class ResultSetQueueFactory(ABC): @staticmethod @@ -157,6 +161,9 @@ def remaining_rows(self): self.cur_row_index += slice.num_rows return slice + def close(self): + return + class ArrowQueue(ResultSetQueue): def __init__( @@ -192,6 +199,9 @@ def remaining_rows(self) -> "pyarrow.Table": self.cur_row_index += slice.num_rows return slice + def close(self): + return + class CloudFetchQueue(ResultSetQueue): def __init__( @@ -341,6 +351,9 @@ def _create_empty_table(self) -> "pyarrow.Table": # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) + def close(self): + self.download_manager._shutdown_manager() + ExecuteResponse = namedtuple( "ExecuteResponse", diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 91e426c6..44c84d79 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -267,33 +267,39 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() + mock_results = Mock() result_set = client.ResultSet( connection=mock_connection, thrift_backend=mock_backend, execute_response=Mock(), ) + result_set.results = mock_results mock_connection.open = False result_set.close() self.assertFalse(mock_backend.close_command.called) self.assertTrue(result_set.has_been_closed_server_side) + mock_results.close.assert_called_once() def test_closing_result_set_hard_closes_commands(self): mock_results_response = Mock() mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() + mock_results = Mock() mock_connection.open = True result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) + result_set.results = mock_results result_set.close() mock_thrift_backend.close_command.assert_called_once_with( mock_results_response.command_handle ) + mock_results.close.assert_called_once() @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_executing_multiple_commands_uses_the_most_recent_command(