1
1
from __future__ import annotations
2
2
3
3
from abc import ABC
4
- from typing import List , Optional , Tuple
4
+ from typing import List , Optional , Tuple , Union
5
+
6
+ try :
7
+ import pyarrow
8
+ except ImportError :
9
+ pyarrow = None
10
+
11
+ import dateutil
5
12
6
13
from databricks .sql .backend .sea .backend import SeaDatabricksClient
7
- from databricks .sql .backend .sea .models .base import ResultData , ResultManifest
14
+ from databricks .sql .backend .sea .models .base import (
15
+ ExternalLink ,
16
+ ResultData ,
17
+ ResultManifest ,
18
+ )
8
19
from databricks .sql .backend .sea .utils .constants import ResultFormat
9
20
from databricks .sql .exc import ProgrammingError
10
- from databricks .sql .utils import ResultSetQueue
21
+ from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
22
+ from databricks .sql .types import SSLOptions
23
+ from databricks .sql .utils import CloudFetchQueue , ResultSetQueue
24
+
25
+ import logging
26
+
27
+ logger = logging .getLogger (__name__ )
11
28
12
29
13
30
class SeaResultSetQueueFactory (ABC ):
@@ -42,8 +59,30 @@ def build_queue(
42
59
return JsonQueue (sea_result_data .data )
43
60
elif manifest .format == ResultFormat .ARROW_STREAM .value :
44
61
# EXTERNAL_LINKS disposition
45
- raise NotImplementedError (
46
- "EXTERNAL_LINKS disposition is not implemented for SEA backend"
62
+ if not max_download_threads :
63
+ raise ValueError (
64
+ "Max download threads is required for EXTERNAL_LINKS disposition"
65
+ )
66
+ if not ssl_options :
67
+ raise ValueError (
68
+ "SSL options are required for EXTERNAL_LINKS disposition"
69
+ )
70
+ if not sea_client :
71
+ raise ValueError (
72
+ "SEA client is required for EXTERNAL_LINKS disposition"
73
+ )
74
+ if not manifest :
75
+ raise ValueError ("Manifest is required for EXTERNAL_LINKS disposition" )
76
+
77
+ return SeaCloudFetchQueue (
78
+ initial_links = sea_result_data .external_links ,
79
+ max_download_threads = max_download_threads ,
80
+ ssl_options = ssl_options ,
81
+ sea_client = sea_client ,
82
+ statement_id = statement_id ,
83
+ total_chunk_count = manifest .total_chunk_count ,
84
+ lz4_compressed = lz4_compressed ,
85
+ description = description ,
47
86
)
48
87
raise ProgrammingError ("Invalid result format" )
49
88
@@ -69,3 +108,134 @@ def remaining_rows(self) -> List[List[str]]:
69
108
slice = self .data_array [self .cur_row_index :]
70
109
self .cur_row_index += len (slice )
71
110
return slice
111
+
112
+
113
+ class SeaCloudFetchQueue (CloudFetchQueue ):
114
+ """Queue implementation for EXTERNAL_LINKS disposition with ARROW format for SEA backend."""
115
+
116
+ def __init__ (
117
+ self ,
118
+ initial_links : List ["ExternalLink" ],
119
+ max_download_threads : int ,
120
+ ssl_options : SSLOptions ,
121
+ sea_client : "SeaDatabricksClient" ,
122
+ statement_id : str ,
123
+ total_chunk_count : int ,
124
+ lz4_compressed : bool = False ,
125
+ description : Optional [List [Tuple ]] = None ,
126
+ ):
127
+ """
128
+ Initialize the SEA CloudFetchQueue.
129
+
130
+ Args:
131
+ initial_links: Initial list of external links to download
132
+ schema_bytes: Arrow schema bytes
133
+ max_download_threads: Maximum number of download threads
134
+ ssl_options: SSL options for downloads
135
+ sea_client: SEA client for fetching additional links
136
+ statement_id: Statement ID for the query
137
+ total_chunk_count: Total number of chunks in the result set
138
+ lz4_compressed: Whether the data is LZ4 compressed
139
+ description: Column descriptions
140
+ """
141
+
142
+ super ().__init__ (
143
+ max_download_threads = max_download_threads ,
144
+ ssl_options = ssl_options ,
145
+ schema_bytes = None ,
146
+ lz4_compressed = lz4_compressed ,
147
+ description = description ,
148
+ )
149
+
150
+ self ._sea_client = sea_client
151
+ self ._statement_id = statement_id
152
+
153
+ logger .debug (
154
+ "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}" .format (
155
+ statement_id , total_chunk_count
156
+ )
157
+ )
158
+
159
+ initial_link = next ((l for l in initial_links if l .chunk_index == 0 ), None )
160
+ if not initial_link :
161
+ raise ValueError ("No initial link found for chunk index 0" )
162
+
163
+ self .download_manager = ResultFileDownloadManager (
164
+ links = [],
165
+ max_download_threads = max_download_threads ,
166
+ lz4_compressed = lz4_compressed ,
167
+ ssl_options = ssl_options ,
168
+ )
169
+
170
+ # Track the current chunk we're processing
171
+ self ._current_chunk_link : Optional ["ExternalLink" ] = initial_link
172
+ self ._download_current_link ()
173
+
174
+ # Initialize table and position
175
+ self .table = self ._create_next_table ()
176
+
177
+ def _convert_to_thrift_link (self , link : "ExternalLink" ) -> TSparkArrowResultLink :
178
+ """Convert SEA external links to Thrift format for compatibility with existing download manager."""
179
+ # Parse the ISO format expiration time
180
+ expiry_time = int (dateutil .parser .parse (link .expiration ).timestamp ())
181
+ return TSparkArrowResultLink (
182
+ fileLink = link .external_link ,
183
+ expiryTime = expiry_time ,
184
+ rowCount = link .row_count ,
185
+ bytesNum = link .byte_count ,
186
+ startRowOffset = link .row_offset ,
187
+ httpHeaders = link .http_headers or {},
188
+ )
189
+
190
+ def _download_current_link (self ):
191
+ """Download the current chunk link."""
192
+ if not self ._current_chunk_link :
193
+ return None
194
+
195
+ if not self .download_manager :
196
+ logger .debug ("SeaCloudFetchQueue: No download manager, returning" )
197
+ return None
198
+
199
+ thrift_link = self ._convert_to_thrift_link (self ._current_chunk_link )
200
+ self .download_manager .add_link (thrift_link )
201
+
202
+ def _progress_chunk_link (self ):
203
+ """Progress to the next chunk link."""
204
+ if not self ._current_chunk_link :
205
+ return None
206
+
207
+ next_chunk_index = self ._current_chunk_link .next_chunk_index
208
+
209
+ if next_chunk_index is None :
210
+ self ._current_chunk_link = None
211
+ return None
212
+
213
+ try :
214
+ self ._current_chunk_link = self ._sea_client .get_chunk_link (
215
+ self ._statement_id , next_chunk_index
216
+ )
217
+ except Exception as e :
218
+ logger .error (
219
+ "SeaCloudFetchQueue: Error fetching link for chunk {}: {}" .format (
220
+ next_chunk_index , e
221
+ )
222
+ )
223
+ return None
224
+
225
+ logger .debug (
226
+ f"SeaCloudFetchQueue: Progressed to link for chunk { next_chunk_index } : { self ._current_chunk_link } "
227
+ )
228
+ self ._download_current_link ()
229
+
230
+ def _create_next_table (self ) -> Union ["pyarrow.Table" , None ]:
231
+ """Create next table by retrieving the logical next downloaded file."""
232
+ if not self ._current_chunk_link :
233
+ logger .debug ("SeaCloudFetchQueue: No current chunk link, returning" )
234
+ return None
235
+
236
+ row_offset = self ._current_chunk_link .row_offset
237
+ arrow_table = self ._create_table_at_offset (row_offset )
238
+
239
+ self ._progress_chunk_link ()
240
+
241
+ return arrow_table
0 commit comments