Skip to content

Commit 30f8266

Browse files
add metadata commands
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 030edf8 commit 30f8266

File tree

3 files changed

+386
-4
lines changed

3 files changed

+386
-4
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
Client-side filtering utilities for Databricks SQL connector.
3+
4+
This module provides filtering capabilities for result sets returned by different backends.
5+
"""
6+
7+
import logging
8+
from typing import (
9+
List,
10+
Optional,
11+
Any,
12+
Dict,
13+
Callable,
14+
TypeVar,
15+
Generic,
16+
cast,
17+
TYPE_CHECKING,
18+
)
19+
20+
from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory
21+
from databricks.sql.backend.types import ExecuteResponse, CommandId
22+
from databricks.sql.backend.sea.models.base import ResultData
23+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
24+
25+
if TYPE_CHECKING:
26+
from databricks.sql.result_set import ResultSet, SeaResultSet
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class ResultSetFilter:
32+
"""
33+
A general-purpose filter for result sets that can be applied to any backend.
34+
35+
This class provides methods to filter result sets based on various criteria,
36+
similar to the client-side filtering in the JDBC connector.
37+
"""
38+
39+
@staticmethod
40+
def _filter_sea_result_set(
41+
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
42+
) -> "SeaResultSet":
43+
"""
44+
Filter a SEA result set using the provided filter function.
45+
46+
Args:
47+
result_set: The SEA result set to filter
48+
filter_func: Function that takes a row and returns True if the row should be included
49+
50+
Returns:
51+
A filtered SEA result set
52+
"""
53+
# Get all remaining rows
54+
all_rows = result_set.results.remaining_rows()
55+
56+
# Filter rows
57+
filtered_rows = [row for row in all_rows if filter_func(row)]
58+
59+
# Import SeaResultSet here to avoid circular imports
60+
from databricks.sql.result_set import SeaResultSet
61+
62+
# Reuse the command_id from the original result set
63+
command_id = result_set.command_id
64+
65+
# Create an ExecuteResponse with the filtered data
66+
execute_response = ExecuteResponse(
67+
command_id=command_id,
68+
status=result_set.status,
69+
description=result_set.description,
70+
has_been_closed_server_side=result_set.has_been_closed_server_side,
71+
lz4_compressed=result_set.lz4_compressed,
72+
arrow_schema_bytes=result_set._arrow_schema_bytes,
73+
is_staging_operation=False,
74+
)
75+
76+
# Create a new ResultData object with filtered data
77+
from databricks.sql.backend.sea.models.base import ResultData
78+
79+
result_data = ResultData(data=filtered_rows, external_links=None)
80+
81+
# Create a new SeaResultSet with the filtered data
82+
filtered_result_set = SeaResultSet(
83+
connection=result_set.connection,
84+
execute_response=execute_response,
85+
sea_client=cast(SeaDatabricksClient, result_set.backend),
86+
buffer_size_bytes=result_set.buffer_size_bytes,
87+
arraysize=result_set.arraysize,
88+
result_data=result_data,
89+
)
90+
91+
return filtered_result_set
92+
93+
@staticmethod
94+
def filter_by_column_values(
95+
result_set: "ResultSet",
96+
column_index: int,
97+
allowed_values: List[str],
98+
case_sensitive: bool = False,
99+
) -> "ResultSet":
100+
"""
101+
Filter a result set by values in a specific column.
102+
103+
Args:
104+
result_set: The result set to filter
105+
column_index: The index of the column to filter on
106+
allowed_values: List of allowed values for the column
107+
case_sensitive: Whether to perform case-sensitive comparison
108+
109+
Returns:
110+
A filtered result set
111+
"""
112+
# Convert to uppercase for case-insensitive comparison if needed
113+
if not case_sensitive:
114+
allowed_values = [v.upper() for v in allowed_values]
115+
116+
# Determine the type of result set and apply appropriate filtering
117+
from databricks.sql.result_set import SeaResultSet
118+
119+
if isinstance(result_set, SeaResultSet):
120+
return ResultSetFilter._filter_sea_result_set(
121+
result_set,
122+
lambda row: (
123+
len(row) > column_index
124+
and isinstance(row[column_index], str)
125+
and (
126+
row[column_index].upper()
127+
if not case_sensitive
128+
else row[column_index]
129+
)
130+
in allowed_values
131+
),
132+
)
133+
134+
# For other result set types, return the original (should be handled by specific implementations)
135+
logger.warning(
136+
f"Filtering not implemented for result set type: {type(result_set).__name__}"
137+
)
138+
return result_set
139+
140+
@staticmethod
141+
def filter_tables_by_type(
142+
result_set: "ResultSet", table_types: Optional[List[str]] = None
143+
) -> "ResultSet":
144+
"""
145+
Filter a result set of tables by the specified table types.
146+
147+
This is a client-side filter that processes the result set after it has been
148+
retrieved from the server. It filters out tables whose type does not match
149+
any of the types in the table_types list.
150+
151+
Args:
152+
result_set: The original result set containing tables
153+
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
154+
155+
Returns:
156+
A filtered result set containing only tables of the specified types
157+
"""
158+
# Default table types if none specified
159+
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
160+
valid_types = (
161+
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
162+
)
163+
164+
# Table type is the 6th column (index 5)
165+
return ResultSetFilter.filter_by_column_values(
166+
result_set, 5, valid_types, case_sensitive=True
167+
)

src/databricks/sql/backend/sea/backend.py

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,20 @@ def get_catalogs(
724724
cursor: "Cursor",
725725
) -> "ResultSet":
726726
"""Get available catalogs by executing 'SHOW CATALOGS'."""
727-
raise NotImplementedError("get_catalogs is not implemented for SEA backend")
727+
result = self.execute_command(
728+
operation="SHOW CATALOGS",
729+
session_id=session_id,
730+
max_rows=max_rows,
731+
max_bytes=max_bytes,
732+
lz4_compression=False,
733+
cursor=cursor,
734+
use_cloud_fetch=False,
735+
parameters=[],
736+
async_op=False,
737+
enforce_embedded_schema_correctness=False,
738+
)
739+
assert result is not None, "execute_command returned None in synchronous mode"
740+
return result
728741

729742
def get_schemas(
730743
self,
@@ -736,7 +749,28 @@ def get_schemas(
736749
schema_name: Optional[str] = None,
737750
) -> "ResultSet":
738751
"""Get schemas by executing 'SHOW SCHEMAS IN catalog [LIKE pattern]'."""
739-
raise NotImplementedError("get_schemas is not implemented for SEA backend")
752+
if not catalog_name:
753+
raise ValueError("Catalog name is required for get_schemas")
754+
755+
operation = f"SHOW SCHEMAS IN `{catalog_name}`"
756+
757+
if schema_name:
758+
operation += f" LIKE '{schema_name}'"
759+
760+
result = self.execute_command(
761+
operation=operation,
762+
session_id=session_id,
763+
max_rows=max_rows,
764+
max_bytes=max_bytes,
765+
lz4_compression=False,
766+
cursor=cursor,
767+
use_cloud_fetch=False,
768+
parameters=[],
769+
async_op=False,
770+
enforce_embedded_schema_correctness=False,
771+
)
772+
assert result is not None, "execute_command returned None in synchronous mode"
773+
return result
740774

741775
def get_tables(
742776
self,
@@ -750,7 +784,41 @@ def get_tables(
750784
table_types: Optional[List[str]] = None,
751785
) -> "ResultSet":
752786
"""Get tables by executing 'SHOW TABLES IN catalog [SCHEMA LIKE pattern] [LIKE pattern]'."""
753-
raise NotImplementedError("get_tables is not implemented for SEA backend")
787+
if not catalog_name:
788+
raise ValueError("Catalog name is required for get_tables")
789+
790+
operation = "SHOW TABLES IN " + (
791+
"ALL CATALOGS"
792+
if catalog_name in [None, "*", "%"]
793+
else f"CATALOG `{catalog_name}`"
794+
)
795+
796+
if schema_name:
797+
operation += f" SCHEMA LIKE '{schema_name}'"
798+
799+
if table_name:
800+
operation += f" LIKE '{table_name}'"
801+
802+
result = self.execute_command(
803+
operation=operation,
804+
session_id=session_id,
805+
max_rows=max_rows,
806+
max_bytes=max_bytes,
807+
lz4_compression=False,
808+
cursor=cursor,
809+
use_cloud_fetch=False,
810+
parameters=[],
811+
async_op=False,
812+
enforce_embedded_schema_correctness=False,
813+
)
814+
assert result is not None, "execute_command returned None in synchronous mode"
815+
816+
# Apply client-side filtering by table_types if specified
817+
from databricks.sql.backend.filters import ResultSetFilter
818+
819+
result = ResultSetFilter.filter_tables_by_type(result, table_types)
820+
821+
return result
754822

755823
def get_columns(
756824
self,
@@ -764,4 +832,31 @@ def get_columns(
764832
column_name: Optional[str] = None,
765833
) -> "ResultSet":
766834
"""Get columns by executing 'SHOW COLUMNS IN CATALOG catalog [SCHEMA LIKE pattern] [TABLE LIKE pattern] [LIKE pattern]'."""
767-
raise NotImplementedError("get_columns is not implemented for SEA backend")
835+
if not catalog_name:
836+
raise ValueError("Catalog name is required for get_columns")
837+
838+
operation = f"SHOW COLUMNS IN CATALOG `{catalog_name}`"
839+
840+
if schema_name:
841+
operation += f" SCHEMA LIKE '{schema_name}'"
842+
843+
if table_name:
844+
operation += f" TABLE LIKE '{table_name}'"
845+
846+
if column_name:
847+
operation += f" LIKE '{column_name}'"
848+
849+
result = self.execute_command(
850+
operation=operation,
851+
session_id=session_id,
852+
max_rows=max_rows,
853+
max_bytes=max_bytes,
854+
lz4_compression=False,
855+
cursor=cursor,
856+
use_cloud_fetch=False,
857+
parameters=[],
858+
async_op=False,
859+
enforce_embedded_schema_correctness=False,
860+
)
861+
assert result is not None, "execute_command returned None in synchronous mode"
862+
return result

0 commit comments

Comments
 (0)