Skip to content

Commit dd7dc6a

Browse files
convert complex types to string if not _use_arrow_native_complex_types
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent bc467d1 commit dd7dc6a

File tree

4 files changed

+58
-9
lines changed

4 files changed

+58
-9
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from abc import ABC, abstractmethod
1212
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING
1313

14+
from databricks.sql.types import SSLOptions
15+
1416
if TYPE_CHECKING:
1517
from databricks.sql.client import Cursor
1618

@@ -25,6 +27,13 @@
2527

2628

2729
class DatabricksClient(ABC):
30+
def __init__(self, ssl_options: SSLOptions, **kwargs):
31+
self._use_arrow_native_complex_types = kwargs.get(
32+
"_use_arrow_native_complex_types", True
33+
)
34+
self._max_download_threads = kwargs.get("max_download_threads", 10)
35+
self._ssl_options = ssl_options
36+
2837
# == Connection and Session Management ==
2938
@abstractmethod
3039
def open_session(

src/databricks/sql/backend/sea/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
http_path,
125125
)
126126

127-
self._max_download_threads = kwargs.get("max_download_threads", 10)
127+
super().__init__(ssl_options, **kwargs)
128128

129129
# Extract warehouse ID from http_path
130130
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -136,7 +136,7 @@ def __init__(
136136
http_path=http_path,
137137
http_headers=http_headers,
138138
auth_provider=auth_provider,
139-
ssl_options=ssl_options,
139+
ssl_options=self._ssl_options,
140140
**kwargs,
141141
)
142142

src/databricks/sql/backend/thrift_backend.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def __init__(
147147
http_path,
148148
)
149149

150+
super().__init__(ssl_options, **kwargs)
151+
150152
port = port or 443
151153
if kwargs.get("_connection_uri"):
152154
uri = kwargs.get("_connection_uri")
@@ -160,19 +162,13 @@ def __init__(
160162
raise ValueError("No valid connection settings.")
161163

162164
self._initialize_retry_args(kwargs)
163-
self._use_arrow_native_complex_types = kwargs.get(
164-
"_use_arrow_native_complex_types", True
165-
)
165+
166166
self._use_arrow_native_decimals = kwargs.get("_use_arrow_native_decimals", True)
167167
self._use_arrow_native_timestamps = kwargs.get(
168168
"_use_arrow_native_timestamps", True
169169
)
170170

171171
# Cloud fetch
172-
self._max_download_threads = kwargs.get("max_download_threads", 10)
173-
174-
self._ssl_options = ssl_options
175-
176172
self._auth_provider = auth_provider
177173

178174
# Connector version 3 retry approach

src/databricks/sql/result_set.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
import json
23
from typing import List, Optional, Any, Union, Tuple, TYPE_CHECKING
34

45
import logging
@@ -551,6 +552,43 @@ def fetchall_json(self):
551552

552553
return results
553554

555+
def _convert_complex_types_to_string(
556+
self, rows: "pyarrow.Table"
557+
) -> "pyarrow.Table":
558+
"""
559+
Convert complex types (array, struct, map) to string representation.
560+
561+
Args:
562+
rows: Input PyArrow table
563+
564+
Returns:
565+
PyArrow table with complex types converted to strings
566+
"""
567+
568+
if not pyarrow:
569+
return rows
570+
571+
def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array":
572+
python_values = col.to_pylist()
573+
json_strings = [
574+
(None if val is None else json.dumps(val)) for val in python_values
575+
]
576+
return pyarrow.array(json_strings, type=pyarrow.string())
577+
578+
converted_columns = []
579+
for col in rows.columns:
580+
converted_col = col
581+
if (
582+
pyarrow.types.is_list(col.type)
583+
or pyarrow.types.is_large_list(col.type)
584+
or pyarrow.types.is_struct(col.type)
585+
or pyarrow.types.is_map(col.type)
586+
):
587+
converted_col = convert_complex_column_to_string(col)
588+
converted_columns.append(converted_col)
589+
590+
return pyarrow.Table.from_arrays(converted_columns, names=rows.column_names)
591+
554592
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
555593
"""
556594
Fetch the next set of rows as an Arrow table.
@@ -571,6 +609,9 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
571609
results = self.results.next_n_rows(size)
572610
self._next_row_index += results.num_rows
573611

612+
if not self.backend._use_arrow_native_complex_types:
613+
results = self._convert_complex_types_to_string(results)
614+
574615
return results
575616

576617
def fetchall_arrow(self) -> "pyarrow.Table":
@@ -580,6 +621,9 @@ def fetchall_arrow(self) -> "pyarrow.Table":
580621
results = self.results.remaining_rows()
581622
self._next_row_index += results.num_rows
582623

624+
if not self.backend._use_arrow_native_complex_types:
625+
results = self._convert_complex_types_to_string(results)
626+
583627
return results
584628

585629
def fetchone(self) -> Optional[Row]:

0 commit comments

Comments
 (0)