Skip to content

Commit ed7cf91

Browse files
exec test example scripts
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent bf26ea3 commit ed7cf91

File tree

6 files changed

+566
-51
lines changed

6 files changed

+566
-51
lines changed
Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,111 @@
1+
"""
2+
Main script to run all SEA connector tests.
3+
4+
This script imports and runs all the individual test modules and displays
5+
a summary of test results with visual indicators.
6+
"""
17
import os
28
import sys
39
import logging
4-
from databricks.sql.client import Connection
10+
import importlib.util
11+
from typing import Dict, Callable, List, Tuple
512

6-
logging.basicConfig(level=logging.DEBUG)
13+
# Configure logging
14+
logging.basicConfig(level=logging.INFO)
715
logger = logging.getLogger(__name__)
816

9-
def test_sea_session():
10-
"""
11-
Test opening and closing a SEA session using the connector.
17+
# Define test modules and their main test functions
18+
TEST_MODULES = [
19+
"test_sea_session",
20+
"test_sea_sync_query",
21+
"test_sea_async_query",
22+
"test_sea_metadata",
23+
]
24+
25+
def load_test_function(module_name: str) -> Callable:
26+
"""Load a test function from a module."""
27+
module_path = os.path.join(
28+
os.path.dirname(os.path.abspath(__file__)),
29+
"tests",
30+
f"{module_name}.py"
31+
)
32+
33+
spec = importlib.util.spec_from_file_location(module_name, module_path)
34+
module = importlib.util.module_from_spec(spec)
35+
spec.loader.exec_module(module)
36+
37+
# Get the main test function (assuming it starts with "test_")
38+
for name in dir(module):
39+
if name.startswith("test_") and callable(getattr(module, name)):
40+
# For sync and async query modules, we want the main function that runs both tests
41+
if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec":
42+
return getattr(module, name)
1243

13-
This function connects to a Databricks SQL endpoint using the SEA backend,
14-
opens a session, and then closes it.
44+
# Fallback to the first test function found
45+
for name in dir(module):
46+
if name.startswith("test_") and callable(getattr(module, name)):
47+
return getattr(module, name)
1548

16-
Required environment variables:
17-
- DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
18-
- DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
19-
- DATABRICKS_TOKEN: Personal access token for authentication
20-
"""
49+
raise ValueError(f"No test function found in module {module_name}")
2150

22-
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
23-
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
24-
access_token = os.environ.get("DATABRICKS_TOKEN")
25-
catalog = os.environ.get("DATABRICKS_CATALOG")
51+
def run_tests() -> List[Tuple[str, bool]]:
52+
"""Run all tests and return results."""
53+
results = []
2654

27-
if not all([server_hostname, http_path, access_token]):
28-
logger.error("Missing required environment variables.")
29-
logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.")
30-
sys.exit(1)
55+
for module_name in TEST_MODULES:
56+
try:
57+
test_func = load_test_function(module_name)
58+
logger.info(f"\n{'=' * 50}")
59+
logger.info(f"Running test: {module_name}")
60+
logger.info(f"{'-' * 50}")
61+
62+
success = test_func()
63+
results.append((module_name, success))
64+
65+
status = "✅ PASSED" if success else "❌ FAILED"
66+
logger.info(f"Test {module_name}: {status}")
67+
68+
except Exception as e:
69+
logger.error(f"Error loading or running test {module_name}: {str(e)}")
70+
import traceback
71+
logger.error(traceback.format_exc())
72+
results.append((module_name, False))
3173

32-
logger.info(f"Connecting to {server_hostname}")
33-
logger.info(f"HTTP Path: {http_path}")
34-
if catalog:
35-
logger.info(f"Using catalog: {catalog}")
74+
return results
75+
76+
def print_summary(results: List[Tuple[str, bool]]) -> None:
77+
"""Print a summary of test results."""
78+
logger.info(f"\n{'=' * 50}")
79+
logger.info("TEST SUMMARY")
80+
logger.info(f"{'-' * 50}")
3681

37-
try:
38-
logger.info("Creating connection with SEA backend...")
39-
connection = Connection(
40-
server_hostname=server_hostname,
41-
http_path=http_path,
42-
access_token=access_token,
43-
catalog=catalog,
44-
schema="default",
45-
use_sea=True,
46-
user_agent_entry="SEA-Test-Client" # add custom user agent
47-
)
48-
49-
logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}")
50-
logger.info(f"backend type: {type(connection.session.backend)}")
51-
52-
# Close the connection
53-
logger.info("Closing the SEA session...")
54-
connection.close()
55-
logger.info("Successfully closed SEA session")
56-
57-
except Exception as e:
58-
logger.error(f"Error testing SEA session: {str(e)}")
59-
import traceback
60-
logger.error(traceback.format_exc())
61-
sys.exit(1)
82+
passed = sum(1 for _, success in results if success)
83+
total = len(results)
6284

63-
logger.info("SEA session test completed successfully")
85+
for module_name, success in results:
86+
status = "✅ PASSED" if success else "❌ FAILED"
87+
logger.info(f"{status} - {module_name}")
88+
89+
logger.info(f"{'-' * 50}")
90+
logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}")
91+
logger.info(f"{'=' * 50}")
6492

