From 89f86804f4a89c88c2fc110ac1dfa0adcb999fc5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 7 Jul 2025 09:13:33 +0530 Subject: [PATCH 1/2] stop passing client to ResultSet, infer from connection Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 1 - src/databricks/sql/backend/sea/result_set.py | 38 ++++++----- .../sql/backend/sea/utils/filters.py | 7 ++- src/databricks/sql/backend/thrift_backend.py | 6 -- src/databricks/sql/result_set.py | 63 ++++++++++++------- tests/unit/test_client.py | 42 +++++++++---- tests/unit/test_fetches.py | 32 +++++++--- tests/unit/test_sea_backend.py | 10 ++- tests/unit/test_sea_result_set.py | 59 ++++++++--------- 9 files changed, 159 insertions(+), 99 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 814859a31..353252c42 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -620,7 +620,6 @@ def get_execution_result( return SeaResultSet( connection=cursor.connection, execute_response=execute_response, - sea_client=self, result_data=response.result, manifest=response.manifest, buffer_size_bytes=cursor.buffer_size_bytes, diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 302af5e3a..14ed61575 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -31,7 +31,6 @@ def __init__( self, connection: Connection, execute_response: ExecuteResponse, - sea_client: SeaDatabricksClient, result_data: ResultData, manifest: ResultManifest, buffer_size_bytes: int = 104857600, @@ -43,7 +42,6 @@ def __init__( Args: connection: The parent connection execute_response: Response from the execute command - sea_client: The SeaDatabricksClient instance for direct access buffer_size_bytes: Buffer size for fetching results arraysize: Default number of rows to fetch result_data: Result data from SEA response @@ -56,32 +54,38 @@ def __init__( if statement_id is None: raise ValueError("Command ID is not a SEA statement ID") - results_queue = SeaResultSetQueueFactory.build_queue( - result_data, - self.manifest, - statement_id, - description=execute_response.description, - max_download_threads=sea_client.max_download_threads, - sea_client=sea_client, - lz4_compressed=execute_response.lz4_compressed, - ) - # Call parent constructor with common attributes super().__init__( connection=connection, - backend=sea_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + assert isinstance( + self.backend, SeaDatabricksClient + ), "Backend must be a SeaDatabricksClient" + + results_queue = SeaResultSetQueueFactory.build_queue( + result_data, + self.manifest, + statement_id, + description=execute_response.description, + max_download_threads=self.backend.max_download_threads, + sea_client=self.backend, + lz4_compressed=execute_response.lz4_compressed, + ) + + # Set the results queue + self.results = results_queue + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -160,6 +164,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]: if size < 0: raise ValueError(f"size argument for fetchmany is {size} but must be >= 0") + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) self._next_row_index += len(results) @@ -173,6 +180,9 @@ def fetchall_json(self) -> List[List[str]]: Columnar table containing all remaining rows """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += len(results) diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index ef6c91d7d..cd27778fb 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -12,14 +12,13 @@ Optional, Any, Callable, - cast, TYPE_CHECKING, ) if TYPE_CHECKING: from databricks.sql.backend.sea.result_set import SeaResultSet -from databricks.sql.backend.types import ExecuteResponse +from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState logger = logging.getLogger(__name__) @@ -45,6 +44,9 @@ def _filter_sea_result_set( """ # Get all remaining rows + if result_set.results is None: + raise RuntimeError("Results queue is not initialized") + all_rows = result_set.results.remaining_rows() # Filter rows @@ -79,7 +81,6 @@ def _filter_sea_result_set( filtered_result_set = SeaResultSet( connection=result_set.connection, execute_response=execute_response, - sea_client=cast(SeaDatabricksClient, result_set.backend), result_data=result_data, manifest=manifest, buffer_size_bytes=result_set.buffer_size_bytes, diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 02d335aa4..12b727120 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -856,7 +856,6 @@ def get_execution_result( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=cursor.buffer_size_bytes, arraysize=cursor.arraysize, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -987,7 +986,6 @@ def execute_command( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=use_cloud_fetch, @@ -1027,7 +1025,6 @@ def get_catalogs( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1071,7 +1068,6 @@ def get_schemas( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1119,7 +1115,6 @@ def get_tables( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, @@ -1167,7 +1162,6 @@ def get_columns( return ThriftResultSet( connection=cursor.connection, execute_response=execute_response, - thrift_client=self, buffer_size_bytes=max_bytes, arraysize=max_rows, use_cloud_fetch=cursor.connection.use_cloud_fetch, diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 8934d0d56..5151988ad 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,6 +20,7 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, + ResultSetQueue, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse @@ -36,14 +37,12 @@ class ResultSet(ABC): def __init__( self, connection: "Connection", - backend: "DatabricksClient", arraysize: int, buffer_size_bytes: int, command_id: CommandId, status: CommandState, has_been_closed_server_side: bool = False, is_direct_results: bool = False, - results_queue=None, description: List[Tuple] = [], is_staging_operation: bool = False, lz4_compressed: bool = False, @@ -54,32 +53,30 @@ def __init__( Parameters: :param connection: The parent connection - :param backend: The backend client :param arraysize: The max number of rows to fetch at a time (PEP-249) :param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch :param command_id: The command ID :param status: The command status :param has_been_closed_server_side: Whether the command has been closed on the server :param is_direct_results: Whether the command has more rows - :param results_queue: The results queue :param description: column description of the results :param is_staging_operation: Whether the command is a staging operation """ - self.connection = connection - self.backend = backend - self.arraysize = arraysize - self.buffer_size_bytes = buffer_size_bytes - self._next_row_index = 0 - self.description = description - self.command_id = command_id - self.status = status - self.has_been_closed_server_side = has_been_closed_server_side - self.is_direct_results = is_direct_results - self.results = results_queue - self._is_staging_operation = is_staging_operation - self.lz4_compressed = lz4_compressed - self._arrow_schema_bytes = arrow_schema_bytes + self.connection: "Connection" = connection + self.backend: DatabricksClient = connection.session.backend + self.arraysize: int = arraysize + self.buffer_size_bytes: int = buffer_size_bytes + self._next_row_index: int = 0 + self.description: List[Tuple] = description + self.command_id: CommandId = command_id + self.status: CommandState = status + self.has_been_closed_server_side: bool = has_been_closed_server_side + self.is_direct_results: bool = is_direct_results + self.results: Optional[ResultSetQueue] = None # Children will set this + self._is_staging_operation: bool = is_staging_operation + self.lz4_compressed: bool = lz4_compressed + self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes def __iter__(self): while True: @@ -190,7 +187,6 @@ def __init__( self, connection: "Connection", execute_response: "ExecuteResponse", - thrift_client: "ThriftDatabricksClient", buffer_size_bytes: int = 104857600, arraysize: int = 10000, use_cloud_fetch: bool = True, @@ -205,7 +201,6 @@ def __init__( Parameters: :param connection: The parent connection :param execute_response: Response from the execute command - :param thrift_client: The ThriftDatabricksClient instance for direct access :param buffer_size_bytes: Buffer size for fetching results :param arraysize: Default number of rows to fetch :param use_cloud_fetch: Whether to use cloud fetch for retrieving results @@ -238,20 +233,28 @@ def __init__( # Call parent constructor with common attributes super().__init__( connection=connection, - backend=thrift_client, arraysize=arraysize, buffer_size_bytes=buffer_size_bytes, command_id=execute_response.command_id, status=execute_response.status, has_been_closed_server_side=execute_response.has_been_closed_server_side, is_direct_results=is_direct_results, - results_queue=results_queue, description=execute_response.description, is_staging_operation=execute_response.is_staging_operation, lz4_compressed=execute_response.lz4_compressed, arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + # Assert that the backend is of the correct type + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + assert isinstance( + self.backend, ThriftDatabricksClient + ), "Backend must be a ThriftDatabricksClient" + + # Set the results queue + self.results = results_queue + # Initialize results queue if not provided if not self.results: self._fill_results_buffer() @@ -307,6 +310,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -332,6 +339,9 @@ def fetchmany_columnar(self, size: int): if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows @@ -351,6 +361,9 @@ def fetchmany_columnar(self, size: int): def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -377,6 +390,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + if self.results is None: + raise RuntimeError("Results queue is not initialized") + results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -393,6 +409,9 @@ def fetchone(self) -> Optional[Row]: Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ + if self.results is None: + raise RuntimeError("Results queue is not initialized") + if isinstance(self.results, ColumnQueue): res = self._convert_columnar_table(self.fetchmany_columnar(1)) else: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5ffdea9f0..c14a74038 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -118,7 +118,6 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class): real_result_set = ThriftResultSet( connection=connection, execute_response=mock_execute_response, - thrift_client=mock_backend, ) # Verify initial state @@ -185,19 +184,24 @@ def test_arraysize_buffer_size_passthrough( def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() - mock_backend = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure isinstance check passes + mock_backend.__class__ = ThriftDatabricksClient - result_set = ThriftResultSet( - connection=mock_connection, - execute_response=Mock(), - thrift_client=mock_backend, - ) # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = False + mock_session.backend = mock_backend type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = ThriftResultSet( + connection=mock_connection, + execute_response=Mock(), + ) + result_set.close() self.assertFalse(mock_backend.close_command.called) @@ -207,15 +211,21 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response = Mock() mock_results_response.has_been_closed_server_side = False mock_connection = Mock() - mock_thrift_backend = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient # Setup session mock on the mock_connection mock_session = Mock() mock_session.open = True + mock_session.backend = mock_thrift_backend type(mock_connection).session = PropertyMock(return_value=mock_session) mock_thrift_backend.fetch_results.return_value = (Mock(), False) result_set = ThriftResultSet( - mock_connection, mock_results_response, mock_thrift_backend + mock_connection, + mock_results_response, ) result_set.close() @@ -258,10 +268,20 @@ def test_closed_cursor_doesnt_allow_operations(self): self.assertIn("closed", e.msg) def test_negative_fetch_throws_exception(self): - mock_backend = Mock() + mock_connection = Mock() + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + + mock_backend = Mock(spec=ThriftDatabricksClient) mock_backend.fetch_results.return_value = (Mock(), False) + # Ensure isinstance check passes + mock_backend.__class__ = ThriftDatabricksClient + + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.backend = mock_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) - result_set = ThriftResultSet(Mock(), Mock(), mock_backend) + result_set = ThriftResultSet(mock_connection, Mock()) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index a649941e1..e6ad33aae 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -1,6 +1,6 @@ import unittest import pytest -from unittest.mock import Mock +from unittest.mock import Mock, PropertyMock try: import pyarrow as pa @@ -38,12 +38,19 @@ def make_arrow_queue(batch): @staticmethod def make_dummy_result_set_from_initial_results(initial_results): # If the initial results have been set, then we should never try and fetch more - schema, arrow_table = FetchTests.make_arrow_table(initial_results) - arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0) + arrow_queue = FetchTests.make_arrow_queue(initial_results) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient - # Create a mock backend that will return the queue when _fill_results_buffer is called mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results.return_value = (arrow_queue, False) + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) num_cols = len(initial_results[0]) if initial_results else 0 description = [ @@ -52,7 +59,7 @@ def make_dummy_result_set_from_initial_results(initial_results): ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, @@ -61,7 +68,6 @@ def make_dummy_result_set_from_initial_results(initial_results): lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, t_row_set=None, ) return rs @@ -86,8 +92,19 @@ def fetch_results( return results, batch_index < len(batch_list) + from databricks.sql.backend.thrift_backend import ThriftDatabricksClient + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) mock_thrift_backend.fetch_results = fetch_results + # Ensure isinstance check passes + mock_thrift_backend.__class__ = ThriftDatabricksClient + + # Setup mock connection with session.backend + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = mock_thrift_backend + type(mock_connection).session = PropertyMock(return_value=mock_session) + num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0 description = [ @@ -96,7 +113,7 @@ def fetch_results( ] rs = ThriftResultSet( - connection=Mock(), + connection=mock_connection, execute_response=ExecuteResponse( command_id=None, status=None, @@ -105,7 +122,6 @@ def fetch_results( lz4_compressed=True, is_staging_operation=False, ), - thrift_client=mock_thrift_backend, ) return rs diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 7eae8e5a8..3185589f6 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -68,12 +68,20 @@ def sea_command_id(self): return CommandId.from_sea_statement_id("test-statement-123") @pytest.fixture - def mock_cursor(self): + def mock_cursor(self, sea_client): """Create a mock cursor.""" cursor = Mock() cursor.active_command_id = None cursor.buffer_size_bytes = 1000 cursor.arraysize = 100 + + # Set up a mock connection with session.backend pointing to the sea_client + mock_connection = Mock() + mock_session = Mock() + mock_session.backend = sea_client + mock_connection.session = mock_session + cursor.connection = mock_connection + return cursor @pytest.fixture diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 544edaf96..49b2564c4 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -23,12 +23,20 @@ def mock_connection(self): """Create a mock connection.""" connection = Mock() connection.open = True - return connection - @pytest.fixture - def mock_sea_client(self): - """Create a mock SEA client.""" - return Mock() + # Mock the session.backend to return a SeaDatabricksClient + mock_session = Mock() + from databricks.sql.backend.sea.backend import SeaDatabricksClient + + mock_backend = Mock(spec=SeaDatabricksClient) + mock_backend.max_download_threads = 10 + mock_backend.close_command = Mock() + # Ensure isinstance check passes + mock_backend.__class__ = SeaDatabricksClient + mock_session.backend = mock_backend + connection.session = mock_session + + return connection @pytest.fixture def execute_response(self): @@ -71,9 +79,7 @@ def _create_empty_manifest(self, format: ResultFormat): ) @pytest.fixture - def result_set_with_data( - self, mock_connection, mock_sea_client, execute_response, sample_data - ): + def result_set_with_data(self, mock_connection, execute_response, sample_data): """Create a SeaResultSet with sample data.""" # Create ResultData with inline data result_data = ResultData( @@ -84,7 +90,6 @@ def result_set_with_data( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=result_data, manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -99,14 +104,11 @@ def json_queue(self, sample_data): """Create a JsonQueue with sample data.""" return JsonQueue(sample_data) - def test_init_with_execute_response( - self, mock_connection, mock_sea_client, execute_response - ): + def test_init_with_execute_response(self, mock_connection, execute_response): """Test initializing SeaResultSet with an execute response.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -117,17 +119,15 @@ def test_init_with_execute_response( assert result_set.command_id == execute_response.command_id assert result_set.status == CommandState.SUCCEEDED assert result_set.connection == mock_connection - assert result_set.backend == mock_sea_client assert result_set.buffer_size_bytes == 1000 assert result_set.arraysize == 100 assert result_set.description == execute_response.description - def test_close(self, mock_connection, mock_sea_client, execute_response): + def test_close(self, mock_connection, execute_response): """Test closing a result set.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -138,18 +138,19 @@ def test_close(self, mock_connection, mock_sea_client, execute_response): result_set.close() # Verify the backend's close_command was called - mock_sea_client.close_command.assert_called_once_with(result_set.command_id) + mock_connection.session.backend.close_command.assert_called_once_with( + result_set.command_id + ) assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED def test_close_when_already_closed_server_side( - self, mock_connection, mock_sea_client, execute_response + self, mock_connection, execute_response ): """Test closing a result set that has already been closed server-side.""" result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -161,19 +162,16 @@ def test_close_when_already_closed_server_side( result_set.close() # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() + mock_connection.session.backend.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED - def test_close_when_connection_closed( - self, mock_connection, mock_sea_client, execute_response - ): + def test_close_when_connection_closed(self, mock_connection, execute_response): """Test closing a result set when the connection is closed.""" mock_connection.open = False result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, @@ -184,7 +182,7 @@ def test_close_when_connection_closed( result_set.close() # Verify the backend's close_command was NOT called - mock_sea_client.close_command.assert_not_called() + mock_connection.session.backend.close_command.assert_not_called() assert result_set.has_been_closed_server_side is True assert result_set.status == CommandState.CLOSED @@ -316,7 +314,7 @@ def test_iteration(self, result_set_with_data, sample_data): assert rows[0].col3 is True def test_fetchmany_arrow_not_implemented( - self, mock_connection, mock_sea_client, execute_response, sample_data + self, mock_connection, execute_response, sample_data ): """Test that fetchmany_arrow raises NotImplementedError for non-JSON data.""" @@ -329,7 +327,6 @@ def test_fetchmany_arrow_not_implemented( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, @@ -337,7 +334,7 @@ def test_fetchmany_arrow_not_implemented( ) def test_fetchall_arrow_not_implemented( - self, mock_connection, mock_sea_client, execute_response, sample_data + self, mock_connection, execute_response, sample_data ): """Test that fetchall_arrow raises NotImplementedError for non-JSON data.""" # Test that NotImplementedError is raised @@ -349,16 +346,13 @@ def test_fetchall_arrow_not_implemented( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=None, external_links=[]), manifest=self._create_empty_manifest(ResultFormat.ARROW_STREAM), buffer_size_bytes=1000, arraysize=100, ) - def test_is_staging_operation( - self, mock_connection, mock_sea_client, execute_response - ): + def test_is_staging_operation(self, mock_connection, execute_response): """Test the is_staging_operation property.""" # Set is_staging_operation to True execute_response.is_staging_operation = True @@ -367,7 +361,6 @@ def test_is_staging_operation( result_set = SeaResultSet( connection=mock_connection, execute_response=execute_response, - sea_client=mock_sea_client, result_data=ResultData(data=[]), manifest=self._create_empty_manifest(ResultFormat.JSON_ARRAY), buffer_size_bytes=1000, From 7f9b35d54685880ba9e7f36c7322eb5bedd46421 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 7 Jul 2025 09:23:19 +0530 Subject: [PATCH 2/2] rename sea.backend to sea.client for clarity Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/{backend.py => client.py} | 0 src/databricks/sql/backend/sea/queue.py | 2 +- src/databricks/sql/backend/sea/result_set.py | 2 +- src/databricks/sql/backend/sea/utils/filters.py | 2 +- src/databricks/sql/session.py | 2 +- src/databricks/sql/utils.py | 2 +- tests/unit/test_sea_backend.py | 4 ++-- tests/unit/test_sea_result_set.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) rename src/databricks/sql/backend/sea/{backend.py => client.py} (100%) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/client.py similarity index 100% rename from src/databricks/sql/backend/sea/backend.py rename to src/databricks/sql/backend/sea/client.py diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 73f47ea96..3aeee41c4 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -3,7 +3,7 @@ from abc import ABC from typing import List, Optional, Tuple -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.constants import ResultFormat from databricks.sql.exc import ProgrammingError diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 14ed61575..c6ed63900 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -4,7 +4,7 @@ import logging -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter diff --git a/src/databricks/sql/backend/sea/utils/filters.py b/src/databricks/sql/backend/sea/utils/filters.py index cd27778fb..639b6495f 100644 --- a/src/databricks/sql/backend/sea/utils/filters.py +++ b/src/databricks/sql/backend/sea/utils/filters.py @@ -71,7 +71,7 @@ def _filter_sea_result_set( result_data = ResultData(data=filtered_rows, external_links=None) - from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.result_set import SeaResultSet # Create a new SeaResultSet with the filtered data diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 76aec4675..4c8b882f4 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -8,7 +8,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 35c7bce4d..1a2a8e693 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -13,7 +13,7 @@ import lz4.frame -from databricks.sql.backend.sea.backend import SeaDatabricksClient +from databricks.sql.backend.sea.client import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest try: diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 3185589f6..96485d235 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -8,7 +8,7 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import ( +from databricks.sql.backend.sea.client import ( SeaDatabricksClient, _filter_session_configuration, ) @@ -31,7 +31,7 @@ class TestSeaBackend: def mock_http_client(self): """Create a mock HTTP client.""" with patch( - "databricks.sql.backend.sea.backend.SeaHttpClient" + "databricks.sql.backend.sea.client.SeaHttpClient" ) as mock_client_class: mock_client = mock_client_class.return_value yield mock_client diff --git a/tests/unit/test_sea_result_set.py b/tests/unit/test_sea_result_set.py index 49b2564c4..8884e812a 100644 --- a/tests/unit/test_sea_result_set.py +++ b/tests/unit/test_sea_result_set.py @@ -26,7 +26,7 @@ def mock_connection(self): # Mock the session.backend to return a SeaDatabricksClient mock_session = Mock() - from databricks.sql.backend.sea.backend import SeaDatabricksClient + from databricks.sql.backend.sea.client import SeaDatabricksClient mock_backend = Mock(spec=SeaDatabricksClient) mock_backend.max_download_threads = 10