Skip to content

Commit 20822e4

Browse files
remove un-necessary backend changes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 5e75fb5 commit 20822e4

File tree

1 file changed

+91
-107
lines changed

1 file changed

+91
-107
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 91 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import logging
2-
import uuid
32
import time
43
import re
5-
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set
4+
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
65

7-
from databricks.sql.backend.sea.models.base import ExternalLink
6+
from databricks.sql.backend.sea.models.base import ResultManifest
87
from databricks.sql.backend.sea.utils.constants import (
98
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
109
ResultFormat,
1110
ResultDisposition,
1211
ResultCompression,
1312
WaitTimeout,
13+
MetadataCommands,
1414
)
1515

1616
if TYPE_CHECKING:
@@ -25,9 +25,8 @@
2525
BackendType,
2626
ExecuteResponse,
2727
)
28-
from databricks.sql.exc import ServerOperationError
28+
from databricks.sql.exc import DatabaseError, ServerOperationError
2929
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
30-
from databricks.sql.thrift_api.TCLIService import ttypes
3130
from databricks.sql.types import SSLOptions
3231

3332
from databricks.sql.backend.sea.models import (
@@ -41,12 +40,11 @@
4140
ExecuteStatementResponse,
4241
GetStatementResponse,
4342
CreateSessionResponse,
44-
GetChunksResponse,
4543
)
4644
from databricks.sql.backend.sea.models.responses import (
47-
parse_status,
48-
parse_manifest,
49-
parse_result,
45+
_parse_status,
46+
_parse_manifest,
47+
_parse_result,
5048
)
5149

5250
logger = logging.getLogger(__name__)
@@ -92,7 +90,9 @@ class SeaDatabricksClient(DatabricksClient):
9290
STATEMENT_PATH = BASE_PATH + "statements"
9391
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9492
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
95-
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
93+
94+
# SEA constants
95+
POLL_INTERVAL_SECONDS = 0.2
9696

