Skip to content

Commit 74dd311

Browse files
move get chunk links into backend
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 75c5a62 commit 74dd311

14 files changed

+432
-284
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from databricks.sql.client import Cursor
1818

1919
from databricks.sql.thrift_api.TCLIService import ttypes
20-
from databricks.sql.backend.types import SessionId, CommandId, CommandState, ExecuteResponse
20+
from databricks.sql.backend.types import (
21+
SessionId,
22+
CommandId,
23+
CommandState,
24+
ExecuteResponse,
25+
)
2126
from databricks.sql.types import SSLOptions
2227

2328
# Forward reference for type hints

src/databricks/sql/backend/filters.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,40 +41,40 @@ def _filter_sea_result_set(
4141
) -> "SeaResultSet":
4242
"""
4343
Filter a SEA result set using the provided filter function.
44-
44+
4545
Args:
4646
result_set: The SEA result set to filter
4747
filter_func: Function that takes a row and returns True if the row should be included
48-
48+
4949
Returns:
5050
A filtered SEA result set
5151
"""
5252
# Get all remaining rows
5353
original_index = result_set.results.cur_row_index
5454
result_set.results.cur_row_index = 0 # Reset to beginning
5555
all_rows = result_set.results.remaining_rows()
56-
56+
5757
# Filter rows
5858
filtered_rows = [row for row in all_rows if filter_func(row)]
59-
59+
6060
# Import SeaResultSet here to avoid circular imports
6161
from databricks.sql.result_set import SeaResultSet
62-
62+
6363
# Reuse the command_id from the original result set
6464
command_id = result_set.command_id
65-
65+
6666
# Create an ExecuteResponse with the filtered data
6767
execute_response = ExecuteResponse(
6868
command_id=command_id,
6969
status=result_set.status,
7070
description=result_set.description,
71-
has_more_rows=result_set._has_more_rows,
71+
has_more_rows=result_set._has_more_rows,
7272
results_queue=JsonQueue(filtered_rows),
7373
has_been_closed_server_side=result_set.has_been_closed_server_side,
7474
lz4_compressed=False,
7575
is_staging_operation=False,
7676
)
77-
77+
7878
return SeaResultSet(
7979
connection=result_set.connection,
8080
execute_response=execute_response,
@@ -108,6 +108,7 @@ def filter_by_column_values(
108108

109109
# Determine the type of result set and apply appropriate filtering
110110
from databricks.sql.result_set import SeaResultSet
111+
111112
if isinstance(result_set, SeaResultSet):
112113
return ResultSetFilter._filter_sea_result_set(
113114
result_set,
@@ -156,4 +157,4 @@ def filter_tables_by_type(
156157
# Table type is the 6th column (index 5)
157158
return ResultSetFilter.filter_by_column_values(
158159
result_set, 5, valid_types, case_sensitive=False
159-
)
160+
)

src/databricks/sql/backend/models/responses.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
4444
error=error,
4545
sql_state=status_data.get("sql_state"),
4646
)
47-
47+
4848
# Parse manifest
4949
manifest = None
5050
if "manifest" in data:
@@ -59,13 +59,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
5959
chunks=manifest_data.get("chunks"),
6060
result_compression=manifest_data.get("result_compression"),
6161
)
62-
62+
6363
# Parse result data
6464
result = None
6565
if "result" in data:
6666
result_data = data["result"]
6767
external_links = None
68-
68+
6969
if "external_links" in result_data:
7070
external_links = []
7171
for link_data in result_data["external_links"]:
@@ -78,11 +78,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
7878
row_count=link_data.get("row_count", 0),
7979
row_offset=link_data.get("row_offset", 0),
8080
next_chunk_index=link_data.get("next_chunk_index"),
81-
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
81+
next_chunk_internal_link=link_data.get(
82+
"next_chunk_internal_link"
83+
),
8284
http_headers=link_data.get("http_headers"),
8385
)
8486
)
85-
87+
8688
result = ResultData(
8789
data=result_data.get("data_array"),
8890
external_links=external_links,
@@ -122,7 +124,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
122124
error=error,
123125
sql_state=status_data.get("sql_state"),
124126
)
125-
127+
126128
# Parse manifest
127129
manifest = None
128130
if "manifest" in data:
@@ -137,13 +139,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
137139
chunks=manifest_data.get("chunks"),
138140
result_compression=manifest_data.get("result_compression"),
139141
)
140-
142+
141143
# Parse result data
142144
result = None
143145
if "result" in data:
144146
result_data = data["result"]
145147
external_links = None
146-
148+
147149
if "external_links" in result_data:
148150
external_links = []
149151
for link_data in result_data["external_links"]:
@@ -156,11 +158,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
156158
row_count=link_data.get("row_count", 0),
157159
row_offset=link_data.get("row_offset", 0),
158160
next_chunk_index=link_data.get("next_chunk_index"),
159-
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
161+
next_chunk_internal_link=link_data.get(
162+
"next_chunk_internal_link"
163+
),
160164
http_headers=link_data.get("http_headers"),
161165
)
162166
)
163-
167+
164168
result = ResultData(
165169
data=result_data.get("data_array"),
166170
external_links=external_links,
@@ -208,11 +212,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
208212
row_count=link_data.get("row_count", 0),
209213
row_offset=link_data.get("row_offset", 0),
210214
next_chunk_index=link_data.get("next_chunk_index"),
211-
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
215+
next_chunk_internal_link=link_data.get(
216+
"next_chunk_internal_link"
217+
),
212218
http_headers=link_data.get("http_headers"),
213219
)
214220
)
215-
221+
216222
return cls(
217223
statement_id=data.get("statement_id", ""),
218224
external_links=external_links,

0 commit comments

Comments
 (0)