Skip to content

Commit b77acbe

Browse files
Merge branch 'sea-migration' into ext-links-sea
2 parents 38c2b88 + 70c7dc8 commit b77acbe

15 files changed

+557
-483
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from databricks.sql.client import Cursor
20-
from databricks.sql.result_set import SeaResultSet
20+
from databricks.sql.backend.sea.result_set import SeaResultSet
2121

2222
from databricks.sql.backend.databricks_client import DatabricksClient
2323
from databricks.sql.backend.types import (
@@ -253,7 +253,7 @@ def close_session(self, session_id: SessionId) -> None:
253253
logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id)
254254

255255
if session_id.backend_type != BackendType.SEA:
256-
raise ProgrammingError("Not a valid SEA session ID")
256+
raise ValueError("Not a valid SEA session ID")
257257
sea_session_id = session_id.to_sea_session_id()
258258

259259
request_data = DeleteSessionRequest(
@@ -292,7 +292,7 @@ def get_allowed_session_configurations() -> List[str]:
292292

293293
def _extract_description_from_manifest(
294294
self, manifest: ResultManifest
295-
) -> Optional[List]:
295+
) -> List[Tuple]:
296296
"""
297297
Extract column description from a manifest object, in the format defined by
298298
the spec: https://peps.python.org/pep-0249/#description
@@ -301,15 +301,12 @@ def _extract_description_from_manifest(
301301
manifest: The ResultManifest object containing schema information
302302
303303
Returns:
304-
Optional[List]: A list of column tuples or None if no columns are found
304+
List[Tuple]: A list of column tuples
305305
"""
306306

307307
schema_data = manifest.schema
308308
columns_data = schema_data.get("columns", [])
309309

310-
if not columns_data:
311-
return None
312-
313310
columns = []
314311
for col_data in columns_data:
315312
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
@@ -325,7 +322,7 @@ def _extract_description_from_manifest(
325322
)
326323
)
327324

328-
return columns if columns else None
325+
return columns
329326

