Skip to content

Commit 961873a

Browse files
clean up SeaCloudFetchQueue
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent d5322eb commit 961873a

File tree

1 file changed

+155
-136
lines changed

1 file changed

+155
-136
lines changed

src/databricks/sql/cloud_fetch_queue.py

Lines changed: 155 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,19 @@ def __init__(
159159
self.table = None
160160
self.table_row_index = 0
161161

162-
# Initialize download manager - subclasses must set this
163-
self.download_manager: Optional[ResultFileDownloadManager] = None
162+
# Initialize download manager - will be set by subclasses
163+
self.download_manager: Optional["ResultFileDownloadManager"] = None
164164

165165
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
166166
"""Get up to the next n rows of the cloud fetch Arrow dataframes."""
167167
if not self.table:
168168
# Return empty pyarrow table to cause retry of fetch
169+
logger.debug(
170+
"SeaCloudFetchQueue: No table available, returning empty table"
171+
)
169172
return self._create_empty_table()
170173

174+
logger.debug("SeaCloudFetchQueue: Retrieving up to {} rows".format(num_rows))
171175
results = pyarrow.Table.from_pydict({}) # Empty table
172176
while num_rows > 0 and self.table:
173177
# Get remaining of num_rows or the rest of the current table, whichever is smaller
@@ -184,11 +188,15 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
184188

185189
# Replace current table with the next table if we are at the end of the current table
186190
if self.table_row_index == self.table.num_rows:
191+
logger.debug(
192+
"SeaCloudFetchQueue: Reached end of current table, fetching next"
193+
)
187194
self.table = self._create_next_table()
188195
self.table_row_index = 0
189196

190197
num_rows -= table_slice.num_rows
191198

199+
logger.debug("SeaCloudFetchQueue: Retrieved {} rows".format(results.num_rows))
192200
return results
193201

194202
@abstractmethod
@@ -247,15 +255,29 @@ def __init__(
247255
self._statement_id = statement_id
248256
self._total_chunk_count = total_chunk_count
249257

250-
# Track which links we've already fetched
251-
self._fetched_chunk_indices = set()
252-
for link in initial_links:
253-
self._fetched_chunk_indices.add(link.chunk_index)
258+
# Track the current chunk we're processing
259+
self._current_chunk_index: Optional[int] = None
260+
self._current_chunk_link: Optional["ExternalLink"] = None
254261

255-
# Create a mapping from chunk index to link
256-
self._chunk_index_to_link = {link.chunk_index: link for link in initial_links}
262+
logger.debug(
263+
"SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format(
264+
statement_id, total_chunk_count
265+
)
266+
)
257267

258-
# Initialize download manager
268+
if initial_links:
269+
logger.debug("SeaCloudFetchQueue: Initial links provided:")
270+
for link in initial_links:
271+
logger.debug(
272+
"- chunk: {}, row offset: {}, row count: {}, next chunk: {}".format(
273+
link.chunk_index,
274+
link.row_offset,
275+
link.row_count,
276+
link.next_chunk_index,
277+
)
278+
)
279+
280+
# Initialize download manager with initial links
259281
self.download_manager = ResultFileDownloadManager(
260282
links=self._convert_to_thrift_links(initial_links),
261283
max_download_threads=max_download_threads,
@@ -265,11 +287,26 @@ def __init__(
265287

266288
# Initialize table and position
267289
self.table = self._create_next_table()
290+
if self.table:
291+
logger.debug(
292+
"SeaCloudFetchQueue: Initial table created with {} rows".format(
293+
self.table.num_rows
294+
)
295+
)
268296

269297
def _convert_to_thrift_links(
270298
self, links: List["ExternalLink"]
271299
) -> List[TSparkArrowResultLink]:
272300
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
301+
if not links:
302+
logger.debug("SeaCloudFetchQueue: No links to convert to Thrift format")
303+
return []
304+
305+
logger.debug(
306+
"SeaCloudFetchQueue: Converting {} links to Thrift format".format(
307+
len(links)
308+
)
309+
)
273310
thrift_links = []
274311
for link in links:
275312
# Parse the ISO format expiration time
@@ -286,170 +323,146 @@ def _convert_to_thrift_links(
286323
thrift_links.append(thrift_link)
287324
return thrift_links
288325

289-
def _fetch_links_for_chunk(self, chunk_index: int) -> List["ExternalLink"]:
290-
"""Fetch links for the specified chunk index."""
291-
if chunk_index in self._fetched_chunk_indices:
292-
return [self._chunk_index_to_link[chunk_index]]
293-
294-
# Find the link that has this chunk_index as its next_chunk_index
295-
next_chunk_link = None
296-
next_chunk_internal_link = None
297-
for link in self._chunk_index_to_link.values():
298-
if link.next_chunk_index == chunk_index:
299-
next_chunk_link = link
300-
next_chunk_internal_link = link.next_chunk_internal_link
301-
break
302-
303-
if not next_chunk_internal_link:
304-
# If we can't find a link with next_chunk_index, we can't fetch the chunk
305-
logger.warning(
306-
f"Cannot find next_chunk_internal_link for chunk {chunk_index}"
326+
def _fetch_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
327+
"""Fetch link for the specified chunk index."""
328+
# Check if we already have this chunk as our current chunk
329+
if (
330+
self._current_chunk_link
331+
and self._current_chunk_link.chunk_index == chunk_index
332+
):
333+
logger.debug(
334+
"SeaCloudFetchQueue: Already have current chunk {}".format(chunk_index)
307335
)
308-
return []
336+
return self._current_chunk_link
309337

310-
logger.info(f"Fetching chunk {chunk_index} using SEA client")
338+
# We need to fetch this chunk
339+
logger.debug(
340+
"SeaCloudFetchQueue: Fetching chunk {} using SEA client".format(chunk_index)
341+
)
311342

312343
# Use the SEA client to fetch the chunk links
313344
links = self._sea_client.fetch_chunk_links(self._statement_id, chunk_index)
314345

315-
# Update our tracking
316-
for link in links:
317-
self._fetched_chunk_indices.add(link.chunk_index)
318-
self._chunk_index_to_link[link.chunk_index] = link
346+
if not links:
347+
logger.debug(
348+
"SeaCloudFetchQueue: No links found for chunk {}".format(chunk_index)
349+
)
350+
return None
319351

320-
# Log link details
321-
logger.info(
322-
f"Link details: chunk_index={link.chunk_index}, row_offset={link.row_offset}, row_count={link.row_count}, next_chunk_index={link.next_chunk_index}"
352+
# Get the link for the requested chunk
353+
link = next((l for l in links if l.chunk_index == chunk_index), None)
354+
355+
if link:
356+
logger.debug(
357+
"SeaCloudFetchQueue: Link details for chunk {}: row_offset={}, row_count={}, next_chunk_index={}".format(
358+
link.chunk_index,
359+
link.row_offset,
360+
link.row_count,
361+
link.next_chunk_index,
362+
)
323363
)
324364

325-
# Add to download manager
326-
if self.download_manager:
327-
self.download_manager.add_links(self._convert_to_thrift_links(links))
365+
if self.download_manager:
366+
self.download_manager.add_links(self._convert_to_thrift_links([link]))
328367

329-
return links
368+
return link
330369

331370
def remaining_rows(self) -> "pyarrow.Table":
332371
"""Get all remaining rows of the cloud fetch Arrow dataframes."""
333372
if not self.table:
334373
# Return empty pyarrow table to cause retry of fetch
374+
logger.debug(
375+
"SeaCloudFetchQueue: No table available, returning empty table"
376+
)
335377
return self._create_empty_table()
336378

379+
logger.debug("SeaCloudFetchQueue: Retrieving all remaining rows")
337380
results = pyarrow.Table.from_pydict({}) # Empty table
338-
339-
# First, fetch the current table's remaining rows
340-
if self.table_row_index < self.table.num_rows:
381+
while self.table:
341382
table_slice = self.table.slice(
342383
self.table_row_index, self.table.num_rows - self.table_row_index
343384
)
344-
results = table_slice
345-
self.table_row_index += table_slice.num_rows
385+
if results.num_rows > 0:
386+
results = pyarrow.concat_tables([results, table_slice])
387+
else:
388+
results = table_slice
346389

347-
# Now, try to fetch all remaining chunks
348-
for chunk_index in range(self._total_chunk_count):
349-
if chunk_index not in self._fetched_chunk_indices:
350-
try:
351-
# Try to fetch this chunk
352-
self._fetch_links_for_chunk(chunk_index)
353-
except Exception as e:
354-
logger.error(f"Error fetching chunk {chunk_index}: {e}")
355-
continue
356-
357-
# If we successfully fetched the chunk, get its data
358-
if chunk_index in self._fetched_chunk_indices:
359-
link = self._chunk_index_to_link[chunk_index]
360-
downloaded_file = self.download_manager.get_next_downloaded_file(
361-
link.row_offset
362-
)
363-
if downloaded_file:
364-
arrow_table = create_arrow_table_from_arrow_file(
365-
downloaded_file.file_bytes, self.description
366-
)
367-
368-
# Ensure the table has the correct number of rows
369-
if arrow_table.num_rows > downloaded_file.row_count:
370-
arrow_table = arrow_table.slice(
371-
0, downloaded_file.row_count
372-
)
373-
374-
# Concatenate with results
375-
if results.num_rows > 0:
376-
results = pyarrow.concat_tables([results, arrow_table])
377-
else:
378-
results = arrow_table
379-
380-
self.table = None # We've fetched everything, so clear the current table
381-
self.table_row_index = 0
390+
self.table_row_index += table_slice.num_rows
391+
self.table = self._create_next_table()
392+
self.table_row_index = 0
382393

394+
logger.debug(
395+
"SeaCloudFetchQueue: Retrieved {} total rows".format(results.num_rows)
396+
)
383397
return results
384398

385399
def _create_next_table(self) -> Union["pyarrow.Table", None]:
386400
"""Create next table by retrieving the logical next downloaded file."""
387-
# Get the next chunk index based on current state
388-
next_chunk_index = 0
389-
if self.table is not None:
390-
# Find the current chunk we're processing
391-
current_chunk = None
392-
for chunk_index, link in self._chunk_index_to_link.items():
393-
# We're looking for the chunk that contains our current position
394-
if (
395-
link.row_offset
396-
<= self.table_row_index
397-
< link.row_offset + link.row_count
398-
):
399-
current_chunk = link
400-
break
401-
402-
if current_chunk and current_chunk.next_chunk_index is not None:
403-
next_chunk_index = current_chunk.next_chunk_index
404-
logger.info(
405-
f"Found next_chunk_index {next_chunk_index} from current chunk {current_chunk.chunk_index}"
406-
)
407-
else:
408-
# If we can't find the next chunk, try to fetch the next sequential one
409-
next_chunk_index = (
410-
max(self._fetched_chunk_indices) + 1
411-
if self._fetched_chunk_indices
412-
else 0
413-
)
414-
logger.info(f"Using sequential next_chunk_index {next_chunk_index}")
401+
# if we're still processing the current table, just return it
402+
if self.table is not None and self.table_row_index < self.table.num_rows:
403+
return self.table
404+
405+
# if we've reached the end of the response, return None
406+
if (
407+
self._current_chunk_link
408+
and self._current_chunk_link.next_chunk_index is None
409+
):
410+
logger.debug(
411+
"SeaCloudFetchQueue: Reached end of chunks (no next chunk index)"
412+
)
413+
return None
415414

416-
# Check if we've reached the end of all chunks
417-
if next_chunk_index >= self._total_chunk_count:
418-
logger.info(
419-
f"Reached end of chunks: next_chunk_index {next_chunk_index} >= total_chunk_count {self._total_chunk_count}"
415+
# Determine the next chunk index
416+
next_chunk_index = (
417+
0
418+
if self._current_chunk_link is None
419+
else self._current_chunk_link.next_chunk_index
420+
)
421+
if next_chunk_index is None: # This can happen if we're at the end of chunks
422+
logger.debug(
423+
"SeaCloudFetchQueue: Reached end of chunks (next_chunk_index is None)"
420424
)
421425
return None
422426

423-
# Check if we need to fetch links for this chunk
424-
if next_chunk_index not in self._fetched_chunk_indices:
425-
try:
426-
logger.info(f"Fetching links for chunk {next_chunk_index}")
427-
self._fetch_links_for_chunk(next_chunk_index)
428-
except Exception as e:
429-
logger.error(f"Error fetching links for chunk {next_chunk_index}: {e}")
430-
# If we can't fetch the next chunk, try to return what we have
431-
return None
432-
else:
433-
logger.info(f"Already have links for chunk {next_chunk_index}")
427+
logger.debug(
428+
"SeaCloudFetchQueue: Trying to get downloaded file for chunk {}".format(
429+
next_chunk_index
430+
)
431+
)
434432

435-
# Find the next downloaded file
436-
link = self._chunk_index_to_link.get(next_chunk_index)
437-
if not link:
438-
logger.error(f"No link found for chunk {next_chunk_index}")
433+
# Update current chunk to the next one
434+
self._current_chunk_index = next_chunk_index
435+
self._current_chunk_link = None
436+
try:
437+
self._current_chunk_link = self._fetch_chunk_link(next_chunk_index)
438+
except Exception as e:
439+
logger.error(
440+
"SeaCloudFetchQueue: Error fetching link for chunk {}: {}".format(
441+
self._current_chunk_index, e
442+
)
443+
)
444+
return None
445+
if not self._current_chunk_link:
446+
logger.error(
447+
"SeaCloudFetchQueue: No link found for chunk {}".format(
448+
self._current_chunk_index
449+
)
450+
)
439451
return None
440452

441-
row_offset = link.row_offset
442-
logger.info(
443-
f"Getting downloaded file for chunk {next_chunk_index} with row_offset {row_offset}"
444-
)
453+
# Get the data for the current chunk
454+
row_offset = self._current_chunk_link.row_offset
455+
445456
if not self.download_manager:
446-
logger.error(f"No download manager available")
457+
logger.debug("SeaCloudFetchQueue: No download manager available")
447458
return None
448459

449460
downloaded_file = self.download_manager.get_next_downloaded_file(row_offset)
450461
if not downloaded_file:
451-
logger.error(
452-
f"No downloaded file found for chunk {next_chunk_index} with row_offset {row_offset}"
462+
logger.debug(
463+
"SeaCloudFetchQueue: Cannot find downloaded file for row {}".format(
464+
row_offset
465+
)
453466
)
454467
return None
455468

@@ -461,9 +474,15 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
461474
if arrow_table.num_rows > downloaded_file.row_count:
462475
arrow_table = arrow_table.slice(0, downloaded_file.row_count)
463476

464-
logger.info(
465-
f"Created arrow table for chunk {next_chunk_index} with {arrow_table.num_rows} rows"
477+
# At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows
478+
assert downloaded_file.row_count == arrow_table.num_rows
479+
480+
logger.debug(
481+
"SeaCloudFetchQueue: Found downloaded file for chunk {}, row count: {}, row offset: {}".format(
482+
self._current_chunk_index, arrow_table.num_rows, row_offset
483+
)
466484
)
485+
467486
return arrow_table
468487

469488

0 commit comments

Comments
 (0)