Skip to content

Commit 6ae1caa

Browse files
Merge branch 'ext-links-sea' into cloudfetchq-sea
2 parents 715cc13 + f90b4d4 commit 6ae1caa

20 files changed

+1582
-1697
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,54 @@
11
"""
22
Main script to run all SEA connector tests.
33
4-
This script imports and runs all the individual test modules and displays
4+
This script runs all the individual test modules and displays
55
a summary of test results with visual indicators.
66
"""
77
import os
88
import sys
99
import logging
10-
import importlib.util
11-
from typing import Dict, Callable, List, Tuple
10+
import subprocess
11+
from typing import List, Tuple
1212

13-
# Configure logging
14-
logging.basicConfig(level=logging.INFO)
13+
logging.basicConfig(level=logging.DEBUG)
1514
logger = logging.getLogger(__name__)
1615

17-
# Define test modules and their main test functions
1816
TEST_MODULES = [
1917
"test_sea_session",
2018
"test_sea_sync_query",
2119
"test_sea_async_query",
2220
"test_sea_metadata",
21+
"test_sea_multi_chunk",
2322
]
2423

2524

26-
def load_test_function(module_name: str) -> Callable:
27-
"""Load a test function from a module."""
25+
def run_test_module(module_name: str) -> bool:
26+
"""Run a test module and return success status."""
2827
module_path = os.path.join(
2928
os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py"
3029
)
3130

32-
spec = importlib.util.spec_from_file_location(module_name, module_path)
33-
module = importlib.util.module_from_spec(spec)
34-
spec.loader.exec_module(module)
31+
# Handle the multi-chunk test which is in the main directory
32+
if module_name == "test_sea_multi_chunk":
33+
module_path = os.path.join(
34+
os.path.dirname(os.path.abspath(__file__)), f"{module_name}.py"
35+
)
36+
37+
# Simply run the module as a script - each module handles its own test execution
38+
result = subprocess.run(
39+
[sys.executable, module_path], capture_output=True, text=True
40+
)
3541

36-
# Get the main test function (assuming it starts with "test_")
37-
for name in dir(module):
38-
if name.startswith("test_") and callable(getattr(module, name)):
39-
# For sync and async query modules, we want the main function that runs both tests
40-
if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec":
41-
return getattr(module, name)
42+
# Log the output from the test module
43+
if result.stdout:
44+
for line in result.stdout.strip().split("\n"):
45+
logger.info(line)
4246

43-
# Fallback to the first test function found
44-
for name in dir(module):
45-
if name.startswith("test_") and callable(getattr(module, name)):
46-
return getattr(module, name)
47+
if result.stderr:
48+
for line in result.stderr.strip().split("\n"):
49+
logger.error(line)
4750

48-
raise ValueError(f"No test function found in module {module_name}")
51+
return result.returncode == 0
4952

5053

5154
def run_tests() -> List[Tuple[str, bool]]:
@@ -54,12 +57,11 @@ def run_tests() -> List[Tuple[str, bool]]:
5457

5558
for module_name in TEST_MODULES:
5659
try:
57-
test_func = load_test_function(module_name)
5860
logger.info(f"\n{'=' * 50}")
5961
logger.info(f"Running test: {module_name}")
6062
logger.info(f"{'-' * 50}")
6163

62-
success = test_func()
64+
success = run_test_module(module_name)
6365
results.append((module_name, success))
6466

