Skip to content

Commit d59b351

Browse files
constrain backend diff
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 44183db commit d59b351

File tree

1 file changed

+66
-62
lines changed

1 file changed

+66
-62
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
import logging
2-
import uuid
32
import time
43
import re
5-
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set
6-
7-
from databricks.sql.backend.sea.models.base import ExternalLink
4+
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
85

6+
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
97
from databricks.sql.backend.sea.utils.constants import (
108
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
11-
MetadataCommands,
129
ResultFormat,
1310
ResultDisposition,
1411
ResultCompression,
1512
WaitTimeout,
13+
MetadataCommands,
1614
)
1715

1816
if TYPE_CHECKING:
@@ -29,7 +27,6 @@
2927
)
3028
from databricks.sql.exc import DatabaseError, ServerOperationError
3129
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
32-
from databricks.sql.thrift_api.TCLIService import ttypes
3330
from databricks.sql.types import SSLOptions
3431

3532
from databricks.sql.backend.sea.models import (
@@ -43,6 +40,8 @@
4340
ExecuteStatementResponse,
4441
GetStatementResponse,
4542
CreateSessionResponse,
43+
)
44+
from databricks.sql.backend.sea.models.responses import (
4645
GetChunksResponse,
4746
)
4847

@@ -91,6 +90,9 @@ class SeaDatabricksClient(DatabricksClient):
9190
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
9291
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9392

