1
+ """
2
+ Main script to run all SEA connector tests.
3
+
4
+ This script imports and runs all the individual test modules and displays
5
+ a summary of test results with visual indicators.
6
+ """
1
7
import os
2
8
import sys
3
9
import logging
4
- from databricks .sql .client import Connection
10
+ import importlib .util
11
+ from typing import Dict , Callable , List , Tuple
5
12
6
- logging .basicConfig (level = logging .DEBUG )
13
+ # Configure logging
14
+ logging .basicConfig (level = logging .INFO )
7
15
logger = logging .getLogger (__name__ )
8
16
9
- def test_sea_session ():
10
- """
11
- Test opening and closing a SEA session using the connector.
17
+ # Define test modules and their main test functions
18
+ TEST_MODULES = [
19
+ "test_sea_session" ,
20
+ "test_sea_sync_query" ,
21
+ "test_sea_async_query" ,
22
+ "test_sea_metadata" ,
23
+ ]
24
+
25
+ def load_test_function (module_name : str ) -> Callable :
26
+ """Load a test function from a module."""
27
+ module_path = os .path .join (
28
+ os .path .dirname (os .path .abspath (__file__ )),
29
+ "tests" ,
30
+ f"{ module_name } .py"
31
+ )
32
+
33
+ spec = importlib .util .spec_from_file_location (module_name , module_path )
34
+ module = importlib .util .module_from_spec (spec )
35
+ spec .loader .exec_module (module )
36
+
37
+ # Get the main test function (assuming it starts with "test_")
38
+ for name in dir (module ):
39
+ if name .startswith ("test_" ) and callable (getattr (module , name )):
40
+ # For sync and async query modules, we want the main function that runs both tests
41
+ if name == f"test_sea_{ module_name .replace ('test_sea_' , '' )} _exec" :
42
+ return getattr (module , name )
12
43
13
- This function connects to a Databricks SQL endpoint using the SEA backend,
14
- opens a session, and then closes it.
44
+ # Fallback to the first test function found
45
+ for name in dir (module ):
46
+ if name .startswith ("test_" ) and callable (getattr (module , name )):
47
+ return getattr (module , name )
15
48
16
- Required environment variables:
17
- - DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
18
- - DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
19
- - DATABRICKS_TOKEN: Personal access token for authentication
20
- """
49
+ raise ValueError (f"No test function found in module { module_name } " )
21
50
22
- server_hostname = os .environ .get ("DATABRICKS_SERVER_HOSTNAME" )
23
- http_path = os .environ .get ("DATABRICKS_HTTP_PATH" )
24
- access_token = os .environ .get ("DATABRICKS_TOKEN" )
25
- catalog = os .environ .get ("DATABRICKS_CATALOG" )
51
+ def run_tests () -> List [Tuple [str , bool ]]:
52
+ """Run all tests and return results."""
53
+ results = []
26
54
27
- if not all ([server_hostname , http_path , access_token ]):
28
- logger .error ("Missing required environment variables." )
29
- logger .error ("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN." )
30
- sys .exit (1 )
55
+ for module_name in TEST_MODULES :
56
+ try :
57
+ test_func = load_test_function (module_name )
58
+ logger .info (f"\n { '=' * 50 } " )
59
+ logger .info (f"Running test: { module_name } " )
60
+ logger .info (f"{ '-' * 50 } " )
61
+
62
+ success = test_func ()
63
+ results .append ((module_name , success ))
64
+
65
+ status = "✅ PASSED" if success else "❌ FAILED"
66
+ logger .info (f"Test { module_name } : { status } " )
67
+
68
+ except Exception as e :
69
+ logger .error (f"Error loading or running test { module_name } : { str (e )} " )
70
+ import traceback
71
+ logger .error (traceback .format_exc ())
72
+ results .append ((module_name , False ))
31
73
32
- logger .info (f"Connecting to { server_hostname } " )
33
- logger .info (f"HTTP Path: { http_path } " )
34
- if catalog :
35
- logger .info (f"Using catalog: { catalog } " )
74
+ return results
75
+
76
+ def print_summary (results : List [Tuple [str , bool ]]) -> None :
77
+ """Print a summary of test results."""
78
+ logger .info (f"\n { '=' * 50 } " )
79
+ logger .info ("TEST SUMMARY" )
80
+ logger .info (f"{ '-' * 50 } " )
36
81
37
- try :
38
- logger .info ("Creating connection with SEA backend..." )
39
- connection = Connection (
40
- server_hostname = server_hostname ,
41
- http_path = http_path ,
42
- access_token = access_token ,
43
- catalog = catalog ,
44
- schema = "default" ,
45
- use_sea = True ,
46
- user_agent_entry = "SEA-Test-Client" # add custom user agent
47
- )
48
-
49
- logger .info (f"Successfully opened SEA session with ID: { connection .get_session_id_hex ()} " )
50
- logger .info (f"backend type: { type (connection .session .backend )} " )
51
-
52
- # Close the connection
53
- logger .info ("Closing the SEA session..." )
54
- connection .close ()
55
- logger .info ("Successfully closed SEA session" )
56
-
57
- except Exception as e :
58
- logger .error (f"Error testing SEA session: { str (e )} " )
59
- import traceback
60
- logger .error (traceback .format_exc ())
61
- sys .exit (1 )
82
+ passed = sum (1 for _ , success in results if success )
83
+ total = len (results )
62
84
63
- logger .info ("SEA session test completed successfully" )
85
+ for module_name , success in results :
86
+ status = "✅ PASSED" if success else "❌ FAILED"
87
+ logger .info (f"{ status } - { module_name } " )
88
+
89
+ logger .info (f"{ '-' * 50 } " )
90
+ logger .info (f"Total: { total } | Passed: { passed } | Failed: { total - passed } " )
91
+ logger .info (f"{ '=' * 50 } " )
64
92
65
93
if __name__ == "__main__" :
66
- test_sea_session ()
94
+ # Check if required environment variables are set
95
+ required_vars = ["DATABRICKS_SERVER_HOSTNAME" , "DATABRICKS_HTTP_PATH" , "DATABRICKS_TOKEN" ]
96
+ missing_vars = [var for var in required_vars if not os .environ .get (var )]
97
+
98
+ if missing_vars :
99
+ logger .error (f"Missing required environment variables: { ', ' .join (missing_vars )} " )
100
+ logger .error ("Please set these variables before running the tests." )
101
+ sys .exit (1 )
102
+
103
+ # Run all tests
104
+ results = run_tests ()
105
+
106
+ # Print summary
107
+ print_summary (results )
108
+
109
+ # Exit with appropriate status code
110
+ all_passed = all (success for _ , success in results )
111
+ sys .exit (0 if all_passed else 1 )
0 commit comments