Skip to content

Commit 4b456b2

Browse files
move ThriftCloudFetchQueue above SeaCloudFetchQueue
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 74f59b7 commit 4b456b2

File tree

1 file changed

+80
-77
lines changed

1 file changed

+80
-77
lines changed

src/databricks/sql/utils.py

Lines changed: 80 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,83 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
440440
pass
441441

442442

443+
class ThriftCloudFetchQueue(CloudFetchQueue):
444+
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""
445+
446+
def __init__(
447+
self,
448+
schema_bytes,
449+
max_download_threads: int,
450+
ssl_options: SSLOptions,
451+
start_row_offset: int = 0,
452+
result_links: Optional[List[TSparkArrowResultLink]] = None,
453+
lz4_compressed: bool = True,
454+
description: Optional[List[Tuple]] = None,
455+
):
456+
"""
457+
Initialize the Thrift CloudFetchQueue.
458+
459+
Args:
460+
schema_bytes: Table schema in bytes
461+
max_download_threads: Maximum number of downloader thread pool threads
462+
ssl_options: SSL options for downloads
463+
start_row_offset: The offset of the first row of the cloud fetch links
464+
result_links: Links containing the downloadable URL and metadata
465+
lz4_compressed: Whether the files are lz4 compressed
466+
description: Hive table schema description
467+
"""
468+
super().__init__(
469+
max_download_threads=max_download_threads,
470+
ssl_options=ssl_options,
471+
schema_bytes=schema_bytes,
472+
lz4_compressed=lz4_compressed,
473+
description=description,
474+
)
475+
476+
self.start_row_index = start_row_offset
477+
self.result_links = result_links or []
478+
479+
logger.debug(
480+
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
481+
start_row_offset
482+
)
483+
)
484+
if self.result_links:
485+
for result_link in self.result_links:
486+
logger.debug(
487+
"- start row offset: {}, row count: {}".format(
488+
result_link.startRowOffset, result_link.rowCount
489+
)
490+
)
491+
492+
# Initialize download manager
493+
self.download_manager = ResultFileDownloadManager(
494+
links=self.result_links,
495+
max_download_threads=self.max_download_threads,
496+
lz4_compressed=self.lz4_compressed,
497+
ssl_options=self._ssl_options,
498+
)
499+
500+
# Initialize table and position
501+
self.table = self._create_next_table()
502+
503+
def _create_next_table(self) -> Union["pyarrow.Table", None]:
504+
logger.debug(
505+
"ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format(
506+
self.start_row_index
507+
)
508+
)
509+
arrow_table = self._create_table_at_offset(self.start_row_index)
510+
if arrow_table:
511+
self.start_row_index += arrow_table.num_rows
512+
logger.debug(
513+
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
514+
arrow_table.num_rows, self.start_row_index
515+
)
516+
)
517+
return arrow_table
518+
519+
443520
class SeaCloudFetchQueue(CloudFetchQueue):
444521
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
445522

@@ -571,83 +648,6 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
571648
return arrow_table
572649

573650

574-
class ThriftCloudFetchQueue(CloudFetchQueue):
575-
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""
576-
577-
def __init__(
578-
self,
579-
schema_bytes,
580-
max_download_threads: int,
581-
ssl_options: SSLOptions,
582-
start_row_offset: int = 0,
583-
result_links: Optional[List[TSparkArrowResultLink]] = None,
584-
lz4_compressed: bool = True,
585-
description: Optional[List[Tuple]] = None,
586-
):
587-
"""
588-
Initialize the Thrift CloudFetchQueue.
589-
590-
Args:
591-
schema_bytes: Table schema in bytes
592-
max_download_threads: Maximum number of downloader thread pool threads
593-
ssl_options: SSL options for downloads
594-
start_row_offset: The offset of the first row of the cloud fetch links
595-
result_links: Links containing the downloadable URL and metadata
596-
lz4_compressed: Whether the files are lz4 compressed
597-
description: Hive table schema description
598-
"""
599-
super().__init__(
600-
max_download_threads=max_download_threads,
601-
ssl_options=ssl_options,
602-
schema_bytes=schema_bytes,
603-
lz4_compressed=lz4_compressed,
604-
description=description,
605-
)
606-
607-
self.start_row_index = start_row_offset
608-
self.result_links = result_links or []
609-
610-
logger.debug(
611-
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
612-
start_row_offset
613-
)
614-
)
615-
if self.result_links:
616-
for result_link in self.result_links:
617-
logger.debug(
618-
"- start row offset: {}, row count: {}".format(
619-
result_link.startRowOffset, result_link.rowCount
620-
)
621-
)
622-
623-
# Initialize download manager
624-
self.download_manager = ResultFileDownloadManager(
625-
links=self.result_links,
626-
max_download_threads=self.max_download_threads,
627-
lz4_compressed=self.lz4_compressed,
628-
ssl_options=self._ssl_options,
629-
)
630-
631-
# Initialize table and position
632-
self.table = self._create_next_table()
633-
634-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
635-
logger.debug(
636-
"ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format(
637-
self.start_row_index
638-
)
639-
)
640-
arrow_table = self._create_table_at_offset(self.start_row_index)
641-
if arrow_table:
642-
self.start_row_index += arrow_table.num_rows
643-
logger.debug(
644-
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
645-
arrow_table.num_rows, self.start_row_index
646-
)
647-
)
648-
return arrow_table
649-
650-
651651
def _bound(min_x, max_x, x):
652652
"""Bound x by [min_x, max_x]
653653
@@ -894,6 +894,7 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes):
894894
except Exception as e:
895895
raise RuntimeError("Failure to convert arrow based file to arrow table", e)
896896

897+
897898
def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes):
898899
ba = bytearray()
899900
ba += schema_bytes
@@ -908,6 +909,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
908909
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
909910
return arrow_table, n_rows
910911

912+
911913
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
912914
new_columns = []
913915
new_fields = []
@@ -935,6 +937,7 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
935937

936938
return pyarrow.Table.from_arrays(new_columns, schema=new_schema)
937939

940+
938941
def convert_to_assigned_datatypes_in_column_table(column_table, description):
939942
converted_column_table = []
940943
for i, col in enumerate(column_table):

0 commit comments

Comments
 (0)