Skip to content

Commit 75c5a62

Browse files
fixed multi chunk
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent d7ab57f commit 75c5a62

File tree

3 files changed

+413
-24
lines changed

3 files changed

+413
-24
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 259 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ def test_sea_result_set_arrow_external_links():
320320
# Execute a query that returns a large result set (will use EXTERNAL_LINKS disposition)
321321
# Use a larger result set to ensure multiple chunks
322322
# Using a CROSS JOIN to generate a larger result set
323-
logger.info("Executing query: SELECT a.id as id1, b.id as id2 FROM range(1, 1000) a CROSS JOIN range(1, 1000) b LIMIT 100000")
324-
cursor.execute("SELECT a.id as id1, b.id as id2 FROM range(1, 1000) a CROSS JOIN range(1, 1000) b LIMIT 100000")
323+
logger.info("Executing query: SELECT a.id as id1, b.id as id2, CONCAT(CAST(a.id AS STRING), '-', CAST(b.id AS STRING)) as concat_str FROM range(1, 1000) a CROSS JOIN range(1, 1000) b LIMIT 500000")
324+
cursor.execute("SELECT a.id as id1, b.id as id2, CONCAT(CAST(a.id AS STRING), '-', CAST(b.id AS STRING)) as concat_str FROM range(1, 1000) a CROSS JOIN range(1, 1000) b LIMIT 500000")
325325

326326
# Test the manifest to verify we're getting multiple chunks
327327
# We can't easily access the manifest in the SeaResultSet, so we'll just continue with the test
@@ -387,6 +387,259 @@ def test_sea_result_set_arrow_external_links():
387387
logger.info("SEA result set test with ARROW format and EXTERNAL_LINKS disposition completed successfully")
388388

389389

