Skip to content

Commit ffd478e

Browse files
Merge branch 'sea-migration' into metadata-sea
2 parents 68ec65f + a74d279 commit ffd478e

File tree

3 files changed

+169
-68
lines changed

3 files changed

+169
-68
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def execute_command(
8282
lz4_compression: bool,
8383
cursor: "Cursor",
8484
use_cloud_fetch: bool,
85-
parameters: List[ttypes.TSparkParameter],
85+
parameters: List,
8686
async_op: bool,
8787
enforce_embedded_schema_correctness: bool,
8888
) -> Union["ResultSet", None]:

src/databricks/sql/backend/sea/backend.py

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22
import time
33
import re
4-
from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
4+
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
55

6+
from databricks.sql.backend.sea.models.base import ResultManifest
67
from databricks.sql.backend.sea.utils.constants import (
78
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
89
ResultFormat,
@@ -23,7 +24,7 @@
2324
BackendType,
2425
ExecuteResponse,
2526
)
26-
from databricks.sql.exc import ServerOperationError
27+
from databricks.sql.exc import DatabaseError, ServerOperationError
2728
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
2829
from databricks.sql.types import SSLOptions
2930

@@ -89,6 +90,9 @@ class SeaDatabricksClient(DatabricksClient):
8990
STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}"
9091
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
9192

