Skip to content

Commit a515d26

Browse files
move filters.py to SEA utils
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 35f1ef0 commit a515d26

File tree

4 files changed

+31
-43
lines changed

4 files changed

+31
-43
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ def get_tables(
724724
assert result is not None, "execute_command returned None in synchronous mode"
725725

726726
# Apply client-side filtering by table_types
727-
from databricks.sql.backend.filters import ResultSetFilter
727+
from databricks.sql.backend.sea.utils.filters import ResultSetFilter
728728

729729
result = ResultSetFilter.filter_tables_by_type(result, table_types)
730730

src/databricks/sql/backend/filters.py renamed to src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def _filter_sea_result_set(
8383

8484
@staticmethod
8585
def filter_by_column_values(
86-
result_set: ResultSet,
86+
result_set: SeaResultSet,
8787
column_index: int,
8888
allowed_values: List[str],
8989
case_sensitive: bool = False,
90-
) -> ResultSet:
90+
) -> SeaResultSet:
9191
"""
9292
Filter a result set by values in a specific column.
9393
@@ -105,34 +105,24 @@ def filter_by_column_values(
105105
if not case_sensitive:
106106
allowed_values = [v.upper() for v in allowed_values]
107107

108-
# Determine the type of result set and apply appropriate filtering
109-
from databricks.sql.result_set import SeaResultSet
110-
111-
if isinstance(result_set, SeaResultSet):
112-
return ResultSetFilter._filter_sea_result_set(
113-
result_set,
114-
lambda row: (
115-
len(row) > column_index
116-
and isinstance(row[column_index], str)
117-
and (
118-
row[column_index].upper()
119-
if not case_sensitive
120-
else row[column_index]
121-
)
122-
in allowed_values
123-
),
124-
)
125-
126-
# For other result set types, return the original (should be handled by specific implementations)
127-
logger.warning(
128-
f"Filtering not implemented for result set type: {type(result_set).__name__}"
108+
return ResultSetFilter._filter_sea_result_set(
109+
result_set,
110+
lambda row: (
111+
len(row) > column_index
112+
and isinstance(row[column_index], str)
113+
and (
114+
row[column_index].upper()
115+
if not case_sensitive
116+
else row[column_index]
117+
)
118+
in allowed_values
119+
),
129120
)
130-
return result_set
131121

132122
@staticmethod
133123
def filter_tables_by_type(
134-
result_set: ResultSet, table_types: Optional[List[str]] = None
135-
) -> ResultSet:
124+
result_set: SeaResultSet, table_types: Optional[List[str]] = None
125+
) -> SeaResultSet:
136126
"""
137127
Filter a result set of tables by the specified table types.
138128

tests/unit/test_filters.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import unittest
66
from unittest.mock import MagicMock, patch
77

8-
from databricks.sql.backend.filters import ResultSetFilter
8+
from databricks.sql.backend.sea.utils.filters import ResultSetFilter
99

1010

1111
class TestResultSetFilter(unittest.TestCase):
@@ -73,7 +73,9 @@ def test_filter_by_column_values(self):
7373
# Case 1: Case-sensitive filtering
7474
allowed_values = ["table1", "table3"]
7575

76-
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
76+
with patch(
77+
"databricks.sql.backend.sea.utils.filters.isinstance", return_value=True
78+
):
7779
with patch(
7880
"databricks.sql.result_set.SeaResultSet"
7981
) as mock_sea_result_set_class:
@@ -98,7 +100,9 @@ def test_filter_by_column_values(self):
98100

99101
# Case 2: Case-insensitive filtering
100102
mock_sea_result_set_class.reset_mock()
101-
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
103+
with patch(
104+
"databricks.sql.backend.sea.utils.filters.isinstance", return_value=True
105+
):
102106
with patch(
103107
"databricks.sql.result_set.SeaResultSet"
104108
) as mock_sea_result_set_class:
@@ -114,22 +118,14 @@ def test_filter_by_column_values(self):
114118
)
115119
mock_sea_result_set_class.assert_called_once()
116120

117-
# Case 3: Unsupported result set type
118-
mock_unsupported_result_set = MagicMock()
119-
with patch("databricks.sql.backend.filters.isinstance", return_value=False):
120-
with patch("databricks.sql.backend.filters.logger") as mock_logger:
121-
result = ResultSetFilter.filter_by_column_values(
122-
mock_unsupported_result_set, 0, ["value"], True
123-
)
124-
mock_logger.warning.assert_called_once()
125-
self.assertEqual(result, mock_unsupported_result_set)
126-
127121
def test_filter_tables_by_type(self):
128122
"""Test filtering tables by type with various options."""
129123
# Case 1: Specific table types
130124
table_types = ["TABLE", "VIEW"]
131125

132-
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
126+
with patch(
127+
"databricks.sql.backend.sea.utils.filters.isinstance", return_value=True
128+
):
133129
with patch.object(
134130
ResultSetFilter, "filter_by_column_values"
135131
) as mock_filter:
@@ -143,7 +139,9 @@ def test_filter_tables_by_type(self):
143139
self.assertEqual(kwargs.get("case_sensitive"), True)
144140

145141
# Case 2: Default table types (None or empty list)
146-
with patch("databricks.sql.backend.filters.isinstance", return_value=True):
142+
with patch(
143+
"databricks.sql.backend.sea.utils.filters.isinstance", return_value=True
144+
):
147145
with patch.object(
148146
ResultSetFilter, "filter_by_column_values"
149147
) as mock_filter:

tests/unit/test_sea_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor):
735735
) as mock_execute:
736736
# Mock the filter_tables_by_type method
737737
with patch(
738-
"databricks.sql.backend.filters.ResultSetFilter.filter_tables_by_type",
738+
"databricks.sql.backend.sea.utils.filters.ResultSetFilter.filter_tables_by_type",
739739
return_value=mock_result_set,
740740
) as mock_filter:
741741
# Case 1: With catalog name only

0 commit comments

Comments
 (0)