Skip to content

Commit d33e5fd

Browse files
Merge branch 'fetch-json-inline' into ext-links-sea
2 parents be17812 + adecd53 commit d33e5fd

File tree

14 files changed

+749
-1113
lines changed

14 files changed

+749
-1113
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 38 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66

77
from databricks.sql.backend.sea.utils.constants import (
88
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
9+
ResultFormat,
10+
ResultDisposition,
11+
ResultCompression,
12+
WaitTimeout,
913
)
1014

1115
if TYPE_CHECKING:
1216
from databricks.sql.client import Cursor
1317
from databricks.sql.result_set import ResultSet
14-
from databricks.sql.backend.sea.models.responses import GetChunksResponse
1518

1619
from databricks.sql.backend.databricks_client import DatabricksClient
1720
from databricks.sql.backend.types import (
@@ -21,16 +24,10 @@
2124
BackendType,
2225
ExecuteResponse,
2326
)
24-
from databricks.sql.exc import Error, NotSupportedError, ServerOperationError
27+
from databricks.sql.exc import ServerOperationError
2528
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
2629
from databricks.sql.thrift_api.TCLIService import ttypes
2730
from databricks.sql.types import SSLOptions
28-
from databricks.sql.utils import SeaResultSetQueueFactory
29-
from databricks.sql.backend.sea.models.base import (
30-
ResultData,
31-
ExternalLink,
32-
ResultManifest,
33-
)
3431

3532
from databricks.sql.backend.sea.models import (
3633
ExecuteStatementRequest,
@@ -45,6 +42,11 @@
4542
CreateSessionResponse,
4643
GetChunksResponse,
4744
)
45+
from databricks.sql.backend.sea.models.responses import (
46+
parse_status,
47+
parse_manifest,
48+
parse_result,
49+
)
4850

4951
logger = logging.getLogger(__name__)
5052