93+
# SEA constants
94+
POLL_INTERVAL_SECONDS = 0.2
95+
9496
def __init__(
9597
self,
9698
server_hostname: str,
@@ -121,7 +123,7 @@ def __init__(
121123
http_path,
122124
)
123125

124-
super().__init__(ssl_options, **kwargs)
126+
super().__init__(ssl_options=ssl_options, **kwargs)
125127

126128
# Extract warehouse ID from http_path
127129
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -288,28 +290,28 @@ def get_allowed_session_configurations() -> List[str]:
288290
"""
289291
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
290292

291-
def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
293+
def _extract_description_from_manifest(
294+
self, manifest: ResultManifest
295+
) -> Optional[List]:
292296
"""
293-
Extract column description from a manifest object.
297+
Extract column description from a manifest object, in the format defined by
298+
the spec: https://peps.python.org/pep-0249/#description
294299
295300
Args:
296-
manifest_obj: The ResultManifest object containing schema information
301+
manifest: The ResultManifest object containing schema information
297302
298303
Returns:
299304
Optional[List]: A list of column tuples or None if no columns are found
300305
"""
301306

302-
schema_data = manifest_obj.schema
307+
schema_data = manifest.schema
303308
columns_data = schema_data.get("columns", [])
304309

305310
if not columns_data:
306311
return None
307312

308313
columns = []
309314
for col_data in columns_data:
310-
if not isinstance(col_data, dict):
311-
continue
312-
313315
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
314316
columns.append(
315317
(
@@ -325,38 +327,9 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
325327

326328
return columns if columns else None
327329

328-
def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
329-
"""
330-
Get links for chunks starting from the specified index.
331-
332-
Args:
333-
statement_id: The statement ID
334-
chunk_index: The starting chunk index
335-
336-
Returns:
337-
ExternalLink: External link for the chunk
338-
"""
339-
340-
response_data = self.http_client._make_request(
341-
method="GET",
342-
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
343-
)
344-
response = GetChunksResponse.from_dict(response_data)
345-
346-
links = response.external_links
347-
link = next((l for l in links if l.chunk_index == chunk_index), None)
348-
if not link:
349-
raise ServerOperationError(
350-
f"No link found for chunk index {chunk_index}",
351-
{
352-
"operation-id": statement_id,
353-
"diagnostic-info": None,
354-
},
355-
)
356-
357-
return link
358-
359-
def _results_message_to_execute_response(self, response: GetStatementResponse, command_id: CommandId):
330+
def _results_message_to_execute_response(
331+
self, response: GetStatementResponse
332+
) -> ExecuteResponse:
360333
"""
361334
Convert a SEA response to an ExecuteResponse and extract result data.
362335
@@ -365,18 +338,19 @@ def _results_message_to_execute_response(self, response: GetStatementResponse, c
365338
command_id: The command ID
366339
367340
Returns:
368-
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
369-
result data object, and manifest object
341+
ExecuteResponse: The normalized execute response
370342
"""
371343

372344
# Extract description from manifest schema
373345
description = self._extract_description_from_manifest(response.manifest)
374346

375347
# Check for compression
376-
lz4_compressed = response.manifest.result_compression == ResultCompression.LZ4_FRAME.value
348+
lz4_compressed = (
349+
response.manifest.result_compression == ResultCompression.LZ4_FRAME.value
350+
)
377351

378352
execute_response = ExecuteResponse(
379-
command_id=command_id,
353+
command_id=CommandId.from_sea_statement_id(response.statement_id),
380354
status=response.status.state,
381355
description=description,
382356
has_been_closed_server_side=False,
@@ -433,7 +407,7 @@ def execute_command(
433407
lz4_compression: bool,
434408
cursor: "Cursor",
435409
use_cloud_fetch: bool,
436-
parameters: List,
410+
parameters: List[Dict[str, Any]],
437411
async_op: bool,
438412
enforce_embedded_schema_correctness: bool,
439413
) -> Union["ResultSet", None]:
@@ -467,9 +441,9 @@ def execute_command(
467441
for param in parameters:
468442
sea_parameters.append(
469443
StatementParameter(
470-
name=param.name,
471-
value=param.value,
472-
type=param.type if hasattr(param, "type") else None,
444+
name=param["name"],
445+
value=param["value"],
446+
type=param["type"] if "type" in param else None,
473447
)
474448
)
475449

@@ -638,8 +612,7 @@ def get_execution_result(
638612
# Create and return a SeaResultSet
639613
from databricks.sql.result_set import SeaResultSet
640614

641-
# Convert the response to an ExecuteResponse and extract result data
642-
execute_response = self._results_message_to_execute_response(response, command_id)
615+
execute_response = self._results_message_to_execute_response(response)
643616

644617
return SeaResultSet(
645618
connection=cursor.connection,
@@ -651,6 +624,35 @@ def get_execution_result(
651624
manifest=response.manifest,
652625
)
653626

627+
def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
628+
"""
629+
Get links for chunks starting from the specified index.
630+
Args:
631+
statement_id: The statement ID
632+
chunk_index: The starting chunk index
633+
Returns:
634+
ExternalLink: External link for the chunk
635+
"""
636+
637+
response_data = self.http_client._make_request(
638+
method="GET",
639+
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
640+
)
641+
response = GetChunksResponse.from_dict(response_data)
642+
643+
links = response.external_links
644+
link = next((l for l in links if l.chunk_index == chunk_index), None)
645+
if not link:
646+
raise ServerOperationError(
647+
f"No link found for chunk index {chunk_index}",
648+
{
649+
"operation-id": statement_id,
650+
"diagnostic-info": None,
651+
},
652+
)
653+
654+
return link
655+
654656
# == Metadata Operations ==
655657

656658
def get_catalogs(
@@ -692,7 +694,7 @@ def get_schemas(
692694
operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name)
693695

694696
if schema_name:
695-
operation += f" LIKE '{schema_name}'"
697+
operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name)
696698

697699
result = self.execute_command(
698700
operation=operation,
@@ -724,17 +726,19 @@ def get_tables(
724726
if not catalog_name:
725727
raise ValueError("Catalog name is required for get_tables")
726728

727-
operation = MetadataCommands.SHOW_TABLES.value.format(
729+
operation = (
728730
MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value
729731
if catalog_name in [None, "*", "%"]
730-
else MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name)
732+
else MetadataCommands.SHOW_TABLES.value.format(
733+
MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name)
734+
)
731735
)
732736

733737
if schema_name:
734-
operation += f" SCHEMA LIKE '{schema_name}'"
738+
operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name)
735739

736740
if table_name:
737-
operation += f" LIKE '{table_name}'"
741+
operation += MetadataCommands.LIKE_PATTERN.value.format(table_name)
738742

739743
result = self.execute_command(
740744
operation=operation,
@@ -750,7 +754,7 @@ def get_tables(
750754
)
751755
assert result is not None, "execute_command returned None in synchronous mode"
752756

753-
# Apply client-side filtering by table_types if specified
757+
# Apply client-side filtering by table_types
754758
from databricks.sql.backend.filters import ResultSetFilter
755759

756760
result = ResultSetFilter.filter_tables_by_type(result, table_types)

0 commit comments

Comments
 (0)