@@ -440,6 +440,83 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
440
440
pass
441
441
442
442
443
+ class ThriftCloudFetchQueue (CloudFetchQueue ):
444
+ """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""
445
+
446
+ def __init__ (
447
+ self ,
448
+ schema_bytes ,
449
+ max_download_threads : int ,
450
+ ssl_options : SSLOptions ,
451
+ start_row_offset : int = 0 ,
452
+ result_links : Optional [List [TSparkArrowResultLink ]] = None ,
453
+ lz4_compressed : bool = True ,
454
+ description : Optional [List [Tuple ]] = None ,
455
+ ):
456
+ """
457
+ Initialize the Thrift CloudFetchQueue.
458
+
459
+ Args:
460
+ schema_bytes: Table schema in bytes
461
+ max_download_threads: Maximum number of downloader thread pool threads
462
+ ssl_options: SSL options for downloads
463
+ start_row_offset: The offset of the first row of the cloud fetch links
464
+ result_links: Links containing the downloadable URL and metadata
465
+ lz4_compressed: Whether the files are lz4 compressed
466
+ description: Hive table schema description
467
+ """
468
+ super ().__init__ (
469
+ max_download_threads = max_download_threads ,
470
+ ssl_options = ssl_options ,
471
+ schema_bytes = schema_bytes ,
472
+ lz4_compressed = lz4_compressed ,
473
+ description = description ,
474
+ )
475
+
476
+ self .start_row_index = start_row_offset
477
+ self .result_links = result_links or []
478
+
479
+ logger .debug (
480
+ "Initialize CloudFetch loader, row set start offset: {}, file list:" .format (
481
+ start_row_offset
482
+ )
483
+ )
484
+ if self .result_links :
485
+ for result_link in self .result_links :
486
+ logger .debug (
487
+ "- start row offset: {}, row count: {}" .format (
488
+ result_link .startRowOffset , result_link .rowCount
489
+ )
490
+ )
491
+
492
+ # Initialize download manager
493
+ self .download_manager = ResultFileDownloadManager (
494
+ links = self .result_links ,
495
+ max_download_threads = self .max_download_threads ,
496
+ lz4_compressed = self .lz4_compressed ,
497
+ ssl_options = self ._ssl_options ,
498
+ )
499
+
500
+ # Initialize table and position
501
+ self .table = self ._create_next_table ()
502
+
503
+ def _create_next_table (self ) -> Union ["pyarrow.Table" , None ]:
504
+ logger .debug (
505
+ "ThriftCloudFetchQueue: Trying to get downloaded file for row {}" .format (
506
+ self .start_row_index
507
+ )
508
+ )
509
+ arrow_table = self ._create_table_at_offset (self .start_row_index )
510
+ if arrow_table :
511
+ self .start_row_index += arrow_table .num_rows
512
+ logger .debug (
513
+ "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}" .format (
514
+ arrow_table .num_rows , self .start_row_index
515
+ )
516
+ )
517
+ return arrow_table
518
+
519
+
443
520
class SeaCloudFetchQueue (CloudFetchQueue ):
444
521
"""Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
445
522
@@ -571,83 +648,6 @@ def _create_next_table(self) -> Union["pyarrow.Table", None]:
571
648
return arrow_table
572
649
573
650
574
- class ThriftCloudFetchQueue (CloudFetchQueue ):
575
- """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for Thrift backend."""
576
-
577
- def __init__ (
578
- self ,
579
- schema_bytes ,
580
- max_download_threads : int ,
581
- ssl_options : SSLOptions ,
582
- start_row_offset : int = 0 ,
583
- result_links : Optional [List [TSparkArrowResultLink ]] = None ,
584
- lz4_compressed : bool = True ,
585
- description : Optional [List [Tuple ]] = None ,
586
- ):
587
- """
588
- Initialize the Thrift CloudFetchQueue.
589
-
590
- Args:
591
- schema_bytes: Table schema in bytes
592
- max_download_threads: Maximum number of downloader thread pool threads
593
- ssl_options: SSL options for downloads
594
- start_row_offset: The offset of the first row of the cloud fetch links
595
- result_links: Links containing the downloadable URL and metadata
596
- lz4_compressed: Whether the files are lz4 compressed
597
- description: Hive table schema description
598
- """
599
- super ().__init__ (
600
- max_download_threads = max_download_threads ,
601
- ssl_options = ssl_options ,
602
- schema_bytes = schema_bytes ,
603
- lz4_compressed = lz4_compressed ,
604
- description = description ,
605
- )
606
-
607
- self .start_row_index = start_row_offset
608
- self .result_links = result_links or []
609
-
610
- logger .debug (
611
- "Initialize CloudFetch loader, row set start offset: {}, file list:" .format (
612
- start_row_offset
613
- )
614
- )
615
- if self .result_links :
616
- for result_link in self .result_links :
617
- logger .debug (
618
- "- start row offset: {}, row count: {}" .format (
619
- result_link .startRowOffset , result_link .rowCount
620
- )
621
- )
622
-
623
- # Initialize download manager
624
- self .download_manager = ResultFileDownloadManager (
625
- links = self .result_links ,
626
- max_download_threads = self .max_download_threads ,
627
- lz4_compressed = self .lz4_compressed ,
628
- ssl_options = self ._ssl_options ,
629
- )
630
-
631
- # Initialize table and position
632
- self .table = self ._create_next_table ()
633
-
634
- def _create_next_table (self ) -> Union ["pyarrow.Table" , None ]:
635
- logger .debug (
636
- "ThriftCloudFetchQueue: Trying to get downloaded file for row {}" .format (
637
- self .start_row_index
638
- )
639
- )
640
- arrow_table = self ._create_table_at_offset (self .start_row_index )
641
- if arrow_table :
642
- self .start_row_index += arrow_table .num_rows
643
- logger .debug (
644
- "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}" .format (
645
- arrow_table .num_rows , self .start_row_index
646
- )
647
- )
648
- return arrow_table
649
-
650
-
651
651
def _bound (min_x , max_x , x ):
652
652
"""Bound x by [min_x, max_x]
653
653
@@ -894,6 +894,7 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes):
894
894
except Exception as e :
895
895
raise RuntimeError ("Failure to convert arrow based file to arrow table" , e )
896
896
897
+
897
898
def convert_arrow_based_set_to_arrow_table (arrow_batches , lz4_compressed , schema_bytes ):
898
899
ba = bytearray ()
899
900
ba += schema_bytes
@@ -908,6 +909,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
908
909
arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
909
910
return arrow_table , n_rows
910
911
912
+
911
913
def convert_decimals_in_arrow_table (table , description ) -> "pyarrow.Table" :
912
914
new_columns = []
913
915
new_fields = []
@@ -935,6 +937,7 @@ def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":
935
937
936
938
return pyarrow .Table .from_arrays (new_columns , schema = new_schema )
937
939
940
+
938
941
def convert_to_assigned_datatypes_in_column_table (column_table , description ):
939
942
converted_column_table = []
940
943
for i , col in enumerate (column_table ):
0 commit comments