330327
def _results_message_to_execute_response(
331328
self, response: GetStatementResponse
@@ -431,7 +428,7 @@ def execute_command(
431428
"""
432429

433430
if session_id.backend_type != BackendType.SEA:
434-
raise ProgrammingError("Not a valid SEA session ID")
431+
raise ValueError("Not a valid SEA session ID")
435432

436433
sea_session_id = session_id.to_sea_session_id()
437434

@@ -510,9 +507,11 @@ def cancel_command(self, command_id: CommandId) -> None:
510507
"""
511508

512509
if command_id.backend_type != BackendType.SEA:
513-
raise ProgrammingError("Not a valid SEA command ID")
510+
raise ValueError("Not a valid SEA command ID")
514511

515512
sea_statement_id = command_id.to_sea_statement_id()
513+
if sea_statement_id is None:
514+
raise ValueError("Not a valid SEA command ID")
516515

517516
request = CancelStatementRequest(statement_id=sea_statement_id)
518517
self.http_client._make_request(
@@ -533,9 +532,11 @@ def close_command(self, command_id: CommandId) -> None:
533532
"""
534533

535534
if command_id.backend_type != BackendType.SEA:
536-
raise ProgrammingError("Not a valid SEA command ID")
535+
raise ValueError("Not a valid SEA command ID")
537536

538537
sea_statement_id = command_id.to_sea_statement_id()
538+
if sea_statement_id is None:
539+
raise ValueError("Not a valid SEA command ID")
539540

540541
request = CloseStatementRequest(statement_id=sea_statement_id)
541542
self.http_client._make_request(
@@ -562,6 +563,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
562563
raise ProgrammingError("Not a valid SEA command ID")
563564

564565
sea_statement_id = command_id.to_sea_statement_id()
566+
if sea_statement_id is None:
567+
raise ValueError("Not a valid SEA command ID")
565568

566569
request = GetStatementRequest(statement_id=sea_statement_id)
567570
response_data = self.http_client._make_request(
@@ -594,9 +597,11 @@ def get_execution_result(
594597
"""
595598

596599
if command_id.backend_type != BackendType.SEA:
597-
raise ProgrammingError("Not a valid SEA command ID")
600+
raise ValueError("Not a valid SEA command ID")
598601

599602
sea_statement_id = command_id.to_sea_statement_id()
603+
if sea_statement_id is None:
604+
raise ValueError("Not a valid SEA command ID")
600605

601606
# Create the request model
602607
request = GetStatementRequest(statement_id=sea_statement_id)
@@ -610,7 +615,7 @@ def get_execution_result(
610615
response = GetStatementResponse.from_dict(response_data)
611616

612617
# Create and return a SeaResultSet
613-
from databricks.sql.result_set import SeaResultSet
618+
from databricks.sql.backend.sea.result_set import SeaResultSet
614619

615620
execute_response = self._results_message_to_execute_response(response)
616621

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC
4+
from typing import List, Optional, Tuple
5+
6+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
7+
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
8+
from databricks.sql.backend.sea.utils.constants import ResultFormat
9+
from databricks.sql.exc import ProgrammingError
10+
from databricks.sql.utils import ResultSetQueue
11+
12+
13+
class SeaResultSetQueueFactory(ABC):
14+
@staticmethod
15+
def build_queue(
16+
sea_result_data: ResultData,
17+
manifest: ResultManifest,
18+
statement_id: str,
19+
description: List[Tuple] = [],
20+
max_download_threads: Optional[int] = None,
21+
sea_client: Optional[SeaDatabricksClient] = None,
22+
lz4_compressed: bool = False,
23+
) -> ResultSetQueue:
24+
"""
25+
Factory method to build a result set queue for SEA backend.
26+
27+
Args:
28+
sea_result_data (ResultData): Result data from SEA response
29+
manifest (ResultManifest): Manifest from SEA response
30+
statement_id (str): Statement ID for the query
31+
description (List[List[Any]]): Column descriptions
32+
max_download_threads (int): Maximum number of download threads
33+
sea_client (SeaDatabricksClient): SEA client for fetching additional links
34+
lz4_compressed (bool): Whether the data is LZ4 compressed
35+
36+
Returns:
37+
ResultSetQueue: The appropriate queue for the result data
38+
"""
39+
40+
if manifest.format == ResultFormat.JSON_ARRAY.value:
41+
# INLINE disposition with JSON_ARRAY format
42+
return JsonQueue(sea_result_data.data)
43+
elif manifest.format == ResultFormat.ARROW_STREAM.value:
44+
# EXTERNAL_LINKS disposition
45+
raise NotImplementedError(
46+
"EXTERNAL_LINKS disposition is not implemented for SEA backend"
47+
)
48+
raise ProgrammingError("Invalid result format")
49+
50+
51+
class JsonQueue(ResultSetQueue):
52+
"""Queue implementation for JSON_ARRAY format data."""
53+
54+
def __init__(self, data_array: Optional[List[List[str]]]):
55+
"""Initialize with JSON array data."""
56+
self.data_array = data_array or []
57+
self.cur_row_index = 0
58+
self.num_rows = len(self.data_array)
59+
60+
def next_n_rows(self, num_rows: int) -> List[List[str]]:
61+
"""Get the next n rows from the data array."""
62+
length = min(num_rows, self.num_rows - self.cur_row_index)
63+
slice = self.data_array[self.cur_row_index : self.cur_row_index + length]
64+
self.cur_row_index += length
65+
return slice
66+
67+
def remaining_rows(self) -> List[List[str]]:
68+
"""Get all remaining rows from the data array."""
69+
slice = self.data_array[self.cur_row_index :]
70+
self.cur_row_index += len(slice)
71+
return slice

0 commit comments

Comments
 (0)