Skip to content

Commit 4cb15fd

Browse files
improved models and filters from cloudfetch-sea branch
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 0216d7a commit 4cb15fd

File tree

4 files changed

+187
-49
lines changed

4 files changed

+187
-49
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
List,
1010
Optional,
1111
Any,
12+
Dict,
1213
Callable,
14+
TypeVar,
15+
Generic,
16+
cast,
1317
TYPE_CHECKING,
1418
)
1519

16-
if TYPE_CHECKING:
17-
from databricks.sql.result_set import ResultSet
20+
from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory
21+
from databricks.sql.backend.types import ExecuteResponse, CommandId
22+
from databricks.sql.backend.sea.models.base import ResultData
1823

19-
from databricks.sql.result_set import SeaResultSet
24+
if TYPE_CHECKING:
25+
from databricks.sql.result_set import ResultSet, SeaResultSet
2026

2127
logger = logging.getLogger(__name__)
2228

@@ -43,26 +49,35 @@ def _filter_sea_result_set(
4349
Returns:
4450
A filtered SEA result set
4551
"""
46-
# Create a filtered version of the result set
47-
filtered_response = result_set._response.copy()
48-
49-
# If there's a result with rows, filter them
50-
if (
51-
"result" in filtered_response
52-
and "data_array" in filtered_response["result"]
53-
):
54-
rows = filtered_response["result"]["data_array"]
55-
filtered_rows = [row for row in rows if filter_func(row)]
56-
filtered_response["result"]["data_array"] = filtered_rows
57-
58-
# Update row count if present
59-
if "row_count" in filtered_response["result"]:
60-
filtered_response["result"]["row_count"] = len(filtered_rows)
61-
62-
# Create a new result set with the filtered data
52+
# Get all remaining rows
53+
original_index = result_set.results.cur_row_index
54+
result_set.results.cur_row_index = 0 # Reset to beginning
55+
all_rows = result_set.results.remaining_rows()
56+
57+
# Filter rows
58+
filtered_rows = [row for row in all_rows if filter_func(row)]
59+
60+
# Import SeaResultSet here to avoid circular imports
61+
from databricks.sql.result_set import SeaResultSet
62+
63+
# Reuse the command_id from the original result set
64+
command_id = result_set.command_id
65+
66+
# Create an ExecuteResponse with the filtered data
67+
execute_response = ExecuteResponse(
68+
command_id=command_id,
69+
status=result_set.status,
70+
description=result_set.description,
71+
has_more_rows=result_set._has_more_rows,
72+
results_queue=JsonQueue(filtered_rows),
73+
has_been_closed_server_side=result_set.has_been_closed_server_side,
74+
lz4_compressed=False,
75+
is_staging_operation=False,
76+
)
77+
6378
return SeaResultSet(
6479
connection=result_set.connection,
65-
sea_response=filtered_response,
80+
execute_response=execute_response,
6681
sea_client=result_set.backend,
6782
buffer_size_bytes=result_set.buffer_size_bytes,
6883
arraysize=result_set.arraysize,
@@ -92,6 +107,8 @@ def filter_by_column_values(
92107
allowed_values = [v.upper() for v in allowed_values]
93108

94109
# Determine the type of result set and apply appropriate filtering
110+
from databricks.sql.result_set import SeaResultSet
111+
95112
if isinstance(result_set, SeaResultSet):
96113
return ResultSetFilter._filter_sea_result_set(
97114
result_set,
@@ -137,7 +154,7 @@ def filter_tables_by_type(
137154
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
138155
)
139156

140-
# Table type is typically in the 6th column (index 5)
157+
# Table type is the 6th column (index 5)
141158
return ResultSetFilter.filter_by_column_values(
142159
result_set, 5, valid_types, case_sensitive=False
143160
)

src/databricks/sql/backend/sea/models/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class ExternalLink:
3434
external_link: str
3535
expiration: str
3636
chunk_index: int
37+
byte_count: int = 0
38+
row_count: int = 0
39+
row_offset: int = 0
40+
next_chunk_index: Optional[int] = None
41+
next_chunk_internal_link: Optional[str] = None
42+
http_headers: Optional[Dict[str, str]] = None
3743

3844

3945
@dataclass
@@ -61,8 +67,11 @@ class ColumnInfo:
6167
class ResultManifest:
6268
"""Manifest information for a result set."""
6369

64-
schema: List[ColumnInfo]
70+
format: str
71+
schema: Dict[str, Any] # Will contain column information
6572
total_row_count: int
6673
total_byte_count: int
74+
total_chunk_count: int
6775
truncated: bool = False
68-
chunk_count: Optional[int] = None
76+
chunks: Optional[List[Dict[str, Any]]] = None
77+
result_compression: Optional[str] = None

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@ class StatementParameter:
2121
class ExecuteStatementRequest:
2222
"""Request to execute a SQL statement."""
2323

24-
warehouse_id: str
25-
statement: str
2624
session_id: str
25+
statement: str
26+
warehouse_id: str
2727
disposition: str = "EXTERNAL_LINKS"
2828
format: str = "JSON_ARRAY"
29+
result_compression: Optional[str] = None
30+
parameters: Optional[List[StatementParameter]] = None
2931
wait_timeout: str = "10s"
3032
on_wait_timeout: str = "CONTINUE"
3133
row_limit: Optional[int] = None
32-
parameters: Optional[List[StatementParameter]] = None
33-
catalog: Optional[str] = None
34-
schema: Optional[str] = None
35-
result_compression: Optional[str] = None
3634

3735
def to_dict(self) -> Dict[str, Any]:
3836
"""Convert the request to a dictionary for JSON serialization."""
@@ -49,12 +47,6 @@ def to_dict(self) -> Dict[str, Any]:
4947
if self.row_limit is not None and self.row_limit > 0:
5048
result["row_limit"] = self.row_limit
5149

52-
if self.catalog:
53-
result["catalog"] = self.catalog
54-
55-
if self.schema:
56-
result["schema"] = self.schema
57-
5850
if self.result_compression:
5951
result["result_compression"] = self.result_compression
6052

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 133 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
ResultManifest,
1414
ResultData,
1515
ServiceError,
16+
ExternalLink,
17+
ColumnInfo,
1618
)
1719

1820

@@ -37,20 +39,62 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
3739
error_code=error_data.get("error_code"),
3840
)
3941

40-
state = CommandState.from_sea_state(status_data.get("state", ""))
41-
if state is None:
42-
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
4342
status = StatementStatus(
44-
state=state,
43+
state=CommandState.from_sea_state(status_data.get("state", "")),
4544
error=error,
4645
sql_state=status_data.get("sql_state"),
4746
)
4847

48+
# Parse manifest
49+
manifest = None
50+
if "manifest" in data:
51+
manifest_data = data["manifest"]
52+
manifest = ResultManifest(
53+
format=manifest_data.get("format", ""),
54+
schema=manifest_data.get("schema", {}),
55+
total_row_count=manifest_data.get("total_row_count", 0),
56+
total_byte_count=manifest_data.get("total_byte_count", 0),
57+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
58+
truncated=manifest_data.get("truncated", False),
59+
chunks=manifest_data.get("chunks"),
60+
result_compression=manifest_data.get("result_compression"),
61+
)
62+
63+
# Parse result data
64+
result = None
65+
if "result" in data:
66+
result_data = data["result"]
67+
external_links = None
68+
69+
if "external_links" in result_data:
70+
external_links = []
71+
for link_data in result_data["external_links"]:
72+
external_links.append(
73+
ExternalLink(
74+
external_link=link_data.get("external_link", ""),
75+
expiration=link_data.get("expiration", ""),
76+
chunk_index=link_data.get("chunk_index", 0),
77+
byte_count=link_data.get("byte_count", 0),
78+
row_count=link_data.get("row_count", 0),
79+
row_offset=link_data.get("row_offset", 0),
80+
next_chunk_index=link_data.get("next_chunk_index"),
81+
next_chunk_internal_link=link_data.get(
82+
"next_chunk_internal_link"
83+
),
84+
http_headers=link_data.get("http_headers"),
85+
)
86+
)
87+
88+
result = ResultData(
89+
data=result_data.get("data_array"),
90+
external_links=external_links,
91+
)
92+
4993
return cls(
5094
statement_id=data.get("statement_id", ""),
5195
status=status,
52-
manifest=data.get("manifest"), # We'll parse this more fully if needed
53-
result=data.get("result"), # We'll parse this more fully if needed
96+
manifest=manifest,
97+
result=result,
5498
)
5599

56100

@@ -75,21 +119,62 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
75119
error_code=error_data.get("error_code"),
76120
)
77121

78-
state = CommandState.from_sea_state(status_data.get("state", ""))
79-
if state is None:
80-
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
81-
82122
status = StatementStatus(
83-
state=state,
123+
state=CommandState.from_sea_state(status_data.get("state", "")),
84124
error=error,
85125
sql_state=status_data.get("sql_state"),
86126
)
87127

128+
# Parse manifest
129+
manifest = None
130+
if "manifest" in data:
131+
manifest_data = data["manifest"]
132+
manifest = ResultManifest(
133+
format=manifest_data.get("format", ""),
134+
schema=manifest_data.get("schema", {}),
135+
total_row_count=manifest_data.get("total_row_count", 0),
136+
total_byte_count=manifest_data.get("total_byte_count", 0),
137+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
138+
truncated=manifest_data.get("truncated", False),
139+
chunks=manifest_data.get("chunks"),
140+
result_compression=manifest_data.get("result_compression"),
141+
)
142+
143+
# Parse result data
144+
result = None
145+
if "result" in data:
146+
result_data = data["result"]
147+
external_links = None
148+
149+
if "external_links" in result_data:
150+
external_links = []
151+
for link_data in result_data["external_links"]:
152+
external_links.append(
153+
ExternalLink(
154+
external_link=link_data.get("external_link", ""),
155+
expiration=link_data.get("expiration", ""),
156+
chunk_index=link_data.get("chunk_index", 0),
157+
byte_count=link_data.get("byte_count", 0),
158+
row_count=link_data.get("row_count", 0),
159+
row_offset=link_data.get("row_offset", 0),
160+
next_chunk_index=link_data.get("next_chunk_index"),
161+
next_chunk_internal_link=link_data.get(
162+
"next_chunk_internal_link"
163+
),
164+
http_headers=link_data.get("http_headers"),
165+
)
166+
)
167+
168+
result = ResultData(
169+
data=result_data.get("data_array"),
170+
external_links=external_links,
171+
)
172+
88173
return cls(
89174
statement_id=data.get("statement_id", ""),
90175
status=status,
91-
manifest=data.get("manifest"), # We'll parse this more fully if needed
92-
result=data.get("result"), # We'll parse this more fully if needed
176+
manifest=manifest,
177+
result=result,
93178
)
94179

95180

@@ -103,3 +188,38 @@ class CreateSessionResponse:
103188
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
104189
"""Create a CreateSessionResponse from a dictionary."""
105190
return cls(session_id=data.get("session_id", ""))
191+
192+
193+
@dataclass
194+
class GetChunksResponse:
195+
"""Response from getting chunks for a statement."""
196+
197+
statement_id: str
198+
external_links: List[ExternalLink]
199+
200+
@classmethod
201+
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
202+
"""Create a GetChunksResponse from a dictionary."""
203+
external_links = []
204+
if "external_links" in data:
205+
for link_data in data["external_links"]:
206+
external_links.append(
207+
ExternalLink(
208+
external_link=link_data.get("external_link", ""),
209+
expiration=link_data.get("expiration", ""),
210+
chunk_index=link_data.get("chunk_index", 0),
211+
byte_count=link_data.get("byte_count", 0),
212+
row_count=link_data.get("row_count", 0),
213+
row_offset=link_data.get("row_offset", 0),
214+
next_chunk_index=link_data.get("next_chunk_index"),
215+
next_chunk_internal_link=link_data.get(
216+
"next_chunk_internal_link"
217+
),
218+
http_headers=link_data.get("http_headers"),
219+
)
220+
)
221+
222+
return cls(
223+
statement_id=data.get("statement_id", ""),
224+
external_links=external_links,
225+
)

0 commit comments

Comments
 (0)