Skip to content

Commit 871a44f

Browse files
Revert "remove un-necessary filters changes"
This reverts commit 5e75fb5.
1 parent 93edb93 commit 871a44f

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,36 @@
99
List,
1010
Optional,
1111
Any,
12+
Dict,
1213
Callable,
14+
TypeVar,
15+
Generic,
1316
cast,
17+
TYPE_CHECKING,
1418
)
1519

20+
from databricks.sql.backend.types import ExecuteResponse, CommandId
21+
from databricks.sql.backend.sea.models.base import ResultData
1622
from databricks.sql.backend.sea.backend import SeaDatabricksClient
17-
from databricks.sql.backend.types import ExecuteResponse
1823

19-
from databricks.sql.result_set import ResultSet, SeaResultSet
24+
if TYPE_CHECKING:
25+
from databricks.sql.result_set import ResultSet, SeaResultSet
2026

2127
logger = logging.getLogger(__name__)
2228

2329

2430
class ResultSetFilter:
2531
"""
26-
A general-purpose filter for result sets.
32+
A general-purpose filter for result sets that can be applied to any backend.
33+
34+
This class provides methods to filter result sets based on various criteria,
35+
similar to the client-side filtering in the JDBC connector.
2736
"""
2837

2938
@staticmethod
3039
def _filter_sea_result_set(
31-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
32-
) -> SeaResultSet:
40+
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
41+
) -> "SeaResultSet":
3342
"""
3443
Filter a SEA result set using the provided filter function.
3544
@@ -40,13 +49,15 @@ def _filter_sea_result_set(
4049
Returns:
4150
A filtered SEA result set
4251
"""
43-
4452
# Get all remaining rows
4553
all_rows = result_set.results.remaining_rows()
4654

4755
# Filter rows
4856
filtered_rows = [row for row in all_rows if filter_func(row)]
4957

58+
# Import SeaResultSet here to avoid circular imports
59+
from databricks.sql.result_set import SeaResultSet
60+
5061
# Reuse the command_id from the original result set
5162
command_id = result_set.command_id
5263

@@ -62,13 +73,10 @@ def _filter_sea_result_set(
6273
)
6374

6475
# Create a new ResultData object with filtered data
65-
6676
from databricks.sql.backend.sea.models.base import ResultData
6777

6878
result_data = ResultData(data=filtered_rows, external_links=None)
6979

70-
from databricks.sql.result_set import SeaResultSet
71-
7280
# Create a new SeaResultSet with the filtered data
7381
filtered_result_set = SeaResultSet(
7482
connection=result_set.connection,
@@ -83,11 +91,11 @@ def _filter_sea_result_set(
8391

8492
@staticmethod
8593
def filter_by_column_values(
86-
result_set: ResultSet,
94+
result_set: "ResultSet",
8795
column_index: int,
8896
allowed_values: List[str],
8997
case_sensitive: bool = False,
90-
) -> ResultSet:
98+
) -> "ResultSet":
9199
"""
92100
Filter a result set by values in a specific column.
93101
@@ -100,7 +108,6 @@ def filter_by_column_values(
100108
Returns:
101109
A filtered result set
102110
"""
103-
104111
# Convert to uppercase for case-insensitive comparison if needed
105112
if not case_sensitive:
106113
allowed_values = [v.upper() for v in allowed_values]
@@ -131,8 +138,8 @@ def filter_by_column_values(
131138

132139
@staticmethod
133140
def filter_tables_by_type(
134-
result_set: ResultSet, table_types: Optional[List[str]] = None
135-
) -> ResultSet:
141+
result_set: "ResultSet", table_types: Optional[List[str]] = None
142+
) -> "ResultSet":
136143
"""
137144
Filter a result set of tables by the specified table types.
138145
@@ -147,7 +154,6 @@ def filter_tables_by_type(
147154
Returns:
148155
A filtered result set containing only tables of the specified types
149156
"""
150-
151157
# Default table types if none specified
152158
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
153159
valid_types = (

0 commit comments

Comments
 (0)