Skip to content

Commit db5bbea

Browse files
[squashed from sea-exec] merge sea stuffs
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent dae15e3 commit db5bbea

File tree

16 files changed

+1805
-232
lines changed

16 files changed

+1805
-232
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,99 +22,90 @@
2222
"test_sea_metadata",
2323
]
2424

25-
2625
def load_test_function(module_name: str) -> Callable:
2726
"""Load a test function from a module."""
2827
module_path = os.path.join(
29-
os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py"
28+
os.path.dirname(os.path.abspath(__file__)),
29+
"tests",
30+
f"{module_name}.py"
3031
)
31-
32+
3233
spec = importlib.util.spec_from_file_location(module_name, module_path)
3334
module = importlib.util.module_from_spec(spec)
3435
spec.loader.exec_module(module)
35-
36+
3637
# Get the main test function (assuming it starts with "test_")
3738
for name in dir(module):
3839
if name.startswith("test_") and callable(getattr(module, name)):
3940
# For sync and async query modules, we want the main function that runs both tests
4041
if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec":
4142
return getattr(module, name)
42-
43+
4344
# Fallback to the first test function found
4445
for name in dir(module):
4546
if name.startswith("test_") and callable(getattr(module, name)):
4647
return getattr(module, name)
47-
48+
4849
raise ValueError(f"No test function found in module {module_name}")
4950

50-
5151
def run_tests() -> List[Tuple[str, bool]]:
5252
"""Run all tests and return results."""
5353
results = []
54-
54+
5555
for module_name in TEST_MODULES:
5656
try:
5757
test_func = load_test_function(module_name)
5858
logger.info(f"\n{'=' * 50}")
5959
logger.info(f"Running test: {module_name}")
6060
logger.info(f"{'-' * 50}")
61-
61+
6262
success = test_func()
6363
results.append((module_name, success))
64-
64+
6565
status = "✅ PASSED" if success else "❌ FAILED"
6666
logger.info(f"Test {module_name}: {status}")
67-
67+
6868
except Exception as e:
6969
logger.error(f"Error loading or running test {module_name}: {str(e)}")
7070
import traceback
71-
7271
logger.error(traceback.format_exc())
7372
results.append((module_name, False))
74-
73+
7574
return results
7675

77-
7876
def print_summary(results: List[Tuple[str, bool]]) -> None:
7977
"""Print a summary of test results."""
8078
logger.info(f"\n{'=' * 50}")
8179
logger.info("TEST SUMMARY")
8280
logger.info(f"{'-' * 50}")
83-
81+
8482
passed = sum(1 for _, success in results if success)
8583
total = len(results)
86-
84+
8785
for module_name, success in results:
8886
status = "✅ PASSED" if success else "❌ FAILED"
8987
logger.info(f"{status} - {module_name}")
90-
88+
9189
logger.info(f"{'-' * 50}")
9290
logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}")
9391
logger.info(f"{'=' * 50}")
9492

95-
9693
if __name__ == "__main__":
9794
# Check if required environment variables are set
98-
required_vars = [
99-
"DATABRICKS_SERVER_HOSTNAME",
100-
"DATABRICKS_HTTP_PATH",
101-
"DATABRICKS_TOKEN",
102-
]
95+
required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"]
10396
missing_vars = [var for var in required_vars if not os.environ.get(var)]
104-
97+
10598
if missing_vars:
106-
logger.error(
107-
f"Missing required environment variables: {', '.join(missing_vars)}"
108-
)
99+
logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
109100
logger.error("Please set these variables before running the tests.")
110101
sys.exit(1)
111-
102+
112103
# Run all tests
113104
results = run_tests()
114-
105+
115106
# Print summary
116107
print_summary(results)
117-
108+
118109
# Exit with appropriate status code
119110
all_passed = all(success for _, success in results)
120111
sys.exit(0 if all_passed else 1)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file makes the tests directory a Python package

src/databricks/sql/backend/databricks_client.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -86,34 +86,6 @@ def execute_command(
8686
async_op: bool,
8787
enforce_embedded_schema_correctness: bool,
8888
) -> Union["ResultSet", None]:
89-
"""
90-
Executes a SQL command or query within the specified session.
91-
92-
This method sends a SQL command to the server for execution and handles
93-
the response. It can operate in both synchronous and asynchronous modes.
94-
95-
Args:
96-
operation: The SQL command or query to execute
97-
session_id: The session identifier in which to execute the command
98-
max_rows: Maximum number of rows to fetch in a single fetch batch
99-
max_bytes: Maximum number of bytes to fetch in a single fetch batch
100-
lz4_compression: Whether to use LZ4 compression for result data
101-
cursor: The cursor object that will handle the results
102-
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
103-
parameters: List of parameters to bind to the query
104-
async_op: Whether to execute the command asynchronously
105-
enforce_embedded_schema_correctness: Whether to enforce schema correctness
106-
107-
Returns:
108-
If async_op is False, returns a ResultSet object containing the
109-
query results and metadata. If async_op is True, returns None and the
110-
results must be fetched later using get_execution_result().
111-
112-
Raises:
113-
ValueError: If the session ID is invalid
114-
OperationalError: If there's an error executing the command
115-
ServerOperationError: If the server encounters an error during execution
116-
"""
11789
pass
11890