6593
if __name__ == "__main__":
66-
test_sea_session()
94+
# Check if required environment variables are set
95+
required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"]
96+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
97+
98+
if missing_vars:
99+
logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
100+
logger.error("Please set these variables before running the tests.")
101+
sys.exit(1)
102+
103+
# Run all tests
104+
results = run_tests()
105+
106+
# Print summary
107+
print_summary(results)
108+
109+
# Exit with appropriate status code
110+
all_passed = all(success for _, success in results)
111+
sys.exit(0 if all_passed else 1)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file makes the tests directory a Python package
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""
2+
Test for SEA asynchronous query execution functionality.
3+
"""
4+
import os
5+
import sys
6+
import logging
7+
import time
8+
from databricks.sql.client import Connection
9+
from databricks.sql.backend.types import CommandState
10+
11+
logging.basicConfig(level=logging.INFO)
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def test_sea_async_query_with_cloud_fetch():
16+
"""
17+
Test executing a query asynchronously using the SEA backend with cloud fetch enabled.
18+
19+
This function connects to a Databricks SQL endpoint using the SEA backend,
20+
executes a simple query asynchronously with cloud fetch enabled, and verifies that execution completes successfully.
21+
"""
22+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
23+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
24+
access_token = os.environ.get("DATABRICKS_TOKEN")
25+
catalog = os.environ.get("DATABRICKS_CATALOG")
26+
27+
if not all([server_hostname, http_path, access_token]):
28+
logger.error("Missing required environment variables.")
29+
logger.error(
30+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
31+
)
32+
return False
33+
34+
try:
35+
# Create connection with cloud fetch enabled
36+
logger.info("Creating connection for asynchronous query execution with cloud fetch enabled")
37+
connection = Connection(
38+
server_hostname=server_hostname,
39+
http_path=http_path,
40+
access_token=access_token,
41+
catalog=catalog,
42+
schema="default",
43+
use_sea=True,
44+
user_agent_entry="SEA-Test-Client",
45+
use_cloud_fetch=True,
46+
)
47+
48+
logger.info(
49+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
50+
)
51+
52+
# Execute a simple query asynchronously
53+
cursor = connection.cursor()
54+
logger.info("Executing asynchronous query with cloud fetch: SELECT 1 as test_value")
55+
cursor.execute_async("SELECT 1 as test_value")
56+
logger.info("Asynchronous query submitted successfully with cloud fetch enabled")
57+
58+
# Check query state
59+
logger.info("Checking query state...")
60+
while cursor.is_query_pending():
61+
logger.info("Query is still pending, waiting...")
62+
time.sleep(1)
63+
64+
logger.info("Query is no longer pending, getting results...")
65+
cursor.get_async_execution_result()
66+
logger.info("Successfully retrieved asynchronous query results with cloud fetch enabled")
67+
68+
# Close resources
69+
cursor.close()
70+
connection.close()
71+
logger.info("Successfully closed SEA session")
72+
73+
return True
74+
75+
except Exception as e:
76+
logger.error(f"Error during SEA asynchronous query execution test with cloud fetch: {str(e)}")
77+
import traceback
78+
logger.error(traceback.format_exc())
79+
return False
80+
81+
82+
def test_sea_async_query_without_cloud_fetch():
83+
"""
84+
Test executing a query asynchronously using the SEA backend with cloud fetch disabled.
85+
86+
This function connects to a Databricks SQL endpoint using the SEA backend,
87+
executes a simple query asynchronously with cloud fetch disabled, and verifies that execution completes successfully.
88+
"""
89+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
90+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
91+
access_token = os.environ.get("DATABRICKS_TOKEN")
92+
catalog = os.environ.get("DATABRICKS_CATALOG")
93+
94+
if not all([server_hostname, http_path, access_token]):
95+
logger.error("Missing required environment variables.")
96+
logger.error(
97+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
98+
)
99+
return False
100+
101+
try:
102+
# Create connection with cloud fetch disabled
103+
logger.info("Creating connection for asynchronous query execution with cloud fetch disabled")
104+
connection = Connection(
105+
server_hostname=server_hostname,
106+
http_path=http_path,
107+
access_token=access_token,
108+
catalog=catalog,
109+
schema="default",
110+
use_sea=True,
111+
user_agent_entry="SEA-Test-Client",
112+
use_cloud_fetch=False,
113+
enable_query_result_lz4_compression=False,
114+
)
115+
116+
logger.info(
117+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
118+
)
119+
120+
# Execute a simple query asynchronously
121+
cursor = connection.cursor()
122+
logger.info("Executing asynchronous query without cloud fetch: SELECT 1 as test_value")
123+
cursor.execute_async("SELECT 1 as test_value")
124+
logger.info("Asynchronous query submitted successfully with cloud fetch disabled")
125+
126+
# Check query state
127+
logger.info("Checking query state...")
128+
while cursor.is_query_pending():
129+
logger.info("Query is still pending, waiting...")
130+
time.sleep(1)
131+
132+
logger.info("Query is no longer pending, getting results...")
133+
cursor.get_async_execution_result()
134+
logger.info("Successfully retrieved asynchronous query results with cloud fetch disabled")
135+
136+
# Close resources
137+
cursor.close()
138+
connection.close()
139+
logger.info("Successfully closed SEA session")
140+
141+
return True
142+
143+
except Exception as e:
144+
logger.error(f"Error during SEA asynchronous query execution test without cloud fetch: {str(e)}")
145+
import traceback
146+
logger.error(traceback.format_exc())
147+
return False
148+
149+
150+
def test_sea_async_query_exec():
151+
"""
152+
Run both asynchronous query tests and return overall success.
153+
"""
154+
with_cloud_fetch_success = test_sea_async_query_with_cloud_fetch()
155+
logger.info(f"Asynchronous query with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}")
156+
157+
without_cloud_fetch_success = test_sea_async_query_without_cloud_fetch()
158+
logger.info(f"Asynchronous query without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}")
159+
160+
return with_cloud_fetch_success and without_cloud_fetch_success
161+
162+
163+
if __name__ == "__main__":
164+
success = test_sea_async_query_exec()
165+
sys.exit(0 if success else 1)

0 commit comments

Comments
 (0)