Skip to content

Commit be68e2f

Browse files
Jessesaishreeeee
authored andcommitted
[PECO-1286] Add tests for complex types in query results (#293)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com> Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 83e8565 commit be68e2f

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

src/databricks/sql/client.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
session_configuration: Dict[str, Any] = None,
6262
catalog: Optional[str] = None,
6363
schema: Optional[str] = None,
64+
_use_arrow_native_complex_types: Optional[bool] = True,
6465
**kwargs,
6566
) -> None:
6667
"""
@@ -152,8 +153,13 @@ def read(self) -> Optional[OAuthToken]:
152153
experimental_oauth_persistence=DevOnlyFilePersistence("~/dev-oauth.json")
153154
)
154155
```
155-
156-
156+
:param _use_arrow_native_complex_types: `bool`, optional
157+
Controls whether a complex type field value is returned as a string or as a native Arrow type. Defaults to True.
158+
When True:
159+
MAP is returned as List[Tuple[str, Any]]
160+
STRUCT is returned as Dict[str, Any]
161+
ARRAY is returned as numpy.ndarray
162+
When False, complex types are returned as a strings. These are generally deserializable as JSON.
157163
"""
158164

159165
# Internal arguments in **kwargs:
@@ -184,9 +190,6 @@ def read(self) -> Optional[OAuthToken]:
184190
# _disable_pandas
185191
# In case the deserialisation through pandas causes any issues, it can be disabled with
186192
# this flag.
187-
# _use_arrow_native_complex_types
188-
# DBR will return native Arrow types for structs, arrays and maps instead of Arrow strings
189-
# (True by default)
190193
# _use_arrow_native_decimals
191194
# Databricks runtime will return native Arrow types for decimals instead of Arrow strings
192195
# (True by default)
@@ -225,6 +228,7 @@ def read(self) -> Optional[OAuthToken]:
225228
http_path,
226229
(http_headers or []) + base_headers,
227230
auth_provider,
231+
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
228232
**kwargs,
229233
)
230234

tests/e2e/test_complex_types.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
import pytest
3+
from numpy import ndarray
4+
5+
from tests.e2e.test_driver import PySQLPytestTestCase
6+
7+
8+
class TestComplexTypes(PySQLPytestTestCase):
9+
@pytest.fixture(scope="class")
10+
def table_fixture(self):
11+
"""A pytest fixture that creates a table with a complex type, inserts a record, yields, and then drops the table"""
12+
13+
with self.cursor() as cursor:
14+
# Create the table
15+
cursor.execute(
16+
"""
17+
CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table (
18+
array_col ARRAY<STRING>,
19+
map_col MAP<STRING, INTEGER>,
20+
struct_col STRUCT<field1: STRING, field2: INTEGER>
21+
)
22+
"""
23+
)
24+
# Insert a record
25+
cursor.execute(
26+
"""
27+
INSERT INTO pysql_test_complex_types_table
28+
VALUES (
29+
ARRAY('a', 'b', 'c'),
30+
MAP('a', 1, 'b', 2, 'c', 3),
31+
NAMED_STRUCT('field1', 'a', 'field2', 1)
32+
)
33+
"""
34+
)
35+
yield
36+
# Clean up the table after the test
37+
cursor.execute("DROP TABLE IF EXISTS pysql_test_complex_types_table")
38+
39+
@pytest.mark.parametrize(
40+
"field,expected_type",
41+
[("array_col", ndarray), ("map_col", list), ("struct_col", dict)],
42+
)
43+
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
44+
"""Confirms the return types of a complex type field when reading as arrow"""
45+
46+
with self.cursor() as cursor:
47+
result = cursor.execute(
48+
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
49+
).fetchone()
50+
51+
assert isinstance(result[field], expected_type)
52+
53+
@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
54+
def test_read_complex_types_as_string(self, field, table_fixture):
55+
"""Confirms the return type of a complex type that is returned as a string"""
56+
with self.cursor(
57+
extra_params={"_use_arrow_native_complex_types": False}
58+
) as cursor:
59+
result = cursor.execute(
60+
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
61+
).fetchone()
62+
63+
assert isinstance(result[field], str)

0 commit comments

Comments
 (0)