Skip to content

Commit 37813ba

Browse files
reduce code duplication in response parsing
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 68657a3 commit 37813ba

File tree

1 file changed

+78
-141
lines changed

1 file changed

+78
-141
lines changed

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 78 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
These models define the structures used in SEA API responses.
55
"""
66

7-
from typing import Dict, List, Any, Optional, Union
8-
from dataclasses import dataclass, field
7+
from typing import Dict, Any
8+
from dataclasses import dataclass
99

1010
from databricks.sql.backend.types import CommandState
1111
from databricks.sql.backend.sea.models.base import (
@@ -14,91 +14,92 @@
1414
ResultData,
1515
ServiceError,
1616
ExternalLink,
17-
ColumnInfo,
1817
)
1918

2019

20+
def _parse_status(data: Dict[str, Any]) -> StatementStatus:
21+
"""Parse status from response data."""
22+
status_data = data.get("status", {})
23+
error = None
24+
if "error" in status_data:
25+
error_data = status_data["error"]
26+
error = ServiceError(
27+
message=error_data.get("message", ""),
28+
error_code=error_data.get("error_code"),
29+
)
30+
31+
state = CommandState.from_sea_state(status_data.get("state", ""))
32+
if state is None:
33+
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
34+
35+
return StatementStatus(
36+
state=state,
37+
error=error,
38+
sql_state=status_data.get("sql_state"),
39+
)
40+
41+
42+
def _parse_manifest(data: Dict[str, Any]) -> ResultManifest:
43+
"""Parse manifest from response data."""
44+
45+
manifest_data = data.get("manifest", {})
46+
return ResultManifest(
47+
format=manifest_data.get("format", ""),
48+
schema=manifest_data.get("schema", {}),
49+
total_row_count=manifest_data.get("total_row_count", 0),
50+
total_byte_count=manifest_data.get("total_byte_count", 0),
51+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
52+
truncated=manifest_data.get("truncated", False),
53+
chunks=manifest_data.get("chunks"),
54+
result_compression=manifest_data.get("result_compression"),
55+
)
56+
57+
58+
def _parse_result(data: Dict[str, Any]) -> ResultData:
59+
"""Parse result data from response data."""
60+
result_data = data.get("result", {})
61+
external_links = None
62+
63+
if "external_links" in result_data:
64+
external_links = []
65+
for link_data in result_data["external_links"]:
66+
external_links.append(
67+
ExternalLink(
68+
external_link=link_data.get("external_link", ""),
69+
expiration=link_data.get("expiration", ""),
70+
chunk_index=link_data.get("chunk_index", 0),
71+
byte_count=link_data.get("byte_count", 0),
72+
row_count=link_data.get("row_count", 0),
73+
row_offset=link_data.get("row_offset", 0),
74+
next_chunk_index=link_data.get("next_chunk_index"),
75+
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
76+
http_headers=link_data.get("http_headers"),
77+
)
78+
)
79+
80+
return ResultData(
81+
data=result_data.get("data_array"),
82+
external_links=external_links,
83+
)
84+
85+
2186
@dataclass
2287
class ExecuteStatementResponse:
2388
"""Response from executing a SQL statement."""
2489

2590
statement_id: str
2691
status: StatementStatus
27-
manifest: Optional[ResultManifest] = None
28-
result: Optional[ResultData] = None
92+
manifest: ResultManifest
93+
result: ResultData
2994

3095
@classmethod
3196
def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
3297
"""Create an ExecuteStatementResponse from a dictionary."""
33-
status_data = data.get("status", {})
34-
error = None
35-
if "error" in status_data:
36-
error_data = status_data["error"]
37-
error = ServiceError(
38-
message=error_data.get("message", ""),
39-
error_code=error_data.get("error_code"),
40-
)
41-
42-
state = CommandState.from_sea_state(status_data.get("state", ""))
43-
if state is None:
44-
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
45-
46-
status = StatementStatus(
47-
state=state,
48-
error=error,
49-
sql_state=status_data.get("sql_state"),
50-
)
51-
52-
# Parse manifest
53-
manifest = None
54-
if "manifest" in data:
55-
manifest_data = data["manifest"]
56-
manifest = ResultManifest(
57-
format=manifest_data.get("format", ""),
58-
schema=manifest_data.get("schema", {}),
59-
total_row_count=manifest_data.get("total_row_count", 0),
60-
total_byte_count=manifest_data.get("total_byte_count", 0),
61-
total_chunk_count=manifest_data.get("total_chunk_count", 0),
62-
truncated=manifest_data.get("truncated", False),
63-
chunks=manifest_data.get("chunks"),
64-
result_compression=manifest_data.get("result_compression"),
65-
)
66-
67-
# Parse result data
68-
result = None
69-
if "result" in data:
70-
result_data = data["result"]
71-
external_links = None
72-
73-
if "external_links" in result_data:
74-
external_links = []
75-
for link_data in result_data["external_links"]:
76-
external_links.append(
77-
ExternalLink(
78-
external_link=link_data.get("external_link", ""),
79-
expiration=link_data.get("expiration", ""),
80-
chunk_index=link_data.get("chunk_index", 0),
81-
byte_count=link_data.get("byte_count", 0),
82-
row_count=link_data.get("row_count", 0),
83-
row_offset=link_data.get("row_offset", 0),
84-
next_chunk_index=link_data.get("next_chunk_index"),
85-
next_chunk_internal_link=link_data.get(
86-
"next_chunk_internal_link"
87-
),
88-
http_headers=link_data.get("http_headers"),
89-
)
90-
)
91-
92-
result = ResultData(
93-
data=result_data.get("data_array"),
94-
external_links=external_links,
95-
)
96-
9798
return cls(
9899
statement_id=data.get("statement_id", ""),
99-
status=status,
100-
manifest=manifest,
101-
result=result,
100+
status=_parse_status(data),
101+
manifest=_parse_manifest(data),
102+
result=_parse_result(data),
102103
)
103104

104105

@@ -108,81 +109,17 @@ class GetStatementResponse:
108109

109110
statement_id: str
110111
status: StatementStatus
111-
manifest: Optional[ResultManifest] = None
112-
result: Optional[ResultData] = None
112+
manifest: ResultManifest
113+
result: ResultData
113114

114115
@classmethod
115116
def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
116117
"""Create a GetStatementResponse from a dictionary."""
117-
status_data = data.get("status", {})
118-
error = None
119-
if "error" in status_data:
120-
error_data = status_data["error"]
121-
error = ServiceError(
122-
message=error_data.get("message", ""),
123-
error_code=error_data.get("error_code"),
124-
)
125-
126-
state = CommandState.from_sea_state(status_data.get("state", ""))
127-
if state is None:
128-
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
129-
130-
status = StatementStatus(
131-
state=state,
132-
error=error,
133-
sql_state=status_data.get("sql_state"),
134-
)
135-
136-
# Parse manifest
137-
manifest = None
138-
if "manifest" in data:
139-
manifest_data = data["manifest"]
140-
manifest = ResultManifest(
141-
format=manifest_data.get("format", ""),
142-
schema=manifest_data.get("schema", {}),
143-
total_row_count=manifest_data.get("total_row_count", 0),
144-
total_byte_count=manifest_data.get("total_byte_count", 0),
145-
total_chunk_count=manifest_data.get("total_chunk_count", 0),
146-
truncated=manifest_data.get("truncated", False),
147-
chunks=manifest_data.get("chunks"),
148-
result_compression=manifest_data.get("result_compression"),
149-
)
150-
151-
# Parse result data
152-
result = None
153-
if "result" in data:
154-
result_data = data["result"]
155-
external_links = None
156-
157-
if "external_links" in result_data:
158-
external_links = []
159-
for link_data in result_data["external_links"]:
160-
external_links.append(
161-
ExternalLink(
162-
external_link=link_data.get("external_link", ""),
163-
expiration=link_data.get("expiration", ""),
164-
chunk_index=link_data.get("chunk_index", 0),
165-
byte_count=link_data.get("byte_count", 0),
166-
row_count=link_data.get("row_count", 0),
167-
row_offset=link_data.get("row_offset", 0),
168-
next_chunk_index=link_data.get("next_chunk_index"),
169-
next_chunk_internal_link=link_data.get(
170-
"next_chunk_internal_link"
171-
),
172-
http_headers=link_data.get("http_headers"),
173-
)
174-
)
175-
176-
result = ResultData(
177-
data=result_data.get("data_array"),
178-
external_links=external_links,
179-
)
180-
181118
return cls(
182119
statement_id=data.get("statement_id", ""),
183-
status=status,
184-
manifest=manifest,
185-
result=result,
120+
status=_parse_status(data),
121+
manifest=_parse_manifest(data),
122+
result=_parse_result(data),
186123
)
187124

188125

0 commit comments

Comments
 (0)