Skip to content

Commit 93edb93

Browse files
Revert "remove un-necessary backend changes"
This reverts commit 20822e4.
1 parent 0e3c0a1 commit 93edb93

File tree

1 file changed

+107
-91
lines changed

1 file changed

+107
-91
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 107 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import logging
2+
import uuid
23
import time
34
import re
4-
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
5+
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set
56

6-
from databricks.sql.backend.sea.models.base import ResultManifest
7+
from databricks.sql.backend.sea.models.base import ExternalLink
78
from databricks.sql.backend.sea.utils.constants import (
89
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
910
ResultFormat,
1011
ResultDisposition,
1112
ResultCompression,
1213
WaitTimeout,
13-
MetadataCommands,
1414
)
1515

1616
if TYPE_CHECKING:
@@ -25,8 +25,9 @@
2525
BackendType,
2626
ExecuteResponse,
2727
)
28-
from databricks.sql.exc import DatabaseError, ServerOperationError
28+
from databricks.sql.exc import ServerOperationError
2929
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
30+
from databricks.sql.thrift_api.TCLIService import ttypes
3031
from databricks.sql.types import SSLOptions
3132

3233
from databricks.sql.backend.sea.models import (
@@ -40,11 +41,12 @@
4041
ExecuteStatementResponse,
4142
GetStatementResponse,
4243
CreateSessionResponse,
44+
GetChunksResponse,
4345
)
4446
from databricks.sql.backend.sea.models.responses import (
45-
_parse_status,
46-
_parse_manifest,
47-
_parse_result,
47+
parse_status,
48+
parse_manifest,
49+
parse_result,
4850
)
4951

5052
logger = logging.getLogger(__name__)
@@ -90,9 +92,7 @@ class SeaDatabricksClient(DatabricksClient):
9092
STATEMENT_PATH = BASE_PATH + "statements"
9193
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9294
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
93-
94-
# SEA constants
95-
POLL_INTERVAL_SECONDS = 0.2
95+
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
9696