6567
status = "✅ PASSED" if success else "❌ FAILED"
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""
2+
Test for SEA multi-chunk responses.
3+
4+
This script tests the SEA connector's ability to handle multi-chunk responses correctly.
5+
It runs a query that generates large rows to force multiple chunks and verifies that
6+
the correct number of rows are returned.
7+
"""
8+
import os
9+
import sys
10+
import logging
11+
import time
12+
import json
13+
import csv
14+
from pathlib import Path
15+
from databricks.sql.client import Connection
16+
17+
logging.basicConfig(level=logging.INFO)
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def test_sea_multi_chunk_with_cloud_fetch(requested_row_count=5000):
22+
"""
23+
Test executing a query that generates multiple chunks using cloud fetch.
24+
25+
Args:
26+
requested_row_count: Number of rows to request in the query
27+
28+
Returns:
29+
bool: True if the test passed, False otherwise
30+
"""
31+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
32+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
33+
access_token = os.environ.get("DATABRICKS_TOKEN")
34+
catalog = os.environ.get("DATABRICKS_CATALOG")
35+
36+
# Create output directory for test results
37+
output_dir = Path("test_results")
38+
output_dir.mkdir(exist_ok=True)
39+
40+
# Files to store results
41+
rows_file = output_dir / "cloud_fetch_rows.csv"
42+
stats_file = output_dir / "cloud_fetch_stats.json"
43+
44+
if not all([server_hostname, http_path, access_token]):
45+
logger.error("Missing required environment variables.")
46+
logger.error(
47+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
48+
)
49+
return False
50+
51+
try:
52+
# Create connection with cloud fetch enabled
53+
logger.info(
54+
"Creating connection for query execution with cloud fetch enabled"
55+
)
56+
connection = Connection(
57+
server_hostname=server_hostname,
58+
http_path=http_path,
59+
access_token=access_token,
60+
catalog=catalog,
61+
schema="default",
62+
use_sea=True,
63+
user_agent_entry="SEA-Test-Client",
64+
use_cloud_fetch=True,
65+
)
66+
67+
logger.info(
68+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
69+
)
70+
71+
# Execute a query that generates large rows to force multiple chunks
72+
cursor = connection.cursor()
73+
query = f"""
74+
SELECT
75+
id,
76+
concat('value_', repeat('a', 10000)) as test_value
77+
FROM range(1, {requested_row_count} + 1) AS t(id)
78+
"""
79+
80+
logger.info(f"Executing query with cloud fetch to generate {requested_row_count} rows")
81+
start_time = time.time()
82+
cursor.execute(query)
83+
84+
# Fetch all rows
85+
rows = cursor.fetchall()
86+
actual_row_count = len(rows)
87+
end_time = time.time()
88+
execution_time = end_time - start_time
89+
90+
logger.info(f"Query executed in {execution_time:.2f} seconds")
91+
logger.info(f"Requested {requested_row_count} rows, received {actual_row_count} rows")
92+
93+
# Write rows to CSV file for inspection
94+
logger.info(f"Writing rows to {rows_file}")
95+
with open(rows_file, 'w', newline='') as f:
96+
writer = csv.writer(f)
97+
writer.writerow(['id', 'value_length']) # Header
98+
99+
# Extract IDs to check for duplicates and missing values
100+
row_ids = []
101+
for row in rows:
102+
row_id = row[0]
103+
value_length = len(row[1])
104+
writer.writerow([row_id, value_length])
105+
row_ids.append(row_id)
106+
107+
# Verify row count
108+
success = actual_row_count == requested_row_count
109+
110+
# Check for duplicate IDs
111+
unique_ids = set(row_ids)
112+
duplicate_count = len(row_ids) - len(unique_ids)
113+
114+
# Check for missing IDs
115+
expected_ids = set(range(1, requested_row_count + 1))
116+
missing_ids = expected_ids - unique_ids
117+
extra_ids = unique_ids - expected_ids
118+
119+
# Write statistics to JSON file
120+
stats = {
121+
"requested_row_count": requested_row_count,
122+
"actual_row_count": actual_row_count,
123+
"execution_time_seconds": execution_time,
124+
"duplicate_count": duplicate_count,
125+
"missing_ids_count": len(missing_ids),
126+
"extra_ids_count": len(extra_ids),
127+
"missing_ids": list(missing_ids)[:100] if missing_ids else [], # Limit to first 100 for readability
128+
"extra_ids": list(extra_ids)[:100] if extra_ids else [], # Limit to first 100 for readability
129+
"success": success and duplicate_count == 0 and len(missing_ids) == 0 and len(extra_ids) == 0
130+
}
131+
132+
with open(stats_file, 'w') as f:
133+
json.dump(stats, f, indent=2)
134+
135+
# Log detailed results
136+
if duplicate_count > 0:
137+
logger.error(f"❌ FAILED: Found {duplicate_count} duplicate row IDs")
138+
success = False
139+
else:
140+
logger.info("✅ PASSED: No duplicate row IDs found")
141+
142+
if missing_ids:
143+
logger.error(f"❌ FAILED: Missing {len(missing_ids)} expected row IDs")
144+
if len(missing_ids) <= 10:
145+
logger.error(f"Missing IDs: {sorted(list(missing_ids))}")
146+
success = False
147+
else:
148+
logger.info("✅ PASSED: All expected row IDs present")
149+
150+
if extra_ids:
151+
logger.error(f"❌ FAILED: Found {len(extra_ids)} unexpected row IDs")
152+
if len(extra_ids) <= 10:
153+
logger.error(f"Extra IDs: {sorted(list(extra_ids))}")
154+
success = False
155+
else:
156+
logger.info("✅ PASSED: No unexpected row IDs found")
157+
158+
if actual_row_count == requested_row_count:
159+
logger.info("✅ PASSED: Row count matches requested count")
160+
else:
161+
logger.error(f"❌ FAILED: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}")
162+
success = False
163+
164+
# Close resources
165+
cursor.close()
166+
connection.close()
167+
logger.info("Successfully closed SEA session")
168+
169+
logger.info(f"Test results written to {rows_file} and {stats_file}")
170+
return success
171+
172+
except Exception as e:
173+
logger.error(
174+
f"Error during SEA multi-chunk test with cloud fetch: {str(e)}"
175+
)
176+
import traceback
177+
logger.error(traceback.format_exc())
178+
return False
179+
180+
181+
def main():
182+
# Check if required environment variables are set
183+
required_vars = [
184+
"DATABRICKS_SERVER_HOSTNAME",
185+
"DATABRICKS_HTTP_PATH",
186+
"DATABRICKS_TOKEN",
187+
]
188+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
189+
190+
if missing_vars:
191+
logger.error(
192+
f"Missing required environment variables: {', '.join(missing_vars)}"
193+
)
194+
logger.error("Please set these variables before running the tests.")
195+
sys.exit(1)
196+
197+
# Get row count from command line or use default
198+
requested_row_count = 5000
199+
200+
if len(sys.argv) > 1:
201+
try:
202+
requested_row_count = int(sys.argv[1])
203+
except ValueError:
204+
logger.error(f"Invalid row count: {sys.argv[1]}")
205+
logger.error("Please provide a valid integer for row count.")
206+
sys.exit(1)
207+
208+
logger.info(f"Testing with {requested_row_count} rows")
209+
210+
# Run the multi-chunk test with cloud fetch
211+
success = test_sea_multi_chunk_with_cloud_fetch(requested_row_count)
212+
213+
# Report results
214+
if success:
215+
logger.info("✅ TEST PASSED: Multi-chunk cloud fetch test completed successfully")
216+
sys.exit(0)
217+
else:
218+
logger.error("❌ TEST FAILED: Multi-chunk cloud fetch test encountered errors")
219+
sys.exit(1)
220+
221+
222+
if __name__ == "__main__":
223+
main()

0 commit comments

Comments
 (0)