Skip to content

Commit 72a5cd3

Browse files
move sea result set into backend/sea package
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 4566cb1 commit 72a5cd3

File tree

7 files changed

+271
-260
lines changed

7 files changed

+271
-260
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from databricks.sql.client import Cursor
20-
from databricks.sql.result_set import SeaResultSet
20+
from databricks.sql.backend.sea.result_set import SeaResultSet
2121

2222
from databricks.sql.backend.databricks_client import DatabricksClient
2323
from databricks.sql.backend.types import (
@@ -613,7 +613,7 @@ def get_execution_result(
613613
response = GetStatementResponse.from_dict(response_data)
614614

615615
# Create and return a SeaResultSet
616-
from databricks.sql.result_set import SeaResultSet
616+
from databricks.sql.backend.sea.result_set import SeaResultSet
617617

618618
execute_response = self._results_message_to_execute_response(response)
619619

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Optional, TYPE_CHECKING
4+
5+
import logging
6+
7+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
8+
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
9+
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter
10+
11+
try:
12+
import pyarrow
13+
except ImportError:
14+
pyarrow = None
15+
16+
if TYPE_CHECKING:
17+
from databricks.sql.client import Connection
18+
from databricks.sql.types import Row
19+
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
20+
from databricks.sql.backend.types import ExecuteResponse
21+
from databricks.sql.result_set import ResultSet
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class SeaResultSet(ResultSet):
27+
"""ResultSet implementation for SEA backend."""
28+
29+
def __init__(
30+
self,
31+
connection: Connection,
32+
execute_response: ExecuteResponse,
33+
sea_client: SeaDatabricksClient,
34+
result_data: ResultData,
35+
manifest: ResultManifest,
36+
buffer_size_bytes: int = 104857600,
37+
arraysize: int = 10000,
38+
):
39+
"""
40+
Initialize a SeaResultSet with the response from a SEA query execution.
41+
42+
Args:
43+
connection: The parent connection
44+
execute_response: Response from the execute command
45+
sea_client: The SeaDatabricksClient instance for direct access
46+
buffer_size_bytes: Buffer size for fetching results
47+
arraysize: Default number of rows to fetch
48+
result_data: Result data from SEA response
49+
manifest: Manifest from SEA response
50+
"""
51+
52+
self.manifest = manifest
53+
54+
statement_id = execute_response.command_id.to_sea_statement_id()
55+
if statement_id is None:
56+
raise ValueError("Command ID is not a SEA statement ID")
57+
58+
results_queue = SeaResultSetQueueFactory.build_queue(
59+
result_data,
60+
self.manifest,
61+
statement_id,
62+
description=execute_response.description,
63+
max_download_threads=sea_client.max_download_threads,
64+
sea_client=sea_client,
65+
lz4_compressed=execute_response.lz4_compressed,
66+
)
67+
68+
# Call parent constructor with common attributes
69+
super().__init__(
70+
connection=connection,
71+
backend=sea_client,
72+
arraysize=arraysize,
73+
buffer_size_bytes=buffer_size_bytes,
74+
command_id=execute_response.command_id,
75+
status=execute_response.status,
76+
has_been_closed_server_side=execute_response.has_been_closed_server_side,
77+
results_queue=results_queue,
78+
description=execute_response.description,
79+
is_staging_operation=execute_response.is_staging_operation,
80+
lz4_compressed=execute_response.lz4_compressed,
81+
arrow_schema_bytes=execute_response.arrow_schema_bytes,
82+
)
83+
84+
def _convert_json_types(self, row: List) -> List:
85+
"""
86+
Convert string values to appropriate Python types based on column metadata.
87+
"""
88+
89+
# JSON + INLINE gives us string values, so we convert them to appropriate
90+
# types based on column metadata
91+
converted_row = []
92+
93+
for i, value in enumerate(row):
94+
column_type = self.description[i][1]
95+
precision = self.description[i][4]
96+
scale = self.description[i][5]
97+
98+
try:
99+
converted_value = SqlTypeConverter.convert_value(
100+
value, column_type, precision=precision, scale=scale
101+
)
102+
converted_row.append(converted_value)
103+
except Exception as e:
104+
logger.warning(
105+
f"Error converting value '{value}' to {column_type}: {e}"
106+
)
107+
converted_row.append(value)
108+
109+
return converted_row
110+
111+
def _convert_json_to_arrow_table(self, rows: List[List]) -> "pyarrow.Table":
112+
"""
113+
Convert raw data rows to Arrow table.
114+
"""
115+
if not rows:
116+
return pyarrow.Table.from_pydict({})
117+
118+
# create a generator for row conversion
119+
converted_rows_iter = (self._convert_json_types(row) for row in rows)
120+
cols = list(map(list, zip(*converted_rows_iter)))
121+
122+
names = [col[0] for col in self.description]
123+
return pyarrow.Table.from_arrays(cols, names=names)
124+
125+
def _create_json_table(self, rows: List[List]) -> List[Row]:
126+
"""
127+
Convert raw data rows to Row objects with named columns based on description.
128+
Also converts string values to appropriate Python types based on column metadata.
129+
130+
Args:
131+
rows: List of raw data rows
132+
Returns:
133+
List of Row objects with named columns and converted values
134+
"""
135+
136+
ResultRow = Row(*[col[0] for col in self.description])
137+
return [ResultRow(*self._convert_json_types(row)) for row in rows]
138+
139+
def fetchmany_json(self, size: int) -> List[List]:
140+
"""
141+
Fetch the next set of rows as a columnar table.
142+
143+
Args:
144+
size: Number of rows to fetch
145+
146+
Returns:
147+
Columnar table containing the fetched rows
148+
149+
Raises:
150+
ValueError: If size is negative
151+
"""
152+
153+
if size < 0:
154+
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
155+
156+
results = self.results.next_n_rows(size)
157+
self._next_row_index += len(results)
158+
159+
return results
160+
161+
def fetchall_json(self) -> List[List]:
162+
"""
163+
Fetch all remaining rows as a columnar table.
164+
165+
Returns:
166+
Columnar table containing all remaining rows
167+
"""
168+
169+
results = self.results.remaining_rows()
170+
self._next_row_index += len(results)
171+
172+
return results
173+
174+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
175+
"""
176+
Fetch the next set of rows as an Arrow table.
177+
178+
Args:
179+
size: Number of rows to fetch
180+
181+
Returns:
182+
PyArrow Table containing the fetched rows
183+
184+
Raises:
185+
ImportError: If PyArrow is not installed
186+
ValueError: If size is negative
187+
"""
188+
189+
if size < 0:
190+
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
191+
192+
if not isinstance(self.results, JsonQueue):
193+
raise NotImplementedError("fetchmany_arrow only supported for JSON data")
194+
195+
results = self._convert_json_to_arrow_table(self.results.next_n_rows(size))
196+
self._next_row_index += results.num_rows
197+
198+
return results
199+
200+
def fetchall_arrow(self) -> "pyarrow.Table":
201+
"""
202+
Fetch all remaining rows as an Arrow table.
203+
"""
204+
205+
if not isinstance(self.results, JsonQueue):
206+
raise NotImplementedError("fetchall_arrow only supported for JSON data")
207+
208+
results = self._convert_json_to_arrow_table(self.results.remaining_rows())
209+
self._next_row_index += results.num_rows
210+
211+
return results
212+
213+
def fetchone(self) -> Optional[Row]:
214+
"""
215+
Fetch the next row of a query result set, returning a single sequence,
216+
or None when no more data is available.
217+
218+
Returns:
219+
A single Row object or None if no more rows are available
220+
"""
221+
222+
if isinstance(self.results, JsonQueue):
223+
res = self._create_json_table(self.fetchmany_json(1))
224+
else:
225+
raise NotImplementedError("fetchone only supported for JSON data")
226+
227+
return res[0] if res else None
228+
229+
def fetchmany(self, size: int) -> List[Row]:
230+
"""
231+
Fetch the next set of rows of a query result, returning a list of rows.
232+
233+
Args:
234+
size: Number of rows to fetch (defaults to arraysize if None)
235+
236+
Returns:
237+
List of Row objects
238+
239+
Raises:
240+
ValueError: If size is negative
241+
"""
242+
243+
if isinstance(self.results, JsonQueue):
244+
return self._create_json_table(self.fetchmany_json(size))
245+
else:
246+
raise NotImplementedError("fetchmany only supported for JSON data")
247+
248+
def fetchall(self) -> List[Row]:
249+
"""
250+
Fetch all remaining rows of a query result, returning them as a list of rows.
251+
252+
Returns:
253+
List of Row objects containing all remaining rows
254+
"""
255+
256+
if isinstance(self.results, JsonQueue):
257+
return self._create_json_table(self.fetchall_json())
258+
else:
259+
raise NotImplementedError("fetchall only supported for JSON data")

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
if TYPE_CHECKING:
20-
from databricks.sql.result_set import SeaResultSet
20+
from databricks.sql.backend.sea.result_set import SeaResultSet
2121

2222
from databricks.sql.backend.types import ExecuteResponse
2323

@@ -70,7 +70,7 @@ def _filter_sea_result_set(
7070
result_data = ResultData(data=filtered_rows, external_links=None)
7171

7272
from databricks.sql.backend.sea.backend import SeaDatabricksClient
73-
from databricks.sql.result_set import SeaResultSet
73+
from databricks.sql.backend.sea.result_set import SeaResultSet
7474

7575
# Create a new SeaResultSet with the filtered data
7676
manifest = result_set.manifest

0 commit comments

Comments
 (0)