9797
def __init__(
9898
self,
@@ -124,7 +124,7 @@ def __init__(
124124
http_path,
125125
)
126126

127-
self._max_download_threads = kwargs.get("max_download_threads", 10)
127+
super().__init__(ssl_options, **kwargs)
128128

129129
# Extract warehouse ID from http_path
130130
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -136,7 +136,7 @@ def __init__(
136136
http_path=http_path,
137137
http_headers=http_headers,
138138
auth_provider=auth_provider,
139-
ssl_options=ssl_options,
139+
ssl_options=self._ssl_options,
140140
**kwargs,
141141
)
142142

@@ -291,28 +291,28 @@ def get_allowed_session_configurations() -> List[str]:
291291
"""
292292
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
293293

294-
def _extract_description_from_manifest(
295-
self, manifest: ResultManifest
296-
) -> Optional[List]:
294+
def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
297295
"""
298-
Extract column description from a manifest object, in the format defined by
299-
the spec: https://peps.python.org/pep-0249/#description
296+
Extract column description from a manifest object.
300297
301298
Args:
302-
manifest: The ResultManifest object containing schema information
299+
manifest_obj: The ResultManifest object containing schema information
303300
304301
Returns:
305302
Optional[List]: A list of column tuples or None if no columns are found
306303
"""
307304

308-
schema_data = manifest.schema
305+
schema_data = manifest_obj.schema
309306
columns_data = schema_data.get("columns", [])
310307

311308
if not columns_data:
312309
return None
313310

314311
columns = []
315312
for col_data in columns_data:
313+
if not isinstance(col_data, dict):
314+
continue
315+
316316
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
317317
columns.append(
318318
(
@@ -328,9 +328,38 @@ def _extract_description_from_manifest(
328328

329329
return columns if columns else None
330330

331-
def _results_message_to_execute_response(
332-
self, response: GetStatementResponse
333-
) -> ExecuteResponse:
331+
def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
332+
"""
333+
Get links for chunks starting from the specified index.
334+
335+
Args:
336+
statement_id: The statement ID
337+
chunk_index: The starting chunk index
338+
339+
Returns:
340+
ExternalLink: External link for the chunk
341+
"""
342+
343+
response_data = self.http_client._make_request(
344+
method="GET",
345+
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
346+
)
347+
response = GetChunksResponse.from_dict(response_data)
348+
349+
links = response.external_links
350+
link = next((l for l in links if l.chunk_index == chunk_index), None)
351+
if not link:
352+
raise ServerOperationError(
353+
f"No link found for chunk index {chunk_index}",
354+
{
355+
"operation-id": statement_id,
356+
"diagnostic-info": None,
357+
},
358+
)
359+
360+
return link
361+
362+
def _results_message_to_execute_response(self, sea_response, command_id):
334363
"""
335364
Convert a SEA response to an ExecuteResponse and extract result data.
336365
@@ -339,65 +368,33 @@ def _results_message_to_execute_response(
339368
command_id: The command ID
340369
341370
Returns:
342-
ExecuteResponse: The normalized execute response
371+
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
372+
result data object, and manifest object
343373
"""
344374

375+
# Parse the response
376+
status = parse_status(sea_response)
377+
manifest_obj = parse_manifest(sea_response)
378+
result_data_obj = parse_result(sea_response)
379+
345380
# Extract description from manifest schema
346-
description = self._extract_description_from_manifest(response.manifest)
381+
description = self._extract_description_from_manifest(manifest_obj)
347382

348383
# Check for compression
349-
lz4_compressed = (
350-
response.manifest.result_compression == ResultCompression.LZ4_FRAME
351-
)
384+
lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME"
352385

353386
execute_response = ExecuteResponse(
354-
command_id=CommandId.from_sea_statement_id(response.statement_id),
355-
status=response.status.state,
387+
command_id=command_id,
388+
status=status.state,
356389
description=description,
357390
has_been_closed_server_side=False,
358391
lz4_compressed=lz4_compressed,
359392
is_staging_operation=False,
360393
arrow_schema_bytes=None,
361-
result_format=response.manifest.format,
394+
result_format=manifest_obj.format,
362395
)
363396

364-
return execute_response
365-
366-
def _check_command_not_in_failed_or_closed_state(
367-
self, state: CommandState, command_id: CommandId
368-
) -> None:
369-
if state == CommandState.CLOSED:
370-
raise DatabaseError(
371-
"Command {} unexpectedly closed server side".format(command_id),
372-
{
373-
"operation-id": command_id,
374-
},
375-
)
376-
if state == CommandState.FAILED:
377-
raise ServerOperationError(
378-
"Command {} failed".format(command_id),
379-
{
380-
"operation-id": command_id,
381-
},
382-
)
383-
384-
def _wait_until_command_done(
385-
self, response: ExecuteStatementResponse
386-
) -> CommandState:
387-
"""
388-
Wait until a command is done.
389-
"""
390-
391-
state = response.status.state
392-
command_id = CommandId.from_sea_statement_id(response.statement_id)
393-
394-
while state in [CommandState.PENDING, CommandState.RUNNING]:
395-
time.sleep(self.POLL_INTERVAL_SECONDS)
396-
state = self.get_query_state(command_id)
397-
398-
self._check_command_not_in_failed_or_closed_state(state, command_id)
399-
400-
return state
397+
return execute_response, result_data_obj, manifest_obj
401398

402399
def execute_command(
403400
self,
@@ -408,7 +405,7 @@ def execute_command(
408405
lz4_compression: bool,
409406
cursor: "Cursor",
410407
use_cloud_fetch: bool,
411-
parameters: List[Dict[str, Any]],
408+
parameters: List,
412409
async_op: bool,
413410
enforce_embedded_schema_correctness: bool,
414411
) -> Union["ResultSet", None]:
@@ -442,9 +439,9 @@ def execute_command(
442439
for param in parameters:
443440
sea_parameters.append(
444441
StatementParameter(
445-
name=param["name"],
446-
value=param["value"],
447-
type=param["type"] if "type" in param else None,
442+
name=param.name,
443+
value=param.value,
444+
type=param.type if hasattr(param, "type") else None,
448445
)
449446
)
450447

@@ -496,7 +493,24 @@ def execute_command(
496493
if async_op:
497494
return None
498495

499-
self._wait_until_command_done(response)
496+
# For synchronous operation, wait for the statement to complete
497+
status = response.status
498+
state = status.state
499+
500+
# Keep polling until we reach a terminal state
501+
while state in [CommandState.PENDING, CommandState.RUNNING]:
502+
time.sleep(0.5) # add a small delay to avoid excessive API calls
503+
state = self.get_query_state(command_id)
504+
505+
if state != CommandState.SUCCEEDED:
506+
raise ServerOperationError(
507+
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
508+
{
509+
"operation-id": command_id.to_sea_statement_id(),
510+
"diagnostic-info": None,
511+
},
512+
)
513+
500514
return self.get_execution_result(command_id, cursor)
501515

502516
def cancel_command(self, command_id: CommandId) -> None:
@@ -608,21 +622,25 @@ def get_execution_result(
608622
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
609623
data=request.to_dict(),
610624
)
611-
response = GetStatementResponse.from_dict(response_data)
612625

613626
# Create and return a SeaResultSet
614627
from databricks.sql.result_set import SeaResultSet
615628

616-
execute_response = self._results_message_to_execute_response(response)
629+
# Convert the response to an ExecuteResponse and extract result data
630+
(
631+
execute_response,
632+
result_data,
633+
manifest,
634+
) = self._results_message_to_execute_response(response_data, command_id)
617635

618636
return SeaResultSet(
619637
connection=cursor.connection,
620638
execute_response=execute_response,
621639
sea_client=self,
622640
buffer_size_bytes=cursor.buffer_size_bytes,
623641
arraysize=cursor.arraysize,
624-
result_data=response.result,
625-
manifest=response.manifest,
642+
result_data=result_data,
643+
manifest=manifest,
626644
)
627645

628646
# == Metadata Operations ==
@@ -636,7 +654,7 @@ def get_catalogs(
636654
) -> "ResultSet":
637655
"""Get available catalogs by executing 'SHOW CATALOGS'."""
638656
result = self.execute_command(
639-
operation=MetadataCommands.SHOW_CATALOGS.value,
657+
operation="SHOW CATALOGS",
640658
session_id=session_id,
641659
max_rows=max_rows,
642660
max_bytes=max_bytes,
@@ -663,10 +681,10 @@ def get_schemas(
663681
if not catalog_name:
664682
raise ValueError("Catalog name is required for get_schemas")
665683

666-
operation = MetadataCommands.SHOW_SCHEMAS.value.format(catalog_name)
684+
operation = f"SHOW SCHEMAS IN `{catalog_name}`"
667685

668686
if schema_name:
669-
operation += MetadataCommands.LIKE_PATTERN.value.format(schema_name)
687+
operation += f" LIKE '{schema_name}'"
670688

671689
result = self.execute_command(
672690
operation=operation,
@@ -698,19 +716,17 @@ def get_tables(
698716
if not catalog_name:
699717
raise ValueError("Catalog name is required for get_tables")
700718

701-
operation = (
702-
MetadataCommands.SHOW_TABLES_ALL_CATALOGS.value
719+
operation = "SHOW TABLES IN " + (
720+
"ALL CATALOGS"
703721
if catalog_name in [None, "*", "%"]
704-
else MetadataCommands.SHOW_TABLES.value.format(
705-
MetadataCommands.CATALOG_SPECIFIC.value.format(catalog_name)
706-
)
722+
else f"CATALOG `{catalog_name}`"
707723
)
708724

709725
if schema_name:
710-
operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name)
726+
operation += f" SCHEMA LIKE '{schema_name}'"
711727

712728
if table_name:
713-
operation += MetadataCommands.LIKE_PATTERN.value.format(table_name)
729+
operation += f" LIKE '{table_name}'"
714730

715731
result = self.execute_command(
716732
operation=operation,
@@ -726,7 +742,7 @@ def get_tables(
726742
)
727743
assert result is not None, "execute_command returned None in synchronous mode"
728744

729-
# Apply client-side filtering by table_types
745+
# Apply client-side filtering by table_types if specified
730746
from databricks.sql.backend.filters import ResultSetFilter
731747

732748
result = ResultSetFilter.filter_tables_by_type(result, table_types)
@@ -748,16 +764,16 @@ def get_columns(
748764
if not catalog_name:
749765
raise ValueError("Catalog name is required for get_columns")
750766

751-
operation = MetadataCommands.SHOW_COLUMNS.value.format(catalog_name)
767+
operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`"
752768

753769
if schema_name:
754-
operation += MetadataCommands.SCHEMA_LIKE_PATTERN.value.format(schema_name)
770+
operation += f" SCHEMA LIKE '{schema_name}'"
755771

756772
if table_name:
757-
operation += MetadataCommands.TABLE_LIKE_PATTERN.value.format(table_name)
773+
operation += f" TABLE LIKE '{table_name}'"
758774

759775
if column_name:
760-
operation += MetadataCommands.LIKE_PATTERN.value.format(column_name)
776+
operation += f" LIKE '{column_name}'"
761777

762778
result = self.execute_command(
763779
operation=operation,

0 commit comments

Comments
 (0)