9797
def __init__(
9898
self,
@@ -124,7 +124,7 @@ def __init__(
124124
http_path,
125125
)
126126

127-
super().__init__(ssl_options, **kwargs)
127+
self._max_download_threads = kwargs.get("max_download_threads", 10)
128128

129129
# Extract warehouse ID from http_path
130130
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -136,7 +136,7 @@ def __init__(
136136
http_path=http_path,
137137
http_headers=http_headers,
138138
auth_provider=auth_provider,
139-
ssl_options=self._ssl_options,
139+
ssl_options=ssl_options,
140140
**kwargs,
141141
)
142142

@@ -291,28 +291,28 @@ def get_allowed_session_configurations() -> List[str]:
291291
"""
292292
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
293293

294-
def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
294+
def _extract_description_from_manifest(
295+
self, manifest: ResultManifest
296+
) -> Optional[List]:
295297
"""
296-
Extract column description from a manifest object.
298+
Extract column description from a manifest object, in the format defined by
299+
the spec: https://peps.python.org/pep-0249/#description
297300
298301
Args:
299-
manifest_obj: The ResultManifest object containing schema information
302+
manifest: The ResultManifest object containing schema information
300303
301304
Returns:
302305
Optional[List]: A list of column tuples or None if no columns are found
303306
"""
304307

305-
schema_data = manifest_obj.schema
308+
schema_data = manifest.schema
306309
columns_data = schema_data.get("columns", [])
307310

308311
if not columns_data:
309312
return None
310313

311314
columns = []
312315
for col_data in columns_data:
313-
if not isinstance(col_data, dict):
314-
continue
315-
316316
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
317317
columns.append(
318318
(
@@ -328,38 +328,9 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
328328

329329
return columns if columns else None
330330

331-
def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
332-
"""
333-
Get links for chunks starting from the specified index.
334-
335-
Args:
336-
statement_id: The statement ID
337-
chunk_index: The starting chunk index
338-
339-
Returns:
340-
ExternalLink: External link for the chunk
341-
"""
342-
343-
response_data = self.http_client._make_request(
344-
method="GET",
345-
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
346-
)
347-
response = GetChunksResponse.from_dict(response_data)
348-
349-
links = response.external_links
350-
link = next((l for l in links if l.chunk_index == chunk_index), None)
351-
if not link:
352-
raise ServerOperationError(
353-
f"No link found for chunk index {chunk_index}",
354-
{
355-
"operation-id": statement_id,
356-
"diagnostic-info": None,
357-
},
358-
)
359-
360-
return link
361-
362-
def _results_message_to_execute_response(self, sea_response, command_id):
331+
def _results_message_to_execute_response(
332+
self, response: GetStatementResponse
333+
) -> ExecuteResponse:
363334
"""
364335
Convert a SEA response to an ExecuteResponse and extract result data.
365336
@@ -368,33 +339,65 @@ def _results_message_to_execute_response(self, sea_response, command_id):
368339
command_id: The command ID
369340
370341
Returns:
371-
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
372-
result data object, and manifest object
342+
ExecuteResponse: The normalized execute response
373343
"""
374344

375-
# Parse the response
376-
status = parse_status(sea_response)
377-
manifest_obj = parse_manifest(sea_response)
378-
result_data_obj = parse_result(sea_response)
379-
380345
# Extract description from manifest schema
381-
description = self._extract_description_from_manifest(manifest_obj)
346+
description = self._extract_description_from_manifest(response.manifest)
382347

383348
# Check for compression
384-
lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME"
349+
lz4_compressed = (
350+
response.manifest.result_compression == ResultCompression.LZ4_FRAME
351+
)
385352

386353
execute_response = ExecuteResponse(
387-
command_id=command_id,
388-
status=status.state,
354+
command_id=CommandId.from_sea_statement_id(response.statement_id),
355+
status=response.status.state,
389356
description=description,
390357
has_been_closed_server_side=False,
391358
lz4_compressed=lz4_compressed,
392359
is_staging_operation=False,
393360
arrow_schema_bytes=None,
394-
result_format=manifest_obj.format,
361+
result_format=response.manifest.format,
395362
)
396363

397-
return execute_response, result_data_obj, manifest_obj
364+
return execute_response
365+
366+
def _check_command_not_in_failed_or_closed_state(
367+
self, state: CommandState, command_id: CommandId
368+
) -> None:
369+
if state == CommandState.CLOSED:
370+
raise DatabaseError(
371+
"Command {} unexpectedly closed server side".format(command_id),
372+
{
373+
"operation-id": command_id,
374+
},
375+
)
376+
if state == CommandState.FAILED:
377+
raise ServerOperationError(
378+
"Command {} failed".format(command_id),
379+
{
380+
"operation-id": command_id,
381+
},
382+
)
383+
384+
def _wait_until_command_done(
385+
self, response: ExecuteStatementResponse
386+
) -> CommandState:
387+
"""
388+
Wait until a command is done.
389+
"""
390+
391+
state = response.status.state
392+
command_id = CommandId.from_sea_statement_id(response.statement_id)
393+
394+
while state in [CommandState.PENDING, CommandState.RUNNING]:
395+
time.sleep(self.POLL_INTERVAL_SECONDS)
396+
state = self.get_query_state(command_id)
397+
398+
self._check_command_not_in_failed_or_closed_state(state, command_id)
399+
400+
return state
398401

399402
def execute_command(
400403
self,
@@ -405,7 +408,7 @@ def execute_command(
405408
lz4_compression: bool,
406409
cursor: "Cursor",
407410
use_cloud_fetch: bool,
408-
parameters: List,
411+
parameters: List[Dict[str, Any]],
409412
async_op: bool,
410413
enforce_embedded_schema_correctness: bool,
411414
) -> Union["ResultSet", None]:
@@ -439,9 +442,9 @@ def execute_command(
439442
for param in parameters:
440443
sea_parameters.append(
441444
StatementParameter(
442-
name=param.name,
443-
value=param.value,
444-
type=param.type if hasattr(param, "type") else None,
445+
name=param["name"],
446+
value=param["value"],
447+
type=param["type"] if "type" in param else None,
445448
)
446449
)
447450

@@ -493,24 +496,7 @@ def execute_command(
493496
if async_op:
494497
return None
495498

496-
# For synchronous operation, wait for the statement to complete
497-
status = response.status
498-
state = status.state
499-
500-
# Keep polling until we reach a terminal state
501-
while state in [CommandState.PENDING, CommandState.RUNNING]:
502-
time.sleep(0.5) # add a small delay to avoid excessive API calls
503-
state = self.get_query_state(command_id)
504-
505-
if state != CommandState.SUCCEEDED:
506-
raise ServerOperationError(
507-
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
508-
{
509-
"operation-id": command_id.to_sea_statement_id(),
510-
"diagnostic-info": None,
511-
},
512-
)
513-
499+
self._wait_until_command_done(response)
514500
return self.get_execution_result(command_id, cursor)
515501

516502
def cancel_command(self, command_id: CommandId) -> None:
@@ -622,25 +608,21 @@ def get_execution_result(
622608
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
623609
data=request.to_dict(),
624610
)
611+
response = GetStatementResponse.from_dict(response_data)
625612

626613
# Create and return a SeaResultSet
627614
from databricks.sql.result_set import SeaResultSet
628615

629-
# Convert the response to an ExecuteResponse and extract result data
630-
(
631-
execute_response,
632-
result_data,
633-
manifest,
634-
) = self._results_message_to_execute_response(response_data, command_id)
616+
execute_response = self._results_message_to_execute_response(response)
635617

636618
return SeaResultSet(
637619
connection=cursor.connection,
638620
execute_response=execute_response,
639621
sea_client=self,
640622
buffer_size_bytes=cursor.buffer_size_bytes,
641623
arraysize=cursor.arraysize,
642-
result_data=result_data,
643-
manifest=manifest,
624+
result_data=response.result,
625+
manifest=response.manifest,
644626
)
645627

646628
# == Metadata Operations ==
@@ -654,7 +636,7 @@ def get_catalogs(
654636
) -> "ResultSet":
655637
"""Get available catalogs by executing 'SHOW CATALOGS'."""
656638
result = self.execute_command(
657-
operation="SHOW CATALOGS",
639+
operation=MetadataCommands.SHOW_CATALOGS.value,
658640
session_id=session_id,
659641
max_rows=max_rows,
660642
max_bytes=max_bytes,
@@ -681,10 +663,10 @@ def get_schemas(
681663
if not catalog_name:
682664
raise ValueError("Catalog name is required for get_schemas")
683665

684-
operation = f"SHOW SCHEMAS IN `{catalog_name}`"
666+
operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name)
685667

686668
if schema_name:
687-
operation += f" LIKE '{schema_name}'"
669+
operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name)
688670

689671
result = self.execute_command(
690672
operation=operation,
@@ -716,17 +698,19 @@ def get_tables(
716698
if not catalog_name:
717699
raise ValueError("Catalog name is required for get_tables")
718700

719-
operation = "SHOW TABLES IN " + (
720-
"ALL CATALOGS"
701+
operation = (
702+
MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value
721703
if catalog_name in [None, "*", "%"]
722-
else f"CATALOG `{catalog_name}`"
704+
else MetadataCommands.SHOW_TABLES.value.format(
705+
MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name)
706+
)
723707
)
724708

725709
if schema_name:
726-
operation += f" SCHEMA LIKE '{schema_name}'"
710+
operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name)
727711

728712
if table_name:
729-
operation += f" LIKE '{table_name}'"
713+
operation += MetadataCommands.LIKE_PATTERN.value.format(table_name)
730714

731715
result = self.execute_command(
732716
operation=operation,
@@ -742,7 +726,7 @@ def get_tables(
742726
)
743727
assert result is not None, "execute_command returned None in synchronous mode"
744728

745-
# Apply client-side filtering by table_types if specified
729+
# Apply client-side filtering by table_types
746730
from databricks.sql.backend.filters import ResultSetFilter
747731

748732
result = ResultSetFilter.filter_tables_by_type(result, table_types)
@@ -764,16 +748,16 @@ def get_columns(
764748
if not catalog_name:
765749
raise ValueError("Catalog name is required for get_columns")
766750

767-
operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`"
751+
operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name)
768752

769753
if schema_name:
770-
operation += f" SCHEMA LIKE '{schema_name}'"
754+
operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name)
771755

772756
if table_name:
773-
operation += f" TABLE LIKE '{table_name}'"
757+
operation += MetadataCommands.TABLE_LIKE_PATTERN.value.format(table_name)
774758

775759
if column_name:
776-
operation += f" LIKE '{column_name}'"
760+
operation += MetadataCommands.LIKE_PATTERN.value.format(column_name)
777761

778762
result = self.execute_command(
779763
operation=operation,

0 commit comments

Comments
 (0)