@@ -80,9 +82,6 @@ def _filter_session_configuration(
8082
class SeaDatabricksClient(DatabricksClient):
8183
"""
8284
Statement Execution API (SEA) implementation of the DatabricksClient interface.
83-
84-
This implementation provides session management functionality for SEA,
85-
while other operations raise NotImplementedError.
8685
"""
8786

8887
# SEA API paths
@@ -92,8 +91,6 @@ class SeaDatabricksClient(DatabricksClient):
9291
STATEMENT_PATH = BASE_PATH + "statements"
9392
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9493
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
95-
CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks"
96-
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9794

9895
def __init__(
9996
self,
@@ -126,7 +123,6 @@ def __init__(
126123
)
127124

128125
self._max_download_threads = kwargs.get("max_download_threads", 10)
129-
self.ssl_options = ssl_options
130126

131127
# Extract warehouse ID from http_path
132128
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -283,19 +279,6 @@ def get_default_session_configuration_value(name: str) -> Optional[str]:
283279
"""
284280
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
285281

286-
@staticmethod
287-
def is_session_configuration_parameter_supported(name: str) -> bool:
288-
"""
289-
Check if a session configuration parameter is supported.
290-
291-
Args:
292-
name: The name of the session configuration parameter
293-
294-
Returns:
295-
True if the parameter is supported, False otherwise
296-
"""
297-
return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP
298-
299282
@staticmethod
300283
def get_allowed_session_configurations() -> List[str]:
301284
"""
@@ -343,92 +326,27 @@ def _results_message_to_execute_response(self, sea_response, command_id):
343326
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
344327
result data object, and manifest object
345328
"""
346-
# Extract status
347-
status_data = sea_response.get("status", {})
348-
state = CommandState.from_sea_state(status_data.get("state", ""))
349-
350-
# Extract description from manifest
351-
description = None
352-
manifest_data = sea_response.get("manifest", {})
353-
schema_data = manifest_data.get("schema", {})
354-
columns_data = schema_data.get("columns", [])
355-
356-
if columns_data:
357-
columns = []
358-
for col_data in columns_data:
359-
if not isinstance(col_data, dict):
360-
continue
361-
362-
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
363-
columns.append(
364-
(
365-
col_data.get("name", ""), # name
366-
col_data.get("type_name", ""), # type_code
367-
None, # display_size (not provided by SEA)
368-
None, # internal_size (not provided by SEA)
369-
col_data.get("precision"), # precision
370-
col_data.get("scale"), # scale
371-
col_data.get("nullable", True), # null_ok
372-
)
373-
)
374-
description = columns if columns else None
375329

376-
# Check for compression
377-
lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME"
378-
379-
# Initialize result_data_obj and manifest_obj
380-
result_data_obj = None
381-
manifest_obj = None
382-
383-
result_data = sea_response.get("result", {})
384-
if result_data:
385-
# Convert external links
386-
external_links = None
387-
if "external_links" in result_data:
388-
external_links = []
389-
for link_data in result_data["external_links"]:
390-
external_links.append(
391-
ExternalLink(
392-
external_link=link_data.get("external_link", ""),
393-
expiration=link_data.get("expiration", ""),
394-
chunk_index=link_data.get("chunk_index", 0),
395-
byte_count=link_data.get("byte_count", 0),
396-
row_count=link_data.get("row_count", 0),
397-
row_offset=link_data.get("row_offset", 0),
398-
next_chunk_index=link_data.get("next_chunk_index"),
399-
next_chunk_internal_link=link_data.get(
400-
"next_chunk_internal_link"
401-
),
402-
http_headers=link_data.get("http_headers", {}),
403-
)
404-
)
330+
# Parse the response
331+
status = parse_status(sea_response)
332+
manifest_obj = parse_manifest(sea_response)
333+
result_data_obj = parse_result(sea_response)
405334

406-
# Create the result data object
407-
result_data_obj = ResultData(
408-
data=result_data.get("data_array"), external_links=external_links
409-
)
335+
# Extract description from manifest schema
336+
description = self._extract_description_from_manifest(manifest_obj)
410337

411-
# Create the manifest object
412-
manifest_obj = ResultManifest(
413-
format=manifest_data.get("format", ""),
414-
schema=manifest_data.get("schema", {}),
415-
total_row_count=manifest_data.get("total_row_count", 0),
416-
total_byte_count=manifest_data.get("total_byte_count", 0),
417-
total_chunk_count=manifest_data.get("total_chunk_count", 0),
418-
truncated=manifest_data.get("truncated", False),
419-
chunks=manifest_data.get("chunks"),
420-
result_compression=manifest_data.get("result_compression"),
421-
)
338+
# Check for compression
339+
lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME"
422340

423341
execute_response = ExecuteResponse(
424342
command_id=command_id,
425-
status=state,
343+
status=status.state,
426344
description=description,
427345
has_been_closed_server_side=False,
428346
lz4_compressed=lz4_compressed,
429347
is_staging_operation=False,
430348
arrow_schema_bytes=None,
431-
result_format=manifest_data.get("format"),
349+
result_format=manifest_obj.format,
432350
)
433351

434352
return execute_response, result_data_obj, manifest_obj
@@ -464,6 +382,7 @@ def execute_command(
464382
Returns:
465383
ResultSet: A SeaResultSet instance for the executed command
466384
"""
385+
467386
if session_id.backend_type != BackendType.SEA:
468387
raise ValueError("Not a valid SEA session ID")
469388

@@ -481,17 +400,25 @@ def execute_command(
481400
)
482401
)
483402

484-
format = "ARROW_STREAM" if use_cloud_fetch else "JSON_ARRAY"
485-
disposition = "EXTERNAL_LINKS" if use_cloud_fetch else "INLINE"
486-
result_compression = "LZ4_FRAME" if lz4_compression else None
403+
format = (
404+
ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY
405+
).value
406+
disposition = (
407+
ResultDisposition.EXTERNAL_LINKS
408+
if use_cloud_fetch
409+
else ResultDisposition.INLINE
410+
).value
411+
result_compression = (
412+
ResultCompression.LZ4_FRAME if lz4_compression else ResultCompression.NONE
413+
).value
487414

488415
request = ExecuteStatementRequest(
489416
warehouse_id=self.warehouse_id,
490417
session_id=sea_session_id,
491418
statement=operation,
492419
disposition=disposition,
493420
format=format,
494-
wait_timeout="0s" if async_op else "10s",
421+
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value,
495422
on_wait_timeout="CONTINUE",
496423
row_limit=max_rows,
497424
parameters=sea_parameters if sea_parameters else None,
@@ -517,12 +444,11 @@ def execute_command(
517444
# Store the command ID in the cursor
518445
cursor.active_command_id = command_id
519446

520-
# If async operation, return None and let the client poll for results
447+
# If async operation, return and let the client poll for results
521448
if async_op:
522449
return None
523450

524451
# For synchronous operation, wait for the statement to complete
525-
# Poll until the statement is done
526452
status = response.status
527453
state = status.state
528454

@@ -552,6 +478,7 @@ def cancel_command(self, command_id: CommandId) -> None:
552478
Raises:
553479
ValueError: If the command ID is invalid
554480
"""
481+
555482
if command_id.backend_type != BackendType.SEA:
556483
raise ValueError("Not a valid SEA command ID")
557484

@@ -574,6 +501,7 @@ def close_command(self, command_id: CommandId) -> None:
574501
Raises:
575502
ValueError: If the command ID is invalid
576503
"""
504+
577505
if command_id.backend_type != BackendType.SEA:
578506
raise ValueError("Not a valid SEA command ID")
579507

@@ -599,6 +527,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
599527
Raises:
600528
ValueError: If the command ID is invalid
601529
"""
530+
602531
if command_id.backend_type != BackendType.SEA:
603532
raise ValueError("Not a valid SEA command ID")
604533

@@ -633,6 +562,7 @@ def get_execution_result(
633562
Raises:
634563
ValueError: If the command ID is invalid
635564
"""
565+
636566
if command_id.backend_type != BackendType.SEA:
637567
raise ValueError("Not a valid SEA command ID")
638568

src/databricks/sql/backend/sea/models/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,29 @@ class ExternalLink:
4242
http_headers: Optional[Dict[str, str]] = None
4343

4444

45+
@dataclass
46+
class ChunkInfo:
47+
"""Information about a chunk in the result set."""
48+
49+
chunk_index: int
50+
byte_count: int
51+
row_offset: int
52+
row_count: int
53+
54+
4555
@dataclass
4656
class ResultData:
4757
"""Result data from a statement execution."""
4858

4959
data: Optional[List[List[Any]]] = None
5060
external_links: Optional[List[ExternalLink]] = None
61+
byte_count: Optional[int] = None
62+
chunk_index: Optional[int] = None
63+
next_chunk_index: Optional[int] = None
64+
next_chunk_internal_link: Optional[str] = None
65+
row_count: Optional[int] = None
66+
row_offset: Optional[int] = None
67+
attachment: Optional[bytes] = None
5168

5269

5370
@dataclass
@@ -73,5 +90,6 @@ class ResultManifest:
7390
total_byte_count: int
7491
total_chunk_count: int
7592
truncated: bool = False
76-
chunks: Optional[List[Dict[str, Any]]] = None
93+
chunks: Optional[List[ChunkInfo]] = None
7794
result_compression: Optional[str] = None
95+
is_volume_operation: Optional[bool] = None

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@dataclass
1212
class StatementParameter:
13-
"""Parameter for a SQL statement."""
13+
"""Representation of a parameter for a SQL statement."""
1414

1515
name: str
1616
value: Optional[str] = None
@@ -19,7 +19,7 @@ class StatementParameter:
1919

2020
@dataclass
2121
class ExecuteStatementRequest:
22-
"""Request to execute a SQL statement."""
22+
"""Representation of a request to execute a SQL statement."""
2323

2424
session_id: str
2525
statement: str
@@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]:
6565

6666
@dataclass
6767
class GetStatementRequest:
68-
"""Request to get information about a statement."""
68+
"""Representation of a request to get information about a statement."""
6969

7070
statement_id: str
7171

@@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]:
7676

7777
@dataclass
7878
class CancelStatementRequest:
79-
"""Request to cancel a statement."""
79+
"""Representation of a request to cancel a statement."""
8080

8181
statement_id: str
8282

@@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]:
8787

8888
@dataclass
8989
class CloseStatementRequest:
90-
"""Request to close a statement."""
90+
"""Representation of a request to close a statement."""
9191

9292
statement_id: str
9393

@@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
9898

9999
@dataclass
100100
class CreateSessionRequest:
101-
"""Request to create a new session."""
101+
"""Representation of a request to create a new session."""
102102

103103
warehouse_id: str
104104
session_confs: Optional[Dict[str, str]] = None
@@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]:
123123

124124
@dataclass
125125
class DeleteSessionRequest:
126-
"""Request to delete a session."""
126+
"""Representation of a request to delete a session."""
127127

128128
warehouse_id: str
129129
session_id: str

0 commit comments

Comments
 (0)