Skip to content

Commit 5540c5c

Browse files
reduce code repetititon + introduce gaps after multi line pydocs
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent af8f74e commit 5540c5c

File tree

3 files changed

+30
-68
lines changed

3 files changed

+30
-68
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 20 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
GetStatementResponse,
4141
CreateSessionResponse,
4242
)
43+
from databricks.sql.backend.sea.models.responses import (
44+
parse_status,
45+
parse_manifest,
46+
parse_result,
47+
)
4348

4449
logger = logging.getLogger(__name__)
4550

@@ -75,9 +80,6 @@ def _filter_session_configuration(
7580
class SeaDatabricksClient(DatabricksClient):
7681
"""
7782
Statement Execution API (SEA) implementation of the DatabricksClient interface.
78-
79-
This implementation provides session management functionality for SEA,
80-
while other operations raise NotImplementedError.
8183
"""
8284

8385
# SEA API paths
@@ -119,7 +121,6 @@ def __init__(
119121
)
120122

121123
self._max_download_threads = kwargs.get("max_download_threads", 10)
122-
self.ssl_options = ssl_options
123124

124125
# Extract warehouse ID from http_path
125126
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -298,16 +299,16 @@ def _results_message_to_execute_response(self, sea_response, command_id):
298299
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
299300
result data object, and manifest object
300301
"""
301-
# Extract status
302-
status_data = sea_response.get("status", {})
303-
state = CommandState.from_sea_state(status_data.get("state", ""))
304302

305-
# Extract description from manifest
303+
# Parse the response
304+
status = parse_status(sea_response)
305+
manifest_obj = parse_manifest(sea_response)
306+
result_data_obj = parse_result(sea_response)
307+
308+
# Extract description from manifest schema
306309
description = None
307-
manifest_data = sea_response.get("manifest", {})
308-
schema_data = manifest_data.get("schema", {})
310+
schema_data = manifest_obj.schema
309311
columns_data = schema_data.get("columns", [])
310-
311312
if columns_data:
312313
columns = []
313314
for col_data in columns_data:
@@ -329,61 +330,17 @@ def _results_message_to_execute_response(self, sea_response, command_id):
329330
description = columns if columns else None
330331

331332
# Check for compression
332-
lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME"
333-
334-
# Initialize result_data_obj and manifest_obj
335-
result_data_obj = None
336-
manifest_obj = None
337-
338-
result_data = sea_response.get("result", {})
339-
if result_data:
340-
# Convert external links
341-
external_links = None
342-
if "external_links" in result_data:
343-
external_links = []
344-
for link_data in result_data["external_links"]:
345-
external_links.append(
346-
ExternalLink(
347-
external_link=link_data.get("external_link", ""),
348-
expiration=link_data.get("expiration", ""),
349-
chunk_index=link_data.get("chunk_index", 0),
350-
byte_count=link_data.get("byte_count", 0),
351-
row_count=link_data.get("row_count", 0),
352-
row_offset=link_data.get("row_offset", 0),
353-
next_chunk_index=link_data.get("next_chunk_index"),
354-
next_chunk_internal_link=link_data.get(
355-
"next_chunk_internal_link"
356-
),
357-
http_headers=link_data.get("http_headers", {}),
358-
)
359-
)
360-
361-
# Create the result data object
362-
result_data_obj = ResultData(
363-
data=result_data.get("data_array"), external_links=external_links
364-
)
365-
366-
# Create the manifest object
367-
manifest_obj = ResultManifest(
368-
format=manifest_data.get("format", ""),
369-
schema=manifest_data.get("schema", {}),
370-
total_row_count=manifest_data.get("total_row_count", 0),
371-
total_byte_count=manifest_data.get("total_byte_count", 0),
372-
total_chunk_count=manifest_data.get("total_chunk_count", 0),
373-
truncated=manifest_data.get("truncated", False),
374-
chunks=manifest_data.get("chunks"),
375-
result_compression=manifest_data.get("result_compression"),
376-
)
333+
lz4_compressed = manifest_obj.result_compression == "LZ4_FRAME"
377334

378335
execute_response = ExecuteResponse(
379336
command_id=command_id,
380-
status=state,
337+
status=status.state,
381338
description=description,
382339
has_been_closed_server_side=False,
383340
lz4_compressed=lz4_compressed,
384341
is_staging_operation=False,
385342
arrow_schema_bytes=None, # to be extracted during fetch phase for ARROW
386-
result_format=manifest_data.get("format"),
343+
result_format=manifest_obj.format,
387344
)
388345

389346
return execute_response, result_data_obj, manifest_obj
@@ -419,6 +376,7 @@ def execute_command(
419376
Returns:
420377
ResultSet: A SeaResultSet instance for the executed command
421378
"""
379+
422380
if session_id.backend_type != BackendType.SEA:
423381
raise ValueError("Not a valid SEA session ID")
424382

@@ -506,6 +464,7 @@ def cancel_command(self, command_id: CommandId) -> None:
506464
Raises:
507465
ValueError: If the command ID is invalid
508466
"""
467+
509468
if command_id.backend_type != BackendType.SEA:
510469
raise ValueError("Not a valid SEA command ID")
511470

@@ -528,6 +487,7 @@ def close_command(self, command_id: CommandId) -> None:
528487
Raises:
529488
ValueError: If the command ID is invalid
530489
"""
490+
531491
if command_id.backend_type != BackendType.SEA:
532492
raise ValueError("Not a valid SEA command ID")
533493

@@ -553,6 +513,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
553513
Raises:
554514
ValueError: If the command ID is invalid
555515
"""
516+
556517
if command_id.backend_type != BackendType.SEA:
557518
raise ValueError("Not a valid SEA command ID")
558519

@@ -587,6 +548,7 @@ def get_execution_result(
587548
Raises:
588549
ValueError: If the command ID is invalid
589550
"""
551+
590552
if command_id.backend_type != BackendType.SEA:
591553
raise ValueError("Not a valid SEA command ID")
592554

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020

21-
def _parse_status(data: Dict[str, Any]) -> StatementStatus:
21+
def parse_status(data: Dict[str, Any]) -> StatementStatus:
2222
"""Parse status from response data."""
2323
status_data = data.get("status", {})
2424
error = None
@@ -40,7 +40,7 @@ def _parse_status(data: Dict[str, Any]) -> StatementStatus:
4040
)
4141