390+
def test_sea_result_set_with_multiple_chunks():
391+
"""
392+
Test the SEA result set implementation with multiple chunks.
393+
394+
This function connects to a Databricks SQL endpoint using the SEA backend,
395+
executes a query that returns a large result set in multiple chunks,
396+
and tests fetching data from multiple chunks.
397+
"""
398+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
399+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
400+
access_token = os.environ.get("DATABRICKS_TOKEN")
401+
catalog = os.environ.get("DATABRICKS_CATALOG", "samples")
402+
schema = os.environ.get("DATABRICKS_SCHEMA", "default")
403+
404+
if not all([server_hostname, http_path, access_token]):
405+
logger.error("Missing required environment variables.")
406+
logger.error(
407+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
408+
)
409+
sys.exit(1)
410+
411+
try:
412+
# Create connection with SEA backend
413+
logger.info("Creating connection with SEA backend...")
414+
connection = Connection(
415+
server_hostname=server_hostname,
416+
http_path=http_path,
417+
access_token=access_token,
418+
catalog=catalog,
419+
schema=schema,
420+
use_sea=True,
421+
use_cloud_fetch=True, # Enable cloud fetch to trigger EXTERNAL_LINKS + ARROW
422+
user_agent_entry="SEA-Test-Client",
423+
# Use a smaller arraysize to potentially force multiple chunks
424+
arraysize=1000,
425+
)
426+
427+
logger.info(
428+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
429+
)
430+
431+
# Create cursor
432+
cursor = connection.cursor()
433+
434+
# Execute the query that we know returns multiple chunks from interactive-sea testing
435+
logger.info("Executing query that returns multiple chunks...")
436+
query = """
437+
WITH large_dataset AS (
438+
SELECT
439+
id,
440+
id * 2 as double_id,
441+
id * 3 as triple_id,
442+
concat('value_', repeat(cast(id as string), 100)) as large_string_value,
443+
array_repeat(id, 50) as large_array_value,
444+
rand() as random_val,
445+
current_timestamp() as current_time
446+
FROM range(1, 100000) AS t(id)
447+
)
448+
SELECT * FROM large_dataset
449+
"""
450+
cursor.execute(query)
451+
452+
# Attempt to access the manifest to check for multiple chunks
453+
from databricks.sql.backend.sea_backend import SeaDatabricksClient
454+
if isinstance(connection.session.backend, SeaDatabricksClient):
455+
# Get the statement ID from the cursor's active result set
456+
statement_id = cursor.active_result_set.statement_id
457+
if statement_id:
458+
# Make a direct request to get the statement status
459+
response_data = connection.session.backend.http_client._make_request(
460+
method="GET",
461+
path=f"/api/2.0/sql/statements/{statement_id}",
462+
)
463+
464+
# Check if we have multiple chunks
465+
manifest = response_data.get("manifest", {})
466+
total_chunk_count = manifest.get("total_chunk_count", 0)
467+
truncated = manifest.get("truncated", False)
468+
469+
logger.info(f"Total chunk count: {total_chunk_count}")
470+
logger.info(f"Result truncated: {truncated}")
471+
472+
# Log chunk information
473+
chunks = manifest.get("chunks", [])
474+
for i, chunk in enumerate(chunks):
475+
logger.info(f"Chunk {i}: index={chunk.get('chunk_index')}, rows={chunk.get('row_count')}, bytes={chunk.get('byte_count')}")
476+
477+
# Log the next_chunk_index from the first external link
478+
result_data = response_data.get("result", {})
479+
external_links = result_data.get("external_links", [])
480+
if external_links:
481+
first_link = external_links[0]
482+
logger.info(f"First link next_chunk_index: {first_link.get('next_chunk_index')}")
483+
logger.info(f"First link next_chunk_internal_link: {first_link.get('next_chunk_internal_link')}")
484+
485+
# Test fetchone
486+
logger.info("Testing fetchone...")
487+
row = cursor.fetchone()
488+
logger.info(f"First row: {row}")
489+
490+
# Test fetchmany with a size that spans multiple chunks
491+
fetch_size = 30000 # This should span at least 2 chunks based on our test
492+
logger.info(f"Testing fetchmany({fetch_size})...")
493+
rows = cursor.fetchmany(fetch_size)
494+
logger.info(f"Fetched {len(rows)} rows with fetchmany")
495+
first_batch_count = len(rows)
496+
497+
# Test another fetchmany to get more chunks
498+
logger.info(f"Testing another fetchmany({fetch_size})...")
499+
more_rows = cursor.fetchmany(fetch_size)
500+
logger.info(f"Fetched {len(more_rows)} more rows with fetchmany")
501+
second_batch_count = len(more_rows)
502+
503+
# Test fetchall for remaining rows
504+
logger.info("Testing fetchall...")
505+
remaining_rows = cursor.fetchall()
506+
logger.info(f"Fetched {len(remaining_rows)} remaining rows with fetchall")
507+
remaining_count = len(remaining_rows)
508+
509+
# Verify results using row IDs instead of row counts
510+
# Calculate the sum of rows from the manifest chunks
511+
manifest_rows_sum = sum(chunk.get('row_count', 0) for chunk in manifest.get('chunks', []))
512+
logger.info(f"Expected rows from manifest chunks: {manifest_rows_sum}")
513+
514+
# Collect all row IDs to check for duplicates and completeness
515+
all_row_ids = set()
516+
517+
# Add the first row's ID
518+
if row and hasattr(row, 'id'):
519+
all_row_ids.add(row.id)
520+
first_id = row.id
521+
logger.info(f"First row ID: {first_id}")
522+
523+
# Add IDs from first batch
524+
if rows and len(rows) > 0 and hasattr(rows[0], 'id'):
525+
batch_ids = [r.id for r in rows if hasattr(r, 'id')]
526+
all_row_ids.update(batch_ids)
527+
logger.info(f"First batch: {len(rows)} rows, ID range {min(batch_ids)} to {max(batch_ids)}")
528+
529+
# Add IDs from second batch
530+
if more_rows and len(more_rows) > 0 and hasattr(more_rows[0], 'id'):
531+
batch_ids = [r.id for r in more_rows if hasattr(r, 'id')]
532+
all_row_ids.update(batch_ids)
533+
logger.info(f"Second batch: {len(more_rows)} rows, ID range {min(batch_ids)} to {max(batch_ids)}")
534+
535+
# Add IDs from remaining rows
536+
if remaining_rows and len(remaining_rows) > 0 and hasattr(remaining_rows[0], 'id'):
537+
batch_ids = [r.id for r in remaining_rows if hasattr(r, 'id')]
538+
all_row_ids.update(batch_ids)
539+
logger.info(f"Remaining batch: {len(remaining_rows)} rows, ID range {min(batch_ids)} to {max(batch_ids)}")
540+
541+
# Check for completeness and duplicates
542+
if all_row_ids:
543+
min_id = min(all_row_ids)
544+
max_id = max(all_row_ids)
545+
expected_count = max_id - min_id + 1
546+
actual_count = len(all_row_ids)
547+
548+
logger.info(f"Row ID range: {min_id} to {max_id}")
549+
logger.info(f"Expected unique IDs in range: {expected_count}")
550+
logger.info(f"Actual unique IDs collected: {actual_count}")
551+
552+
if expected_count == actual_count:
553+
logger.info("✅ All rows fetched correctly with no gaps")
554+
else:
555+
logger.warning("⚠️ Gap detected in row IDs")
556+
557+
# Check for duplicates
558+
if actual_count == len(all_row_ids):
559+
logger.info("✅ No duplicate row IDs detected")
560+
else:
561+
logger.warning("⚠️ Duplicate row IDs detected")
562+
563+
# Check if we got all expected rows
564+
if max_id == manifest_rows_sum:
565+
logger.info("✅ Last row ID matches expected row count from manifest")
566+
567+
# Let's try one more time with a fresh cursor to fetch all rows at once
568+
logger.info("\nTesting fetchall_arrow with a fresh cursor...")
569+
new_cursor = connection.cursor()
570+
new_cursor.execute(query)
571+
572+
try:
573+
# Fetch all rows as Arrow
574+
arrow_table = new_cursor.fetchall_arrow()
575+
logger.info(f"Arrow table num rows: {arrow_table.num_rows}")
576+
logger.info(f"Arrow table columns: {arrow_table.column_names}")
577+
578+
# Get the ID column if it exists
579+
if 'id' in arrow_table.column_names:
580+
id_column = arrow_table.column('id').to_pylist()
581+
logger.info(f"First 5 rows of id column: {id_column[:5]}")
582+
logger.info(f"Last 5 rows of id column: {id_column[-5:]}")
583+
584+
# Check for completeness and duplicates in Arrow results
585+
arrow_id_set = set(id_column)
586+
arrow_min_id = min(id_column)
587+
arrow_max_id = max(id_column)
588+
arrow_expected_count = arrow_max_id - arrow_min_id + 1
589+
arrow_actual_count = len(arrow_id_set)
590+
591+
logger.info(f"Arrow result row ID range: {arrow_min_id} to {arrow_max_id}")
592+
logger.info(f"Arrow result expected unique IDs: {arrow_expected_count}")
593+
logger.info(f"Arrow result actual unique IDs: {arrow_actual_count}")
594+
595+
if arrow_expected_count == arrow_actual_count:
596+
logger.info("✅ Arrow results: All rows fetched correctly with no gaps")
597+
else:
598+
logger.warning("⚠️ Arrow results: Gap detected in row IDs")
599+
600+
if arrow_actual_count == len(arrow_id_set):
601+
logger.info("✅ Arrow results: No duplicate row IDs detected")
602+
else:
603+
logger.warning("⚠️ Arrow results: Duplicate row IDs detected")
604+
605+
# Compare with manifest row count
606+
if arrow_max_id == manifest_rows_sum:
607+
logger.info("✅ Arrow results: Last row ID matches expected row count from manifest")
608+
609+
# Compare with sequential fetch results
610+
if arrow_id_set == all_row_ids:
611+
logger.info("✅ Arrow and sequential fetch results contain exactly the same row IDs")
612+
else:
613+
logger.warning("⚠️ Arrow and sequential fetch results contain different row IDs")
614+
only_in_arrow = arrow_id_set - all_row_ids
615+
only_in_sequential = all_row_ids - arrow_id_set
616+
if only_in_arrow:
617+
logger.warning(f"IDs only in Arrow results: {len(only_in_arrow)} rows")
618+
if only_in_sequential:
619+
logger.warning(f"IDs only in sequential fetch: {len(only_in_sequential)} rows")
620+
621+
# Check if we got all rows
622+
logger.info(f"Expected rows from manifest chunks: {manifest_rows_sum}")
623+
logger.info(f"Actual rows in arrow table: {arrow_table.num_rows}")
624+
except Exception as e:
625+
logger.error(f"Error fetching all rows as Arrow: {e}")
626+
627+
new_cursor.close()
628+
629+
# Close cursor and connection
630+
cursor.close()
631+
connection.close()
632+
logger.info("Successfully closed SEA session")
633+
634+
except Exception as e:
635+
logger.error(f"Error during SEA result set test: {str(e)}")
636+
import traceback
637+
logger.error(traceback.format_exc())
638+
sys.exit(1)
639+
640+
logger.info("SEA result set test with multiple chunks completed successfully")
641+
642+
390643
if __name__ == "__main__":
391644
# Test session management
392645
# test_sea_session()
@@ -395,4 +648,7 @@ def test_sea_result_set_arrow_external_links():
395648
# test_sea_result_set_json_array_inline()
396649

