Skip to content

Commit 89a46af

Browse files
access ssl_options through connection
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent e1842d8 commit 89a46af

File tree

4 files changed

+96
-64
lines changed

4 files changed

+96
-64
lines changed

examples/experimental/test_sea_multi_chunk.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@
2121
def test_sea_multi_chunk_with_cloud_fetch(requested_row_count=5000):
2222
"""
2323
Test executing a query that generates multiple chunks using cloud fetch.
24-
24+
2525
Args:
2626
requested_row_count: Number of rows to request in the query
27-
27+
2828
Returns:
2929
bool: True if the test passed, False otherwise
3030
"""
3131
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
3232
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
3333
access_token = os.environ.get("DATABRICKS_TOKEN")
3434
catalog = os.environ.get("DATABRICKS_CATALOG")
35-
35+
3636
# Create output directory for test results
3737
output_dir = Path("test_results")
3838
output_dir.mkdir(exist_ok=True)
39-
39+
4040
# Files to store results
4141
rows_file = output_dir / "cloud_fetch_rows.csv"
4242
stats_file = output_dir / "cloud_fetch_stats.json"
@@ -50,9 +50,7 @@ def test_sea_multi_chunk_with_cloud_fetch(requested_row_count=5000):
5050

5151
try:
5252
# Create connection with cloud fetch enabled
53-
logger.info(
54-
"Creating connection for query execution with cloud fetch enabled"
55-
)
53+
logger.info("Creating connection for query execution with cloud fetch enabled")
5654
connection = Connection(
5755
server_hostname=server_hostname,
5856
http_path=http_path,
@@ -76,46 +74,50 @@ def test_sea_multi_chunk_with_cloud_fetch(requested_row_count=5000):
7674
concat('value_', repeat('a', 10000)) as test_value
7775
FROM range(1, {requested_row_count} + 1) AS t(id)
7876
"""
79-
80-
logger.info(f"Executing query with cloud fetch to generate {requested_row_count} rows")
77+
78+
logger.info(
79+
f"Executing query with cloud fetch to generate {requested_row_count} rows"
80+
)
8181
start_time = time.time()
8282
cursor.execute(query)
83-
83+
8484
# Fetch all rows
8585
rows = cursor.fetchall()
8686
actual_row_count = len(rows)
8787
end_time = time.time()
8888
execution_time = end_time - start_time
89-
89+
9090
logger.info(f"Query executed in {execution_time:.2f} seconds")
91-
logger.info(f"Requested {requested_row_count} rows, received {actual_row_count} rows")
92-
91+
logger.info(
92+
f"Requested {requested_row_count} rows, received {actual_row_count} rows"
93+
)
94+
9395
# Write rows to CSV file for inspection
9496
logger.info(f"Writing rows to {rows_file}")
95-
with open(rows_file, 'w', newline='') as f:
97+
with open(rows_file, "w", newline="") as f:
9698
writer = csv.writer(f)
97-
writer.writerow(['id', 'value_length']) # Header
98-
99+
writer.writerow(["id", "value_length"]) # Header
100+
99101
# Extract IDs to check for duplicates and missing values
100102
row_ids = []
101103
for row in rows:
102104
row_id = row[0]
103105
value_length = len(row[1])
104106
writer.writerow([row_id, value_length])
105107
row_ids.append(row_id)
106-
108+
107109
# Verify row count
108110
success = actual_row_count == requested_row_count
109-
111+
110112
# Check for duplicate IDs
111113
unique_ids = set(row_ids)
112114
duplicate_count = len(row_ids) - len(unique_ids)
113-
115+
114116
# Check for missing IDs
115117
expected_ids = set(range(1, requested_row_count + 1))
116118
missing_ids = expected_ids - unique_ids
117119
extra_ids = unique_ids - expected_ids
118-
120+
119121
# Write statistics to JSON file
120122
stats = {
121123
"requested_row_count": requested_row_count,
@@ -124,56 +126,64 @@ def test_sea_multi_chunk_with_cloud_fetch(requested_row_count=5000):
124126
"duplicate_count": duplicate_count,
125127
"missing_ids_count": len(missing_ids),
126128
"extra_ids_count": len(extra_ids),
127-
"missing_ids": list(missing_ids)[:100] if missing_ids else [], # Limit to first 100 for readability
128-
"extra_ids": list(extra_ids)[:100] if extra_ids else [], # Limit to first 100 for readability
129-
"success": success and duplicate_count == 0 and len(missing_ids) == 0 and len(extra_ids) == 0
129+
"missing_ids": list(missing_ids)[:100]
130+
if missing_ids
131+
else [], # Limit to first 100 for readability
132+
"extra_ids": list(extra_ids)[:100]
133+
if extra_ids
134+
else [], # Limit to first 100 for readability
135+
"success": success
136+
and duplicate_count == 0
137+
and len(missing_ids) == 0
138+
and len(extra_ids) == 0,
130139
}
131-
132-
with open(stats_file, 'w') as f:
140+
141+
with open(stats_file, "w") as f:
133142
json.dump(stats, f, indent=2)
134-
143+
135144
# Log detailed results
136145
if duplicate_count > 0:
137146
logger.error(f"❌ FAILED: Found {duplicate_count} duplicate row IDs")
138147
success = False
139148
else:
140149
logger.info("✅ PASSED: No duplicate row IDs found")
141-
150+
142151
if missing_ids:
143152
logger.error(f"❌ FAILED: Missing {len(missing_ids)} expected row IDs")
144153
if len(missing_ids) <= 10:
145154
logger.error(f"Missing IDs: {sorted(list(missing_ids))}")
146155
success = False
147156
else:
148157
logger.info("✅ PASSED: All expected row IDs present")
149-
158+
150159
if extra_ids:
151160
logger.error(f"❌ FAILED: Found {len(extra_ids)} unexpected row IDs")
152161
if len(extra_ids) <= 10:
153162
logger.error(f"Extra IDs: {sorted(list(extra_ids))}")
154163
success = False
155164
else:
156165
logger.info("✅ PASSED: No unexpected row IDs found")
157-
166+
158167
if actual_row_count == requested_row_count:
159168
logger.info("✅ PASSED: Row count matches requested count")
160169
else:
161-
logger.error(f"❌ FAILED: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}")
170+
logger.error(
171+
f"❌ FAILED: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
172+
)
162173
success = False
163-
174+
164175
# Close resources
165176
cursor.close()
166177
connection.close()
167178
logger.info("Successfully closed SEA session")
168-
179+
169180
logger.info(f"Test results written to {rows_file} and {stats_file}")
170181
return success
171182

172183
except Exception as e:
173-
logger.error(
174-
f"Error during SEA multi-chunk test with cloud fetch: {str(e)}"
175-
)
184+
logger.error(f"Error during SEA multi-chunk test with cloud fetch: {str(e)}")
176185
import traceback
186+
177187
logger.error(traceback.format_exc())
178188
return False
179189

@@ -193,31 +203,33 @@ def main():
193203
)
194204
logger.error("Please set these variables before running the tests.")
195205
sys.exit(1)
196-
206+
197207
# Get row count from command line or use default
198208
requested_row_count = 10000
199-
209+
200210
if len(sys.argv) > 1:
201211
try:
202212
requested_row_count = int(sys.argv[1])
203213
except ValueError:
204214
logger.error(f"Invalid row count: {sys.argv[1]}")
205215
logger.error("Please provide a valid integer for row count.")
206216
sys.exit(1)
207-
217+
208218
logger.info(f"Testing with {requested_row_count} rows")
209-
219+
210220
# Run the multi-chunk test with cloud fetch
211221
success = test_sea_multi_chunk_with_cloud_fetch(requested_row_count)
212-
222+
213223
# Report results
214224
if success:
215-
logger.info("✅ TEST PASSED: Multi-chunk cloud fetch test completed successfully")
225+
logger.info(
226+
"✅ TEST PASSED: Multi-chunk cloud fetch test completed successfully"
227+
)
216228
sys.exit(0)
217229
else:
218230
logger.error("❌ TEST FAILED: Multi-chunk cloud fetch test encountered errors")
219231
sys.exit(1)
220232

221233

222234
if __name__ == "__main__":
223-
main()
235+
main()

examples/experimental/tests/test_sea_async_query.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,24 +77,29 @@ def test_sea_async_query_with_cloud_fetch():
7777

7878
logger.info("Query is no longer pending, getting results...")
7979
cursor.get_async_execution_result()
80-
80+
8181
results = [cursor.fetchone()]
8282
results.extend(cursor.fetchmany(10))
8383
results.extend(cursor.fetchall())
84-
logger.info(f"{len(results)} rows retrieved against 100 requested")
84+
actual_row_count = len(results)
85+
logger.info(
86+
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
87+
)
8588

8689
logger.info(
8790
f"Requested {requested_row_count} rows, received {actual_row_count} rows"
8891
)
89-
92+
9093
# Verify total row count
9194
if actual_row_count != requested_row_count:
9295
logger.error(
9396
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
9497
)
9598
return False
96-
97-
logger.info("PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly")
99+
100+
logger.info(
101+
"PASS: Received correct number of rows with cloud fetch and all fetch methods work correctly"
102+
)
98103

99104
# Close resources
100105
cursor.close()
@@ -182,20 +187,25 @@ def test_sea_async_query_without_cloud_fetch():
182187
results = [cursor.fetchone()]
183188
results.extend(cursor.fetchmany(10))
184189
results.extend(cursor.fetchall())
185-
logger.info(f"{len(results)} rows retrieved against 100 requested")
190+
actual_row_count = len(results)
191+
logger.info(
192+
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
193+
)
186194

187195
logger.info(
188196
f"Requested {requested_row_count} rows, received {actual_row_count} rows"
189197
)
190-
198+
191199
# Verify total row count
192200
if actual_row_count != requested_row_count:
193201
logger.error(
194202
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
195203
)
196204
return False
197205

198-
logger.info("PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly")
206+
logger.info(
207+
"PASS: Received correct number of rows without cloud fetch and all fetch methods work correctly"
208+
)
199209

200210
# Close resources
201211
cursor.close()

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,14 @@ def test_sea_sync_query_with_cloud_fetch():
6262
logger.info(
6363
f"Executing synchronous query with cloud fetch to generate {requested_row_count} rows"
6464
)
65+
cursor.execute(query)
6566
results = [cursor.fetchone()]
6667
results.extend(cursor.fetchmany(10))
6768
results.extend(cursor.fetchall())
68-
logger.info(f"{len(results)} rows retrieved against 100 requested")
69+
actual_row_count = len(results)
70+
logger.info(
71+
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
72+
)
6973

7074
# Close resources
7175
cursor.close()

src/databricks/sql/result_set.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -472,23 +472,14 @@ def __init__(
472472
result_data: Result data from SEA response (optional)
473473
manifest: Manifest from SEA response (optional)
474474
"""
475-
# Extract and store SEA-specific properties
476-
self.statement_id = (
477-
execute_response.command_id.to_sea_statement_id()
478-
if execute_response.command_id
479-
else None
480-
)
481-
482-
# Build the results queue
483-
results_queue = None
484475

485476
results_queue = None
486477
if result_data:
487478
results_queue = SeaResultSetQueueFactory.build_queue(
488479
result_data,
489480
manifest,
490481
str(execute_response.command_id.to_sea_statement_id()),
491-
ssl_options=self.connection.session.ssl_options,
482+
ssl_options=connection.session.ssl_options,
492483
description=execute_response.description,
493484
max_download_threads=sea_client.max_download_threads,
494485
sea_client=sea_client,
@@ -513,6 +504,21 @@ def __init__(
513504
# Initialize queue for result data if not provided
514505
self.results = results_queue or JsonQueue([])
515506

507+
def _convert_json_table(self, rows):
508+
"""
509+
Convert raw data rows to Row objects with named columns based on description.
510+
Args:
511+
rows: List of raw data rows
512+
Returns:
513+
List of Row objects with named columns
514+
"""
515+
if not self.description or not rows:
516+
return rows
517+
518+
column_names = [col[0] for col in self.description]
519+
ResultRow = Row(*column_names)
520+
return [ResultRow(*row) for row in rows]
521+
516522
def fetchmany_json(self, size: int):
517523
"""
518524
Fetch the next set of rows as a columnar table.
@@ -586,7 +592,7 @@ def fetchone(self) -> Optional[Row]:
586592
A single Row object or None if no more rows are available
587593
"""
588594
if isinstance(self.results, JsonQueue):
589-
res = self.fetchmany_json(1)
595+
res = self._convert_json_table(self.fetchmany_json(1))
590596
else:
591597
res = self._convert_arrow_table(self.fetchmany_arrow(1))
592598

@@ -606,7 +612,7 @@ def fetchmany(self, size: int) -> List[Row]:
606612
ValueError: If size is negative
607613
"""
608614
if isinstance(self.results, JsonQueue):
609-
return self.fetchmany_json(size)
615+
return self._convert_json_table(self.fetchmany_json(size))
610616
else:
611617
return self._convert_arrow_table(self.fetchmany_arrow(size))
612618

@@ -618,6 +624,6 @@ def fetchall(self) -> List[Row]:
618624
List of Row objects containing all remaining rows
619625
"""
620626
if isinstance(self.results, JsonQueue):
621-
return self.fetchall_json()
627+
return self._convert_json_table(self.fetchall_json())
622628
else:
623629
return self._convert_arrow_table(self.fetchall_arrow())

0 commit comments

Comments
 (0)