From 048af73f330958908380d5d7aa60f1aba0275961 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 30 May 2025 11:33:04 +0530 Subject: [PATCH 1/5] Enhance Arrow to Pandas conversion with type overrides and additional kwargs * Introduced _arrow_pandas_type_override and _arrow_to_pandas_kwargs in Connection class for customizable dtype mapping and DataFrame construction parameters. * Updated ResultSet to utilize these new options during conversion from Arrow tables to Pandas DataFrames. * Added unit tests to validate the new functionality, including scenarios for type overrides and additional kwargs handling. --- src/databricks/sql/client.py | 21 +++-- tests/unit/test_arrow_conversion.py | 128 ++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_arrow_conversion.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0c9a08a85..0b17ae7d7 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -213,6 +213,11 @@ def read(self) -> Optional[OAuthToken]: # (True by default) # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage + # _arrow_pandas_type_override + # Override the default pandas dtype mapping for Arrow types. + # This is a dictionary of Arrow types to pandas dtypes. + # _arrow_to_pandas_kwargs + # Additional or modified arguments to pass to pandas.DataFrame constructor. logger.debug( "Connection.__init__(server_hostname=%s, http_path=%s)", @@ -1346,7 +1351,7 @@ def _convert_arrow_table(self, table): # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - dtype_mapping = { + DEFAULT_DTYPE_MAPPING: Dict[pyarrow.DataType, pandas.api.extensions.ExtensionDtype] = { pyarrow.int8(): pandas.Int8Dtype(), pyarrow.int16(): pandas.Int16Dtype(), pyarrow.int32(): pandas.Int32Dtype(), @@ -1360,14 +1365,18 @@ def _convert_arrow_table(self, table): pyarrow.float64(): pandas.Float64Dtype(), pyarrow.string(): pandas.StringDtype(), } + dtype_mapping = {**DEFAULT_DTYPE_MAPPING, **self.connection._arrow_pandas_type_override} + + to_pandas_kwargs: dict[str, Any] = { + "types_mapper": dtype_mapping.get, + "date_as_object": True, + "timestamp_as_object": True, + } + to_pandas_kwargs.update(self.connection._arrow_to_pandas_kwargs) # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) - df = table_renamed.to_pandas( - types_mapper=dtype_mapping.get, - date_as_object=True, - timestamp_as_object=True, - ) + df = table_renamed.to_pandas(**to_pandas_kwargs) res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py new file mode 100644 index 000000000..b673d0dfa --- /dev/null +++ b/tests/unit/test_arrow_conversion.py @@ -0,0 +1,128 @@ +import pytest +import pyarrow +import pandas +import datetime +from unittest.mock import MagicMock, patch + +from databricks.sql.client import ResultSet, Connection, ExecuteResponse +from databricks.sql.types import Row +from databricks.sql.utils import ArrowQueue + + +@pytest.fixture +def mock_connection(): + conn = MagicMock(spec=Connection) + conn.disable_pandas = False + conn._arrow_pandas_type_override = {} + conn._arrow_to_pandas_kwargs = {} + if not hasattr(conn, '_arrow_to_pandas_kwargs'): + conn._arrow_to_pandas_kwargs = {} + return conn + +@pytest.fixture +def mock_thrift_backend(sample_arrow_table): + tb = MagicMock() + empty_arrays = [pyarrow.array([], type=field.type) for field in sample_arrow_table.schema] + empty_table = pyarrow.Table.from_arrays(empty_arrays, schema=sample_arrow_table.schema) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0) , False) + return tb + +@pytest.fixture +def mock_raw_execute_response(): + er = MagicMock(spec=ExecuteResponse) + er.description = [("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None)] + er.arrow_schema_bytes = None + er.arrow_queue = None + er.has_more_rows = False + er.lz4_compressed = False + er.command_handle = MagicMock() + er.status = MagicMock() + er.has_been_closed_server_side = False + er.is_staging_operation = False + return er + +@pytest.fixture +def sample_arrow_table(): + data = [ + pyarrow.array([1, 2, 3], type=pyarrow.int32()), + pyarrow.array(["a", "b", "c"], type=pyarrow.string()) + ] + schema = pyarrow.schema([ + ('col_int', pyarrow.int32()), + ('col_str', pyarrow.string()) + ]) + return pyarrow.Table.from_arrays(data, schema=schema) + + +def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_one = rs.fetchone() + assert isinstance(result_one, Row) + assert result_one.col_int == 1 + assert result_one.col_str == "a" + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_all = rs.fetchall() + assert len(result_all) == 3 + assert isinstance(result_all[0], Row) + assert result_all[0].col_int == 1 + assert result_all[1].col_str == "b" + + +def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): + mock_connection.disable_pandas = True + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + assert len(result) == 3 + assert isinstance(result[0], Row) + assert result[0].col_int == 1 + assert result[0].col_str == "a" + assert isinstance(sample_arrow_table.column(0)[0].as_py(), int) + assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) + + +def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): + mock_connection._arrow_pandas_type_override = {pyarrow.int32(): pandas.Float64Dtype()} + mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + assert len(result) == 3 + assert isinstance(result[0].col_int, float) + assert result[0].col_int == 1.0 + assert result[0].col_str == "a" + + +def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backend, mock_raw_execute_response): + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp('us', tz='UTC')) + ts_schema = pyarrow.schema([('col_ts', pyarrow.timestamp('us', tz='UTC'))]) + ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) + + mock_raw_execute_response.description = [("col_ts", "timestamp", None, None, None, None, None)] + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} + rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_true = rs_ts_true.fetchall() + assert len(result_true) == 1 + assert isinstance(result_true[0].col_ts, datetime.datetime) + + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} + rs_ts_false = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_false = rs_ts_false.fetchall() + assert len(result_false) == 1 + assert isinstance(result_false[0].col_ts, pandas.Timestamp) + + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + mock_connection._arrow_to_pandas_kwargs = {} + rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_true = rs_ts_true.fetchall() + assert len(result_true) == 1 + assert isinstance(result_true[0].col_ts, datetime.datetime) From 0b1b05b9fc7d88036f180178a99883a6dbeda921 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 30 May 2025 11:54:14 +0530 Subject: [PATCH 2/5] fmt --- src/databricks/sql/client.py | 9 ++- tests/unit/test_arrow_conversion.py | 86 ++++++++++++++++++++--------- 2 files changed, 67 insertions(+), 28 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0b17ae7d7..79338f387 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1351,7 +1351,9 @@ def _convert_arrow_table(self, table): # Need to use nullable types, as otherwise type can change when there are missing values. # See https://arrow.apache.org/docs/python/pandas.html#nullable-types # NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html - DEFAULT_DTYPE_MAPPING: Dict[pyarrow.DataType, pandas.api.extensions.ExtensionDtype] = { + DEFAULT_DTYPE_MAPPING: Dict[ + pyarrow.DataType, pandas.api.extensions.ExtensionDtype + ] = { pyarrow.int8(): pandas.Int8Dtype(), pyarrow.int16(): pandas.Int16Dtype(), pyarrow.int32(): pandas.Int32Dtype(), @@ -1365,7 +1367,10 @@ def _convert_arrow_table(self, table): pyarrow.float64(): pandas.Float64Dtype(), pyarrow.string(): pandas.StringDtype(), } - dtype_mapping = {**DEFAULT_DTYPE_MAPPING, **self.connection._arrow_pandas_type_override} + dtype_mapping = { + **DEFAULT_DTYPE_MAPPING, + **self.connection._arrow_pandas_type_override, + } to_pandas_kwargs: dict[str, Any] = { "types_mapper": dtype_mapping.get, diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py index b673d0dfa..30fd4f04e 100644 --- a/tests/unit/test_arrow_conversion.py +++ b/tests/unit/test_arrow_conversion.py @@ -15,23 +15,31 @@ def mock_connection(): conn.disable_pandas = False conn._arrow_pandas_type_override = {} conn._arrow_to_pandas_kwargs = {} - if not hasattr(conn, '_arrow_to_pandas_kwargs'): + if not hasattr(conn, "_arrow_to_pandas_kwargs"): conn._arrow_to_pandas_kwargs = {} return conn + @pytest.fixture def mock_thrift_backend(sample_arrow_table): tb = MagicMock() - empty_arrays = [pyarrow.array([], type=field.type) for field in sample_arrow_table.schema] - empty_table = pyarrow.Table.from_arrays(empty_arrays, schema=sample_arrow_table.schema) - tb.fetch_results.return_value = (ArrowQueue(empty_table, 0) , False) + empty_arrays = [ + pyarrow.array([], type=field.type) for field in sample_arrow_table.schema + ] + empty_table = pyarrow.Table.from_arrays( + empty_arrays, schema=sample_arrow_table.schema + ) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) return tb + @pytest.fixture def mock_raw_execute_response(): er = MagicMock(spec=ExecuteResponse) - er.description = [("col_int", "int", None, None, None, None, None), - ("col_str", "string", None, None, None, None, None)] + er.description = [ + ("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None), + ] er.arrow_schema_bytes = None er.arrow_queue = None er.has_more_rows = False @@ -42,27 +50,33 @@ def mock_raw_execute_response(): er.is_staging_operation = False return er + @pytest.fixture def sample_arrow_table(): data = [ pyarrow.array([1, 2, 3], type=pyarrow.int32()), - pyarrow.array(["a", "b", "c"], type=pyarrow.string()) + pyarrow.array(["a", "b", "c"], type=pyarrow.string()), ] - schema = pyarrow.schema([ - ('col_int', pyarrow.int32()), - ('col_str', pyarrow.string()) - ]) + schema = pyarrow.schema( + [("col_int", pyarrow.int32()), ("col_str", pyarrow.string())] + ) return pyarrow.Table.from_arrays(data, schema=schema) -def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) +def test_convert_arrow_table_default( + mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table +): + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result_one = rs.fetchone() assert isinstance(result_one, Row) assert result_one.col_int == 1 assert result_one.col_str == "a" - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result_all = rs.fetchall() assert len(result_all) == 3 @@ -71,9 +85,13 @@ def test_convert_arrow_table_default(mock_connection, mock_thrift_backend, mock_ assert result_all[1].col_str == "b" -def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): +def test_convert_arrow_table_disable_pandas( + mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table +): mock_connection.disable_pandas = True - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result = rs.fetchall() assert len(result) == 3 @@ -84,9 +102,15 @@ def test_convert_arrow_table_disable_pandas(mock_connection, mock_thrift_backend assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) -def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table): - mock_connection._arrow_pandas_type_override = {pyarrow.int32(): pandas.Float64Dtype()} - mock_raw_execute_response.arrow_queue = ArrowQueue(sample_arrow_table, sample_arrow_table.num_rows) +def test_convert_arrow_table_type_override( + mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table +): + mock_connection._arrow_pandas_type_override = { + pyarrow.int32(): pandas.Float64Dtype() + } + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) result = rs.fetchall() assert len(result) == 3 @@ -95,18 +119,24 @@ def test_convert_arrow_table_type_override(mock_connection, mock_thrift_backend, assert result[0].col_str == "a" -def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backend, mock_raw_execute_response): +def test_convert_arrow_table_to_pandas_kwargs( + mock_connection, mock_thrift_backend, mock_raw_execute_response +): dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) - ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp('us', tz='UTC')) - ts_schema = pyarrow.schema([('col_ts', pyarrow.timestamp('us', tz='UTC'))]) + ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC")) + ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))]) ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) - mock_raw_execute_response.description = [("col_ts", "timestamp", None, None, None, None, None)] + mock_raw_execute_response.description = [ + ("col_ts", "timestamp", None, None, None, None, None) + ] mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} - rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) result_true = rs_ts_true.fetchall() assert len(result_true) == 1 assert isinstance(result_true[0].col_ts, datetime.datetime) @@ -114,7 +144,9 @@ def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backe # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} - rs_ts_false = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + rs_ts_false = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) result_false = rs_ts_false.fetchall() assert len(result_false) == 1 assert isinstance(result_false[0].col_ts, pandas.Timestamp) @@ -122,7 +154,9 @@ def test_convert_arrow_table_to_pandas_kwargs(mock_connection, mock_thrift_backe # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) mock_connection._arrow_to_pandas_kwargs = {} - rs_ts_true = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) result_true = rs_ts_true.fetchall() assert len(result_true) == 1 assert isinstance(result_true[0].col_ts, datetime.datetime) From 647ed391be8377afab7fb2ad48d336f656f16185 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Fri, 30 May 2025 12:29:47 +0530 Subject: [PATCH 3/5] fix unit tests --- src/databricks/sql/client.py | 19 +- tests/unit/test_arrow_conversion.py | 328 +++++++++++++++------------- 2 files changed, 192 insertions(+), 155 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 79338f387..da1177f45 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1367,9 +1367,17 @@ def _convert_arrow_table(self, table): pyarrow.float64(): pandas.Float64Dtype(), pyarrow.string(): pandas.StringDtype(), } + + arrow_pandas_type_override = self.connection._arrow_pandas_type_override + if not isinstance(arrow_pandas_type_override, dict): + logger.debug( + "_arrow_pandas_type_override on connection was not a dict, using default type mapping" + ) + arrow_pandas_type_override = {} + dtype_mapping = { **DEFAULT_DTYPE_MAPPING, - **self.connection._arrow_pandas_type_override, + **arrow_pandas_type_override, } to_pandas_kwargs: dict[str, Any] = { @@ -1377,7 +1385,14 @@ def _convert_arrow_table(self, table): "date_as_object": True, "timestamp_as_object": True, } - to_pandas_kwargs.update(self.connection._arrow_to_pandas_kwargs) + + arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs + if isinstance(arrow_to_pandas_kwargs, dict): + to_pandas_kwargs.update(arrow_to_pandas_kwargs) + else: + logger.debug( + "_arrow_to_pandas_kwargs on connection was not a dict, using default arguments" + ) # Need to rename columns, as the to_pandas function cannot handle duplicate column names table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)]) diff --git a/tests/unit/test_arrow_conversion.py b/tests/unit/test_arrow_conversion.py index 30fd4f04e..78d43635a 100644 --- a/tests/unit/test_arrow_conversion.py +++ b/tests/unit/test_arrow_conversion.py @@ -1,162 +1,184 @@ import pytest -import pyarrow + +try: + import pyarrow as pa +except ImportError: + pa = None import pandas import datetime -from unittest.mock import MagicMock, patch +import unittest +from unittest.mock import MagicMock from databricks.sql.client import ResultSet, Connection, ExecuteResponse from databricks.sql.types import Row from databricks.sql.utils import ArrowQueue - -@pytest.fixture -def mock_connection(): - conn = MagicMock(spec=Connection) - conn.disable_pandas = False - conn._arrow_pandas_type_override = {} - conn._arrow_to_pandas_kwargs = {} - if not hasattr(conn, "_arrow_to_pandas_kwargs"): +@pytest.mark.skipif(pa is None, reason="PyArrow is not installed") +class ArrowConversionTests(unittest.TestCase): + @staticmethod + def mock_connection_static(): + conn = MagicMock(spec=Connection) + conn.disable_pandas = False + conn._arrow_pandas_type_override = {} conn._arrow_to_pandas_kwargs = {} - return conn - - -@pytest.fixture -def mock_thrift_backend(sample_arrow_table): - tb = MagicMock() - empty_arrays = [ - pyarrow.array([], type=field.type) for field in sample_arrow_table.schema - ] - empty_table = pyarrow.Table.from_arrays( - empty_arrays, schema=sample_arrow_table.schema - ) - tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) - return tb - - -@pytest.fixture -def mock_raw_execute_response(): - er = MagicMock(spec=ExecuteResponse) - er.description = [ - ("col_int", "int", None, None, None, None, None), - ("col_str", "string", None, None, None, None, None), - ] - er.arrow_schema_bytes = None - er.arrow_queue = None - er.has_more_rows = False - er.lz4_compressed = False - er.command_handle = MagicMock() - er.status = MagicMock() - er.has_been_closed_server_side = False - er.is_staging_operation = False - return er - - -@pytest.fixture -def sample_arrow_table(): - data = [ - pyarrow.array([1, 2, 3], type=pyarrow.int32()), - pyarrow.array(["a", "b", "c"], type=pyarrow.string()), - ] - schema = pyarrow.schema( - [("col_int", pyarrow.int32()), ("col_str", pyarrow.string())] - ) - return pyarrow.Table.from_arrays(data, schema=schema) - - -def test_convert_arrow_table_default( - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table -): - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result_one = rs.fetchone() - assert isinstance(result_one, Row) - assert result_one.col_int == 1 - assert result_one.col_str == "a" - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result_all = rs.fetchall() - assert len(result_all) == 3 - assert isinstance(result_all[0], Row) - assert result_all[0].col_int == 1 - assert result_all[1].col_str == "b" - - -def test_convert_arrow_table_disable_pandas( - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table -): - mock_connection.disable_pandas = True - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result = rs.fetchall() - assert len(result) == 3 - assert isinstance(result[0], Row) - assert result[0].col_int == 1 - assert result[0].col_str == "a" - assert isinstance(sample_arrow_table.column(0)[0].as_py(), int) - assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) - - -def test_convert_arrow_table_type_override( - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table -): - mock_connection._arrow_pandas_type_override = { - pyarrow.int32(): pandas.Float64Dtype() - } - mock_raw_execute_response.arrow_queue = ArrowQueue( - sample_arrow_table, sample_arrow_table.num_rows - ) - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) - result = rs.fetchall() - assert len(result) == 3 - assert isinstance(result[0].col_int, float) - assert result[0].col_int == 1.0 - assert result[0].col_str == "a" - - -def test_convert_arrow_table_to_pandas_kwargs( - mock_connection, mock_thrift_backend, mock_raw_execute_response -): - dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) - ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC")) - ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))]) - ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) - - mock_raw_execute_response.description = [ - ("col_ts", "timestamp", None, None, None, None, None) - ] - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) - - # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. - mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} - rs_ts_true = ResultSet( - mock_connection, mock_raw_execute_response, mock_thrift_backend - ) - result_true = rs_ts_true.fetchall() - assert len(result_true) == 1 - assert isinstance(result_true[0].col_ts, datetime.datetime) - - # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) - mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} - rs_ts_false = ResultSet( - mock_connection, mock_raw_execute_response, mock_thrift_backend - ) - result_false = rs_ts_false.fetchall() - assert len(result_false) == 1 - assert isinstance(result_false[0].col_ts, pandas.Timestamp) - - # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) - mock_connection._arrow_to_pandas_kwargs = {} - rs_ts_true = ResultSet( - mock_connection, mock_raw_execute_response, mock_thrift_backend - ) - result_true = rs_ts_true.fetchall() - assert len(result_true) == 1 - assert isinstance(result_true[0].col_ts, datetime.datetime) + return conn + + @staticmethod + def sample_arrow_table_static(): + data = [ + pa.array([1, 2, 3], type=pa.int32()), + pa.array(["a", "b", "c"], type=pa.string()), + ] + schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())]) + return pa.Table.from_arrays(data, schema=schema) + + @staticmethod + def mock_thrift_backend_static(): + sample_table = ArrowConversionTests.sample_arrow_table_static() + tb = MagicMock() + empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema] + empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema) + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) + return tb + + @staticmethod + def mock_raw_execute_response_static(): + er = MagicMock(spec=ExecuteResponse) + er.description = [ + ("col_int", "int", None, None, None, None, None), + ("col_str", "string", None, None, None, None, None), + ] + er.arrow_schema_bytes = None + er.arrow_queue = None + er.has_more_rows = False + er.lz4_compressed = False + er.command_handle = MagicMock() + er.status = MagicMock() + er.has_been_closed_server_side = False + er.is_staging_operation = False + return er + + def test_convert_arrow_table_default(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_one = rs.fetchone() + self.assertIsInstance(result_one, Row) + self.assertEqual(result_one.col_int, 1) + self.assertEqual(result_one.col_str, "a") + + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result_all = rs.fetchall() + self.assertEqual(len(result_all), 3) + self.assertIsInstance(result_all[0], Row) + self.assertEqual(result_all[0].col_int, 1) + self.assertEqual(result_all[1].col_str, "b") + + def test_convert_arrow_table_disable_pandas(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection.disable_pandas = True + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0], Row) + self.assertEqual(result[0].col_int, 1) + self.assertEqual(result[0].col_str, "a") + self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int) + self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str) + + def test_convert_arrow_table_type_override(self): + mock_connection = ArrowConversionTests.mock_connection_static() + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + mock_connection._arrow_pandas_type_override = { + pa.int32(): pandas.Float64Dtype() + } + mock_raw_execute_response.arrow_queue = ArrowQueue( + sample_arrow_table, sample_arrow_table.num_rows + ) + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) + result = rs.fetchall() + self.assertEqual(len(result), 3) + self.assertIsInstance(result[0].col_int, float) + self.assertEqual(result[0].col_int, 1.0) + self.assertEqual(result[0].col_str, "a") + + def test_convert_arrow_table_to_pandas_kwargs(self): + mock_connection = ArrowConversionTests.mock_connection_static() + mock_thrift_backend = ( + ArrowConversionTests.mock_thrift_backend_static() + ) # Does not use sample_arrow_table + mock_raw_execute_response = ( + ArrowConversionTests.mock_raw_execute_response_static() + ) + + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) + ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC")) + ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))]) + ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema) + + mock_raw_execute_response.description = [ + ("col_ts", "timestamp", None, None, None, None, None) + ] + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) + + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} + rs_ts_true = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_true = rs_ts_true.fetchall() + self.assertEqual(len(result_true), 1) + self.assertIsInstance(result_true[0].col_ts, datetime.datetime) + + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} + rs_ts_false = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_false = rs_ts_false.fetchall() + self.assertEqual(len(result_false), 1) + self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp) + + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. + mock_raw_execute_response.arrow_queue = ArrowQueue( + ts_table, ts_table.num_rows + ) # Reset queue + mock_connection._arrow_to_pandas_kwargs = {} + rs_ts_default = ResultSet( + mock_connection, mock_raw_execute_response, mock_thrift_backend + ) + result_default = rs_ts_default.fetchall() + self.assertEqual(len(result_default), 1) + self.assertIsInstance(result_default[0].col_ts, datetime.datetime) + + +if __name__ == "__main__": + unittest.main() From 31b44d4d53bdf154a0cd63630509621fcd753264 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 16 Jun 2025 06:38:48 +0000 Subject: [PATCH 4/5] Add _arrow_pandas_type_override and _arrow_to_pandas_kwargs to Connection class --- src/databricks/sql/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index da1177f45..d47c3f24c 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -234,6 +234,10 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self._arrow_pandas_type_override = kwargs.get( + "_arrow_pandas_type_override", {} + ) + self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs From 2f32c6c08d9b40266a931b75b7868d526642de7e Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Mon, 16 Jun 2025 06:47:33 +0000 Subject: [PATCH 5/5] fmt --- src/databricks/sql/client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index d47c3f24c..112a60dc9 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -234,9 +234,7 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) - self._arrow_pandas_type_override = kwargs.get( - "_arrow_pandas_type_override", {} - ) + self._arrow_pandas_type_override = kwargs.get("_arrow_pandas_type_override", {}) self._arrow_to_pandas_kwargs = kwargs.get("_arrow_to_pandas_kwargs", {}) auth_provider = get_python_sql_connector_auth_provider(