11991
@abstractmethod

src/databricks/sql/backend/filters.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Client-side filtering utilities for Databricks SQL connector.
3+
4+
This module provides filtering capabilities for result sets returned by different backends.
5+
"""
6+
7+
import logging
8+
from typing import (
9+
List,
10+
Optional,
11+
Any,
12+
Callable,
13+
TYPE_CHECKING,
14+
)
15+
16+
if TYPE_CHECKING:
17+
from databricks.sql.result_set import ResultSet
18+
19+
from databricks.sql.result_set import SeaResultSet
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class ResultSetFilter:
25+
"""
26+
A general-purpose filter for result sets that can be applied to any backend.
27+
28+
This class provides methods to filter result sets based on various criteria,
29+
similar to the client-side filtering in the JDBC connector.
30+
"""
31+
32+
@staticmethod
33+
def _filter_sea_result_set(
34+
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
35+
) -> "SeaResultSet":
36+
"""
37+
Filter a SEA result set using the provided filter function.
38+
39+
Args:
40+
result_set: The SEA result set to filter
41+
filter_func: Function that takes a row and returns True if the row should be included
42+
43+
Returns:
44+
A filtered SEA result set
45+
"""
46+
# Create a filtered version of the result set
47+
filtered_response = result_set._response.copy()
48+
49+
# If there's a result with rows, filter them
50+
if (
51+
"result" in filtered_response
52+
and "data_array" in filtered_response["result"]
53+
):
54+
rows = filtered_response["result"]["data_array"]
55+
filtered_rows = [row for row in rows if filter_func(row)]
56+
filtered_response["result"]["data_array"] = filtered_rows
57+
58+
# Update row count if present
59+
if "row_count" in filtered_response["result"]:
60+
filtered_response["result"]["row_count"] = len(filtered_rows)
61+
62+
# Create a new result set with the filtered data
63+
return SeaResultSet(
64+
connection=result_set.connection,
65+
sea_response=filtered_response,
66+
sea_client=result_set.backend,
67+
buffer_size_bytes=result_set.buffer_size_bytes,
68+
arraysize=result_set.arraysize,
69+
)
70+
71+
@staticmethod
72+
def filter_by_column_values(
73+
result_set: "ResultSet",
74+
column_index: int,
75+
allowed_values: List[str],
76+
case_sensitive: bool = False,
77+
) -> "ResultSet":
78+
"""
79+
Filter a result set by values in a specific column.
80+
81+
Args:
82+
result_set: The result set to filter
83+
column_index: The index of the column to filter on
84+
allowed_values: List of allowed values for the column
85+
case_sensitive: Whether to perform case-sensitive comparison
86+
87+
Returns:
88+
A filtered result set
89+
"""
90+
# Convert to uppercase for case-insensitive comparison if needed
91+
if not case_sensitive:
92+
allowed_values = [v.upper() for v in allowed_values]
93+
94+
# Determine the type of result set and apply appropriate filtering
95+
if isinstance(result_set, SeaResultSet):
96+
return ResultSetFilter._filter_sea_result_set(
97+
result_set,
98+
lambda row: (
99+
len(row) > column_index
100+
and isinstance(row[column_index], str)
101+
and (
102+
row[column_index].upper()
103+
if not case_sensitive
104+
else row[column_index]
105+
)
106+
in allowed_values
107+
),
108+
)
109+
110+
# For other result set types, return the original (should be handled by specific implementations)
111+
logger.warning(
112+
f"Filtering not implemented for result set type: {type(result_set).__name__}"
113+
)
114+
return result_set
115+
116+
@staticmethod
117+
def filter_tables_by_type(
118+
result_set: "ResultSet", table_types: Optional[List[str]] = None
119+
) -> "ResultSet":
120+
"""
121+
Filter a result set of tables by the specified table types.
122+
123+
This is a client-side filter that processes the result set after it has been
124+
retrieved from the server. It filters out tables whose type does not match
125+
any of the types in the table_types list.
126+
127+
Args:
128+
result_set: The original result set containing tables
129+
table_types: List of table types to include (e.g., ["TABLE", "VIEW"])
130+
131+
Returns:
132+
A filtered result set containing only tables of the specified types
133+
"""
134+
# Default table types if none specified
135+
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
136+
valid_types = (
137+
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
138+
)
139+
140+
# Table type is typically in the 6th column (index 5)
141+
return ResultSetFilter.filter_by_column_values(
142+
result_set, 5, valid_types, case_sensitive=False
143+
)

0 commit comments

Comments
 (0)