|
| 1 | +""" |
| 2 | +Client-side filtering utilities for Databricks SQL connector. |
| 3 | +
|
| 4 | +This module provides filtering capabilities for result sets returned by different backends. |
| 5 | +""" |
| 6 | + |
| 7 | +import logging |
| 8 | +from typing import ( |
| 9 | + List, |
| 10 | + Optional, |
| 11 | + Any, |
| 12 | + Dict, |
| 13 | + Callable, |
| 14 | + TypeVar, |
| 15 | + Generic, |
| 16 | + cast, |
| 17 | + TYPE_CHECKING, |
| 18 | +) |
| 19 | + |
| 20 | +from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory |
| 21 | +from databricks.sql.backend.types import ExecuteResponse, CommandId |
| 22 | +from databricks.sql.backend.sea.models.base import ResultData |
| 23 | +from databricks.sql.backend.sea.backend import SeaDatabricksClient |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + from databricks.sql.result_set import ResultSet, SeaResultSet |
| 27 | + |
| 28 | +logger = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class ResultSetFilter: |
| 32 | + """ |
| 33 | + A general-purpose filter for result sets that can be applied to any backend. |
| 34 | +
|
| 35 | + This class provides methods to filter result sets based on various criteria, |
| 36 | + similar to the client-side filtering in the JDBC connector. |
| 37 | + """ |
| 38 | + |
| 39 | + @staticmethod |
| 40 | + def _filter_sea_result_set( |
| 41 | + result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool] |
| 42 | + ) -> "SeaResultSet": |
| 43 | + """ |
| 44 | + Filter a SEA result set using the provided filter function. |
| 45 | +
|
| 46 | + Args: |
| 47 | + result_set: The SEA result set to filter |
| 48 | + filter_func: Function that takes a row and returns True if the row should be included |
| 49 | +
|
| 50 | + Returns: |
| 51 | + A filtered SEA result set |
| 52 | + """ |
| 53 | + # Get all remaining rows |
| 54 | + all_rows = result_set.results.remaining_rows() |
| 55 | + |
| 56 | + # Filter rows |
| 57 | + filtered_rows = [row for row in all_rows if filter_func(row)] |
| 58 | + |
| 59 | + # Import SeaResultSet here to avoid circular imports |
| 60 | + from databricks.sql.result_set import SeaResultSet |
| 61 | + |
| 62 | + # Reuse the command_id from the original result set |
| 63 | + command_id = result_set.command_id |
| 64 | + |
| 65 | + # Create an ExecuteResponse with the filtered data |
| 66 | + execute_response = ExecuteResponse( |
| 67 | + command_id=command_id, |
| 68 | + status=result_set.status, |
| 69 | + description=result_set.description, |
| 70 | + has_been_closed_server_side=result_set.has_been_closed_server_side, |
| 71 | + lz4_compressed=result_set.lz4_compressed, |
| 72 | + arrow_schema_bytes=result_set._arrow_schema_bytes, |
| 73 | + is_staging_operation=False, |
| 74 | + ) |
| 75 | + |
| 76 | + # Create a new ResultData object with filtered data |
| 77 | + from databricks.sql.backend.sea.models.base import ResultData |
| 78 | + |
| 79 | + result_data = ResultData(data=filtered_rows, external_links=None) |
| 80 | + |
| 81 | + # Create a new SeaResultSet with the filtered data |
| 82 | + filtered_result_set = SeaResultSet( |
| 83 | + connection=result_set.connection, |
| 84 | + execute_response=execute_response, |
| 85 | + sea_client=cast(SeaDatabricksClient, result_set.backend), |
| 86 | + buffer_size_bytes=result_set.buffer_size_bytes, |
| 87 | + arraysize=result_set.arraysize, |
| 88 | + result_data=result_data, |
| 89 | + ) |
| 90 | + |
| 91 | + return filtered_result_set |
| 92 | + |
| 93 | + @staticmethod |
| 94 | + def filter_by_column_values( |
| 95 | + result_set: "ResultSet", |
| 96 | + column_index: int, |
| 97 | + allowed_values: List[str], |
| 98 | + case_sensitive: bool = False, |
| 99 | + ) -> "ResultSet": |
| 100 | + """ |
| 101 | + Filter a result set by values in a specific column. |
| 102 | +
|
| 103 | + Args: |
| 104 | + result_set: The result set to filter |
| 105 | + column_index: The index of the column to filter on |
| 106 | + allowed_values: List of allowed values for the column |
| 107 | + case_sensitive: Whether to perform case-sensitive comparison |
| 108 | +
|
| 109 | + Returns: |
| 110 | + A filtered result set |
| 111 | + """ |
| 112 | + # Convert to uppercase for case-insensitive comparison if needed |
| 113 | + if not case_sensitive: |
| 114 | + allowed_values = [v.upper() for v in allowed_values] |
| 115 | + |
| 116 | + # Determine the type of result set and apply appropriate filtering |
| 117 | + from databricks.sql.result_set import SeaResultSet |
| 118 | + |
| 119 | + if isinstance(result_set, SeaResultSet): |
| 120 | + return ResultSetFilter._filter_sea_result_set( |
| 121 | + result_set, |
| 122 | + lambda row: ( |
| 123 | + len(row) > column_index |
| 124 | + and isinstance(row[column_index], str) |
| 125 | + and ( |
| 126 | + row[column_index].upper() |
| 127 | + if not case_sensitive |
| 128 | + else row[column_index] |
| 129 | + ) |
| 130 | + in allowed_values |
| 131 | + ), |
| 132 | + ) |
| 133 | + |
| 134 | + # For other result set types, return the original (should be handled by specific implementations) |
| 135 | + logger.warning( |
| 136 | + f"Filtering not implemented for result set type: {type(result_set).__name__}" |
| 137 | + ) |
| 138 | + return result_set |
| 139 | + |
| 140 | + @staticmethod |
| 141 | + def filter_tables_by_type( |
| 142 | + result_set: "ResultSet", table_types: Optional[List[str]] = None |
| 143 | + ) -> "ResultSet": |
| 144 | + """ |
| 145 | + Filter a result set of tables by the specified table types. |
| 146 | +
|
| 147 | + This is a client-side filter that processes the result set after it has been |
| 148 | + retrieved from the server. It filters out tables whose type does not match |
| 149 | + any of the types in the table_types list. |
| 150 | +
|
| 151 | + Args: |
| 152 | + result_set: The original result set containing tables |
| 153 | + table_types: List of table types to include (e.g., ["TABLE", "VIEW"]) |
| 154 | +
|
| 155 | + Returns: |
| 156 | + A filtered result set containing only tables of the specified types |
| 157 | + """ |
| 158 | + # Default table types if none specified |
| 159 | + DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"] |
| 160 | + valid_types = ( |
| 161 | + table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES |
| 162 | + ) |
| 163 | + |
| 164 | + # Table type is the 6th column (index 5) |
| 165 | + return ResultSetFilter.filter_by_column_values( |
| 166 | + result_set, 5, valid_types, case_sensitive=True |
| 167 | + ) |
0 commit comments