397650
# Test result set implementation with ARROW format and EXTERNAL_LINKS disposition
398-
test_sea_result_set_arrow_external_links()
651+
# test_sea_result_set_arrow_external_links()
652+
653+
# Test result set implementation with multiple chunks
654+
test_sea_result_set_with_multiple_chunks()

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def __init__(
2424
ssl_options: SSLOptions,
2525
):
2626
self._pending_links: List[TSparkArrowResultLink] = []
27+
# Add a cache to store downloaded files by row offset
28+
self._downloaded_files_cache = {}
29+
2730
for link in links:
2831
if link.rowCount <= 0:
2932
continue
@@ -56,24 +59,60 @@ def get_next_downloaded_file(
5659
Args:
5760
next_row_offset (int): The offset of the starting row of the next file we want data from.
5861
"""
62+
logger.info(f"ResultFileDownloadManager: get_next_downloaded_file for row offset {next_row_offset}")
63+
64+
# Check if we have this file in the cache
65+
if next_row_offset in self._downloaded_files_cache:
66+
logger.info(f"ResultFileDownloadManager: Found file in cache for row offset {next_row_offset}")
67+
return self._downloaded_files_cache[next_row_offset]
5968

6069
# Make sure the download queue is always full
6170
self._schedule_downloads()
6271

6372
# No more files to download from this batch of links
6473
if len(self._download_tasks) == 0:
74+
logger.info("ResultFileDownloadManager: No more download tasks")
6575
self._shutdown_manager()
6676
return None
6777

78+
# Log all pending download tasks
79+
logger.info(f"ResultFileDownloadManager: {len(self._download_tasks)} download tasks pending")
80+
81+
# Find the task that matches the requested row offset
82+
matching_task_index = None
83+
for i, task in enumerate(self._download_tasks):
84+
if task.done():
85+
try:
86+
file = task.result(timeout=0) # Don't block
87+
logger.info(f"Task {i}: start_row_offset={file.start_row_offset}, row_count={file.row_count}")
88+
if file.start_row_offset == next_row_offset:
89+
matching_task_index = i
90+
break
91+
except Exception as e:
92+
logger.error(f"Error getting task result: {e}")
93+
94+
# If we found a matching task, use it
95+
if matching_task_index is not None:
96+
logger.info(f"ResultFileDownloadManager: Found matching task at index {matching_task_index}")
97+
task = self._download_tasks.pop(matching_task_index)
98+
file = task.result()
99+
# Cache the file for future use
100+
self._downloaded_files_cache[file.start_row_offset] = file
101+
return file
102+
103+
# Otherwise, just use the first task
68104
task = self._download_tasks.pop(0)
69105
# Future's `result()` method will wait for the call to complete, and return
70106
# the value returned by the call. If the call throws an exception - `result()`
71107
# will throw the same exception
72108
file = task.result()
109+
# Cache the file for future use
110+
self._downloaded_files_cache[file.start_row_offset] = file
111+
73112
if (next_row_offset < file.start_row_offset) or (
74113
next_row_offset > file.start_row_offset + file.row_count
75114
):
76-
logger.debug(
115+
logger.warning(
77116
"ResultFileDownloadManager: file does not contain row {}, start {}, row count {}".format(
78117
next_row_offset, file.start_row_offset, file.row_count
79118
)

0 commit comments

Comments
 (0)