1
+ """
2
+ Test for SEA multi-chunk responses.
3
+
4
+ This script tests the SEA connector's ability to handle multi-chunk responses correctly.
5
+ It runs a query that generates large rows to force multiple chunks and verifies that
6
+ the correct number of rows are returned.
7
+ """
8
+ import os
9
+ import sys
10
+ import logging
11
+ import time
12
+ import json
13
+ import csv
14
+ from pathlib import Path
15
+ from databricks .sql .client import Connection
16
+
17
+ logging .basicConfig (level = logging .INFO )
18
+ logger = logging .getLogger (__name__ )
19
+
20
+
21
+ def test_sea_multi_chunk_with_cloud_fetch (requested_row_count = 5000 ):
22
+ """
23
+ Test executing a query that generates multiple chunks using cloud fetch.
24
+
25
+ Args:
26
+ requested_row_count: Number of rows to request in the query
27
+
28
+ Returns:
29
+ bool: True if the test passed, False otherwise
30
+ """
31
+ server_hostname = os .environ .get ("DATABRICKS_SERVER_HOSTNAME" )
32
+ http_path = os .environ .get ("DATABRICKS_HTTP_PATH" )
33
+ access_token = os .environ .get ("DATABRICKS_TOKEN" )
34
+ catalog = os .environ .get ("DATABRICKS_CATALOG" )
35
+
36
+ # Create output directory for test results
37
+ output_dir = Path ("test_results" )
38
+ output_dir .mkdir (exist_ok = True )
39
+
40
+ # Files to store results
41
+ rows_file = output_dir / "cloud_fetch_rows.csv"
42
+ stats_file = output_dir / "cloud_fetch_stats.json"
43
+
44
+ if not all ([server_hostname , http_path , access_token ]):
45
+ logger .error ("Missing required environment variables." )
46
+ logger .error (
47
+ "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
48
+ )
49
+ return False
50
+
51
+ try :
52
+ # Create connection with cloud fetch enabled
53
+ logger .info (
54
+ "Creating connection for query execution with cloud fetch enabled"
55
+ )
56
+ connection = Connection (
57
+ server_hostname = server_hostname ,
58
+ http_path = http_path ,
59
+ access_token = access_token ,
60
+ catalog = catalog ,
61
+ schema = "default" ,
62
+ use_sea = True ,
63
+ user_agent_entry = "SEA-Test-Client" ,
64
+ use_cloud_fetch = True ,
65
+ )
66
+
67
+ logger .info (
68
+ f"Successfully opened SEA session with ID: { connection .get_session_id_hex ()} "
69
+ )
70
+
71
+ # Execute a query that generates large rows to force multiple chunks
72
+ cursor = connection .cursor ()
73
+ query = f"""
74
+ SELECT
75
+ id,
76
+ concat('value_', repeat('a', 10000)) as test_value
77
+ FROM range(1, { requested_row_count } + 1) AS t(id)
78
+ """
79
+
80
+ logger .info (f"Executing query with cloud fetch to generate { requested_row_count } rows" )
81
+ start_time = time .time ()
82
+ cursor .execute (query )
83
+
84
+ # Fetch all rows
85
+ rows = cursor .fetchall ()
86
+ actual_row_count = len (rows )
87
+ end_time = time .time ()
88
+ execution_time = end_time - start_time
89
+
90
+ logger .info (f"Query executed in { execution_time :.2f} seconds" )
91
+ logger .info (f"Requested { requested_row_count } rows, received { actual_row_count } rows" )
92
+
93
+ # Write rows to CSV file for inspection
94
+ logger .info (f"Writing rows to { rows_file } " )
95
+ with open (rows_file , 'w' , newline = '' ) as f :
96
+ writer = csv .writer (f )
97
+ writer .writerow (['id' , 'value_length' ]) # Header
98
+
99
+ # Extract IDs to check for duplicates and missing values
100
+ row_ids = []
101
+ for row in rows :
102
+ row_id = row [0 ]
103
+ value_length = len (row [1 ])
104
+ writer .writerow ([row_id , value_length ])
105
+ row_ids .append (row_id )
106
+
107
+ # Verify row count
108
+ success = actual_row_count == requested_row_count
109
+
110
+ # Check for duplicate IDs
111
+ unique_ids = set (row_ids )
112
+ duplicate_count = len (row_ids ) - len (unique_ids )
113
+
114
+ # Check for missing IDs
115
+ expected_ids = set (range (1 , requested_row_count + 1 ))
116
+ missing_ids = expected_ids - unique_ids
117
+ extra_ids = unique_ids - expected_ids
118
+
119
+ # Write statistics to JSON file
120
+ stats = {
121
+ "requested_row_count" : requested_row_count ,
122
+ "actual_row_count" : actual_row_count ,
123
+ "execution_time_seconds" : execution_time ,
124
+ "duplicate_count" : duplicate_count ,
125
+ "missing_ids_count" : len (missing_ids ),
126
+ "extra_ids_count" : len (extra_ids ),
127
+ "missing_ids" : list (missing_ids )[:100 ] if missing_ids else [], # Limit to first 100 for readability
128
+ "extra_ids" : list (extra_ids )[:100 ] if extra_ids else [], # Limit to first 100 for readability
129
+ "success" : success and duplicate_count == 0 and len (missing_ids ) == 0 and len (extra_ids ) == 0
130
+ }
131
+
132
+ with open (stats_file , 'w' ) as f :
133
+ json .dump (stats , f , indent = 2 )
134
+
135
+ # Log detailed results
136
+ if duplicate_count > 0 :
137
+ logger .error (f"❌ FAILED: Found { duplicate_count } duplicate row IDs" )
138
+ success = False
139
+ else :
140
+ logger .info ("✅ PASSED: No duplicate row IDs found" )
141
+
142
+ if missing_ids :
143
+ logger .error (f"❌ FAILED: Missing { len (missing_ids )} expected row IDs" )
144
+ if len (missing_ids ) <= 10 :
145
+ logger .error (f"Missing IDs: { sorted (list (missing_ids ))} " )
146
+ success = False
147
+ else :
148
+ logger .info ("✅ PASSED: All expected row IDs present" )
149
+
150
+ if extra_ids :
151
+ logger .error (f"❌ FAILED: Found { len (extra_ids )} unexpected row IDs" )
152
+ if len (extra_ids ) <= 10 :
153
+ logger .error (f"Extra IDs: { sorted (list (extra_ids ))} " )
154
+ success = False
155
+ else :
156
+ logger .info ("✅ PASSED: No unexpected row IDs found" )
157
+
158
+ if actual_row_count == requested_row_count :
159
+ logger .info ("✅ PASSED: Row count matches requested count" )
160
+ else :
161
+ logger .error (f"❌ FAILED: Row count mismatch. Expected { requested_row_count } , got { actual_row_count } " )
162
+ success = False
163
+
164
+ # Close resources
165
+ cursor .close ()
166
+ connection .close ()
167
+ logger .info ("Successfully closed SEA session" )
168
+
169
+ logger .info (f"Test results written to { rows_file } and { stats_file } " )
170
+ return success
171
+
172
+ except Exception as e :
173
+ logger .error (
174
+ f"Error during SEA multi-chunk test with cloud fetch: { str (e )} "
175
+ )
176
+ import traceback
177
+ logger .error (traceback .format_exc ())
178
+ return False
179
+
180
+
181
+ def main ():
182
+ # Check if required environment variables are set
183
+ required_vars = [
184
+ "DATABRICKS_SERVER_HOSTNAME" ,
185
+ "DATABRICKS_HTTP_PATH" ,
186
+ "DATABRICKS_TOKEN" ,
187
+ ]
188
+ missing_vars = [var for var in required_vars if not os .environ .get (var )]
189
+
190
+ if missing_vars :
191
+ logger .error (
192
+ f"Missing required environment variables: { ', ' .join (missing_vars )} "
193
+ )
194
+ logger .error ("Please set these variables before running the tests." )
195
+ sys .exit (1 )
196
+
197
+ # Get row count from command line or use default
198
+ requested_row_count = 5000
199
+
200
+ if len (sys .argv ) > 1 :
201
+ try :
202
+ requested_row_count = int (sys .argv [1 ])
203
+ except ValueError :
204
+ logger .error (f"Invalid row count: { sys .argv [1 ]} " )
205
+ logger .error ("Please provide a valid integer for row count." )
206
+ sys .exit (1 )
207
+
208
+ logger .info (f"Testing with { requested_row_count } rows" )
209
+
210
+ # Run the multi-chunk test with cloud fetch
211
+ success = test_sea_multi_chunk_with_cloud_fetch (requested_row_count )
212
+
213
+ # Report results
214
+ if success :
215
+ logger .info ("✅ TEST PASSED: Multi-chunk cloud fetch test completed successfully" )
216
+ sys .exit (0 )
217
+ else :
218
+ logger .error ("❌ TEST FAILED: Multi-chunk cloud fetch test encountered errors" )
219
+ sys .exit (1 )
220
+
221
+
222
+ if __name__ == "__main__" :
223
+ main ()
0 commit comments