39
39
from hsfs .constructor import query
40
40
from hsfs .core .variable_api import VariableApi
41
41
from hsfs .storage_connector import StorageConnector
42
+ from pyarrow .flight import FlightServerError
42
43
from retrying import retry
43
44
44
45
@@ -96,6 +97,13 @@ def _should_retry_healthcheck(exception):
96
97
)
97
98
98
99
100
+ def _should_retry_certificate_registration (exception ):
101
+ return (
102
+ _should_retry_healthcheck (exception )
103
+ or _is_feature_query_service_queue_full_error (exception )
104
+ )
105
+
106
+
99
107
# Avoid unnecessary client init
100
108
def is_data_format_supported (data_format : str , read_options : Optional [Dict [str , Any ]]):
101
109
if data_format not in ArrowFlightClient .SUPPORTED_FORMATS :
@@ -199,6 +207,12 @@ def __init__(self, disabled_for_session: bool = False):
199
207
200
208
try :
201
209
self ._health_check ()
210
+ if "get-version" in [action .type for action in self ._connection .list_actions ()]:
211
+ self ._server_version = self ._get_server_version ()
212
+ else :
213
+ self ._server_version = None
214
+ if self ._server_version is None :
215
+ self ._register_certificates ()
202
216
except Exception as e :
203
217
_logger .debug ("Failed to connect to Hopsworks Query Service." )
204
218
_logger .exception (e )
@@ -293,6 +307,20 @@ def _health_check(self):
293
307
list (self ._connection .do_action (action , options = options ))
294
308
_logger .debug ("Healthcheck succeeded." )
295
309
310
+ @retry (
311
+ wait_exponential_multiplier = 1000 ,
312
+ stop_max_attempt_number = 3 ,
313
+ retry_on_exception = _should_retry ,
314
+ )
315
+ def _get_server_version (self ):
316
+ _logger .debug ("Acquiring the server version of Hopsworks Query Service." )
317
+ action = pyarrow .flight .Action ("get-version" , b"" )
318
+ options = pyarrow .flight .FlightCallOptions (timeout = self .health_check_timeout )
319
+ for res in self ._connection .do_action (action , options = options ):
320
+ version = res .body .to_pybytes ()
321
+ _logger .debug (f"The HQS server is of version { version } ." )
322
+ return version
323
+
296
324
def _should_be_used (self ):
297
325
if not self ._enabled_on_cluster :
298
326
_logger .debug (
@@ -331,10 +359,32 @@ def _certificates(self):
331
359
cert_key = self ._client ._cert_key
332
360
return {"kstore" : kstore , "tstore" : tstore , "cert_key" : cert_key }
333
361
334
- def _certificates_header (self ):
362
+ def _make_certificates_json (self ):
335
363
if self ._certificates_json is None :
336
364
self ._certificates_json = json .dumps (self ._certificates ()).encode ("utf-8" )
337
- return (b"x-certificates-json" , self ._certificates_json )
365
+ return self ._certificates_json
366
+
367
+ def _certificates_headers (self ):
368
+ if self ._server_version is None :
369
+ return []
370
+ return [(b"x-certificates-json" , self ._make_certificates_json ())]
371
+
372
+ @retry (
373
+ wait_exponential_multiplier = 1000 ,
374
+ stop_max_attempt_number = 3 ,
375
+ retry_on_exception = _should_retry_certificate_registration ,
376
+ )
377
+ def _register_certificates (self ):
378
+ certificates_json = self ._make_certificates_json ()
379
+ certificates_json_buf = pyarrow .py_buffer (certificates_json )
380
+ action = pyarrow .flight .Action (
381
+ "register-client-certificates" , certificates_json_buf
382
+ )
383
+ # Registering certificates queue time occasionally spike.
384
+ options = pyarrow .flight .FlightCallOptions (timeout = self .health_check_timeout )
385
+ _logger .debug ("Registering client certificates with Hopsworks Query Service." )
386
+ self ._connection .do_action (action , options = options )
387
+ _logger .debug ("Client certificates registered." )
338
388
339
389
def _handle_afs_exception (user_message = "None" ):
340
390
def decorator (func ):
@@ -346,7 +396,13 @@ def afs_error_handler_wrapper(instance, *args, **kw):
346
396
message = str (e )
347
397
_logger .debug ("Caught exception in %s: %s" , func .__name__ , message )
348
398
_logger .exception (e )
349
- if _is_feature_query_service_queue_full_error (e ):
399
+ if instance ._server_version is None and (
400
+ isinstance (e , FlightServerError )
401
+ and "Please register client certificates first." in message
402
+ ):
403
+ instance ._register_certificates ()
404
+ return func (instance , * args , ** kw )
405
+ elif _is_feature_query_service_queue_full_error (e ):
350
406
raise FeatureStoreException (
351
407
"Hopsworks Query Service is busy right now. Please try again later."
352
408
) from e
@@ -384,7 +440,7 @@ def _get_dataset(self, descriptor, timeout=None, dataframe_type="pandas"):
384
440
info = self .get_flight_info (descriptor )
385
441
_logger .debug ("Retrieved flight info: %s. Fetching dataset." , str (info ))
386
442
options = pyarrow .flight .FlightCallOptions (
387
- timeout = timeout , headers = [ self ._certificates_header ()]
443
+ timeout = timeout , headers = self ._certificates_headers ()
388
444
)
389
445
reader = self ._connection .do_get (info .endpoints [0 ].ticket , options )
390
446
_logger .debug ("Dataset fetched. Converting to dataframe %s." , dataframe_type )
@@ -456,7 +512,7 @@ def create_training_dataset(
456
512
else self .timeout
457
513
)
458
514
options = pyarrow .flight .FlightCallOptions (
459
- timeout = timeout , headers = [ self ._certificates_header ()]
515
+ timeout = timeout , headers = self ._certificates_headers ()
460
516
)
461
517
for result in self ._connection .do_action (action , options ):
462
518
return result .body .to_pybytes ()
0 commit comments