4242

43-
def _parse_manifest(data: Dict[str, Any]) -> ResultManifest:
43+
def parse_manifest(data: Dict[str, Any]) -> ResultManifest:
4444
"""Parse manifest from response data."""
4545

4646
manifest_data = data.get("manifest", {})
@@ -69,7 +69,7 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest:
6969
)
7070

7171

72-
def _parse_result(data: Dict[str, Any]) -> ResultData:
72+
def parse_result(data: Dict[str, Any]) -> ResultData:
7373
"""Parse result data from response data."""
7474
result_data = data.get("result", {})
7575
external_links = None
@@ -118,9 +118,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
118118
"""Create an ExecuteStatementResponse from a dictionary."""
119119
return cls(
120120
statement_id=data.get("statement_id", ""),
121-
status=_parse_status(data),
122-
manifest=_parse_manifest(data),
123-
result=_parse_result(data),
121+
status=parse_status(data),
122+
manifest=parse_manifest(data),
123+
result=parse_result(data),
124124
)
125125

126126

@@ -138,9 +138,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
138138
"""Create a GetStatementResponse from a dictionary."""
139139
return cls(
140140
statement_id=data.get("statement_id", ""),
141-
status=_parse_status(data),
142-
manifest=_parse_manifest(data),
143-
result=_parse_result(data),
141+
status=parse_status(data),
142+
manifest=parse_manifest(data),
143+
result=parse_result(data),
144144
)
145145

146146

tests/unit/test_sea_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def test_get_execution_result(
536536
print(result)
537537

538538
# Verify basic properties of the result
539-
assert result.statement_id == "test-statement-123"
539+
assert result.command_id.to_sea_statement_id() == "test-statement-123"
540540
assert result.status == CommandState.SUCCEEDED
541541

542542
# Verify the HTTP request

0 commit comments

Comments
 (0)