93+
# SEA constants
94+
POLL_INTERVAL_SECONDS = 0.2
95+
9296
def __init__(
9397
self,
9498
server_hostname: str,
@@ -286,28 +290,28 @@ def get_allowed_session_configurations() -> List[str]:
286290
"""
287291
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
288292

289-
def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
293+
def _extract_description_from_manifest(
294+
self, manifest: ResultManifest
295+
) -> Optional[List]:
290296
"""
291-
Extract column description from a manifest object.
297+
Extract column description from a manifest object, in the format defined by
298+
the spec: https://peps.python.org/pep-0249/#description
292299
293300
Args:
294-
manifest_obj: The ResultManifest object containing schema information
301+
manifest: The ResultManifest object containing schema information
295302
296303
Returns:
297304
Optional[List]: A list of column tuples or None if no columns are found
298305
"""
299306

300-
schema_data = manifest_obj.schema
307+
schema_data = manifest.schema
301308
columns_data = schema_data.get("columns", [])
302309

303310
if not columns_data:
304311
return None
305312

306313
columns = []
307314
for col_data in columns_data:
308-
if not isinstance(col_data, dict):
309-
continue
310-
311315
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
312316
columns.append(
313317
(
@@ -323,7 +327,9 @@ def _extract_description_from_manifest(self, manifest_obj) -> Optional[List]:
323327

324328
return columns if columns else None
325329

326-
def _results_message_to_execute_response(self, sea_response, command_id):
330+
def _results_message_to_execute_response(
331+
self, response: GetStatementResponse
332+
) -> ExecuteResponse:
327333
"""
328334
Convert a SEA response to an ExecuteResponse and extract result data.
329335
@@ -332,33 +338,65 @@ def _results_message_to_execute_response(self, sea_response, command_id):
332338
command_id: The command ID
333339
334340
Returns:
335-
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
336-
result data object, and manifest object
341+
ExecuteResponse: The normalized execute response
337342
"""
338343

339-
# Parse the response
340-
status = parse_status(sea_response)
341-
manifest_obj = parse_manifest(sea_response)
342-
result_data_obj = parse_result(sea_response)
343-
344344
# Extract description from manifest schema
345-
description = self._extract_description_from_manifest(manifest_obj)
345+
description = self._extract_description_from_manifest(response.manifest)
346346

347347
# Check for compression
348-
lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME"
348+
lz4_compressed = (
349+
response.manifest.result_compression == ResultCompression.LZ4_FRAME
350+
)
349351

350352
execute_response = ExecuteResponse(
351-
command_id=command_id,
352-
status=status.state,
353+
command_id=CommandId.from_sea_statement_id(response.statement_id),
354+
status=response.status.state,
353355
description=description,
354356
has_been_closed_server_side=False,
355357
lz4_compressed=lz4_compressed,
356358
is_staging_operation=False,
357359
arrow_schema_bytes=None,
358-
result_format=manifest_obj.format,
360+
result_format=response.manifest.format,
359361
)
360362

361-
return execute_response, result_data_obj, manifest_obj
363+
return execute_response
364+
365+
def _check_command_not_in_failed_or_closed_state(
366+
self, state: CommandState, command_id: CommandId
367+
) -> None:
368+
if state == CommandState.CLOSED:
369+
raise DatabaseError(
370+
"Command {} unexpectedly closed server side".format(command_id),
371+
{
372+
"operation-id": command_id,
373+
},
374+
)
375+
if state == CommandState.FAILED:
376+
raise ServerOperationError(
377+
"Command {} failed".format(command_id),
378+
{
379+
"operation-id": command_id,
380+
},
381+
)
382+
383+
def _wait_until_command_done(
384+
self, response: ExecuteStatementResponse
385+
) -> CommandState:
386+
"""
387+
Wait until a command is done.
388+
"""
389+
390+
state = response.status.state
391+
command_id = CommandId.from_sea_statement_id(response.statement_id)
392+
393+
while state in [CommandState.PENDING, CommandState.RUNNING]:
394+
time.sleep(self.POLL_INTERVAL_SECONDS)
395+
state = self.get_query_state(command_id)
396+
397+
self._check_command_not_in_failed_or_closed_state(state, command_id)
398+
399+
return state
362400

363401
def execute_command(
364402
self,
@@ -369,7 +407,7 @@ def execute_command(
369407
lz4_compression: bool,
370408
cursor: "Cursor",
371409
use_cloud_fetch: bool,
372-
parameters: List,
410+
parameters: List[Dict[str, Any]],
373411
async_op: bool,
374412
enforce_embedded_schema_correctness: bool,
375413
) -> Union["ResultSet", None]:
@@ -403,9 +441,9 @@ def execute_command(
403441
for param in parameters:
404442
sea_parameters.append(
405443
StatementParameter(
406-
name=param.name,
407-
value=param.value,
408-
type=param.type if hasattr(param, "type") else None,
444+
name=param["name"],
445+
value=param["value"],
446+
type=param["type"] if "type" in param else None,
409447
)
410448
)
411449

@@ -457,24 +495,7 @@ def execute_command(
457495
if async_op:
458496
return None
459497

460-
# For synchronous operation, wait for the statement to complete
461-
status = response.status
462-
state = status.state
463-
464-
# Keep polling until we reach a terminal state
465-
while state in [CommandState.PENDING, CommandState.RUNNING]:
466-
time.sleep(0.5) # add a small delay to avoid excessive API calls
467-
state = self.get_query_state(command_id)
468-
469-
if state != CommandState.SUCCEEDED:
470-
raise ServerOperationError(
471-
f"Statement execution did not succeed: {status.error.message if status.error else 'Unknown error'}",
472-
{
473-
"operation-id": command_id.to_sea_statement_id(),
474-
"diagnostic-info": None,
475-
},
476-
)
477-
498+
self._wait_until_command_done(response)
478499
return self.get_execution_result(command_id, cursor)
479500

480501
def cancel_command(self, command_id: CommandId) -> None:
@@ -586,25 +607,21 @@ def get_execution_result(
586607
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
587608
data=request.to_dict(),
588609
)
610+
response = GetStatementResponse.from_dict(response_data)
589611

590612
# Create and return a SeaResultSet
591613
from databricks.sql.result_set import SeaResultSet
592614

593-
# Convert the response to an ExecuteResponse and extract result data
594-
(
595-
execute_response,
596-
result_data,
597-
manifest,
598-
) = self._results_message_to_execute_response(response_data, command_id)
615+
execute_response = self._results_message_to_execute_response(response)
599616

600617
return SeaResultSet(
601618
connection=cursor.connection,
602619
execute_response=execute_response,
603620
sea_client=self,
604621
buffer_size_bytes=cursor.buffer_size_bytes,
605622
arraysize=cursor.arraysize,
606-
result_data=result_data,
607-
manifest=manifest,
623+
result_data=response.result,
624+
manifest=response.manifest,
608625
)
609626

610627
# == Metadata Operations ==

tests/unit/test_sea_backend.py

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
1616
from databricks.sql.types import SSLOptions
1717
from databricks.sql.auth.authenticators import AuthProvider
18-
from databricks.sql.exc import Error, NotSupportedError, ServerOperationError
18+
from databricks.sql.exc import (
19+
Error,
20+
NotSupportedError,
21+
ServerOperationError,
22+
DatabaseError,
23+
)
1924

2025

2126
class TestSeaBackend:
@@ -349,10 +354,7 @@ def test_command_execution_advanced(
349354
"status": {"state": "SUCCEEDED"},
350355
}
351356
mock_http_client._make_request.return_value = execute_response
352-
param = MagicMock()
353-
param.name = "param1"
354-
param.value = "value1"
355-
param.type = "STRING"
357+
param = {"name": "param1", "value": "value1", "type": "STRING"}
356358

357359
with patch.object(sea_client, "get_execution_result"):
358360
sea_client.execute_command(
@@ -405,7 +407,7 @@ def test_command_execution_advanced(
405407
async_op=False,
406408
enforce_embedded_schema_correctness=False,
407409
)
408-
assert "Statement execution did not succeed" in str(excinfo.value)
410+
assert "Command test-statement-123 failed" in str(excinfo.value)
409411

410412
# Test missing statement ID
411413
mock_http_client.reset_mock()
@@ -523,6 +525,34 @@ def test_command_management(
523525
sea_client.get_execution_result(thrift_command_id, mock_cursor)
524526
assert "Not a valid SEA command ID" in str(excinfo.value)
525527

528+
def test_check_command_state(self, sea_client, sea_command_id):
529+
"""Test _check_command_not_in_failed_or_closed_state method."""
530+
# Test with RUNNING state (should not raise)
531+
sea_client._check_command_not_in_failed_or_closed_state(
532+
CommandState.RUNNING, sea_command_id
533+
)
534+
535+
# Test with SUCCEEDED state (should not raise)
536+
sea_client._check_command_not_in_failed_or_closed_state(
537+
CommandState.SUCCEEDED, sea_command_id
538+
)
539+
540+
# Test with CLOSED state (should raise DatabaseError)
541+
with pytest.raises(DatabaseError) as excinfo:
542+
sea_client._check_command_not_in_failed_or_closed_state(
543+
CommandState.CLOSED, sea_command_id
544+
)
545+
assert "Command test-statement-123 unexpectedly closed server side" in str(
546+
excinfo.value
547+
)
548+
549+
# Test with FAILED state (should raise ServerOperationError)
550+
with pytest.raises(ServerOperationError) as excinfo:
551+
sea_client._check_command_not_in_failed_or_closed_state(
552+
CommandState.FAILED, sea_command_id
553+
)
554+
assert "Command test-statement-123 failed" in str(excinfo.value)
555+
526556
def test_utility_methods(self, sea_client):
527557
"""Test utility methods."""
528558
# Test get_default_session_configuration_value
@@ -590,12 +620,66 @@ def test_utility_methods(self, sea_client):
590620
assert description[1][1] == "INT" # type_code
591621
assert description[1][6] is False # null_ok
592622

593-
# Test with manifest containing non-dict column
594-
manifest_obj.schema = {"columns": ["not_a_dict"]}
595-
description = sea_client._extract_description_from_manifest(manifest_obj)
596-
assert description is None
623+
# Test _extract_description_from_manifest with empty columns
624+
empty_manifest = MagicMock()
625+
empty_manifest.schema = {"columns": []}
626+
assert sea_client._extract_description_from_manifest(empty_manifest) is None
597627

598-
# Test with manifest without columns
599-
manifest_obj.schema = {}
600-
description = sea_client._extract_description_from_manifest(manifest_obj)
601-
assert description is None
628+
# Test _extract_description_from_manifest with no columns key
629+
no_columns_manifest = MagicMock()
630+
no_columns_manifest.schema = {}
631+
assert (
632+
sea_client._extract_description_from_manifest(no_columns_manifest) is None
633+
)
634+
635+
def test_unimplemented_metadata_methods(
636+
self, sea_client, sea_session_id, mock_cursor
637+
):
638+
"""Test that metadata methods raise NotImplementedError."""
639+
# Test get_catalogs
640+
with pytest.raises(NotImplementedError):
641+
sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor)
642+
643+
# Test get_schemas
644+
with pytest.raises(NotImplementedError):
645+
sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor)
646+
647+
# Test get_schemas with optional parameters
648+
with pytest.raises(NotImplementedError):
649+
sea_client.get_schemas(
650+
sea_session_id, 100, 1000, mock_cursor, "catalog", "schema"
651+
)
652+
653+
# Test get_tables
654+
with pytest.raises(NotImplementedError):
655+
sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor)
656+
657+
# Test get_tables with optional parameters
658+
with pytest.raises(NotImplementedError):
659+
sea_client.get_tables(
660+
sea_session_id,
661+
100,
662+
1000,
663+
mock_cursor,
664+
catalog_name="catalog",
665+
schema_name="schema",
666+
table_name="table",
667+
table_types=["TABLE", "VIEW"],
668+
)
669+
670+
# Test get_columns
671+
with pytest.raises(NotImplementedError):
672+
sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor)
673+
674+
# Test get_columns with optional parameters
675+
with pytest.raises(NotImplementedError):
676+
sea_client.get_columns(
677+
sea_session_id,
678+
100,
679+
1000,
680+
mock_cursor,
681+
catalog_name="catalog",
682+
schema_name="schema",
683+
table_name="table",
684+
column_name="column",
685+
)

0 commit comments

Comments
 (0)