Skip to content

Commit 179435d

Browse files
averseySirOibaf
authored andcommitted
[FSTORE-1661] Fix backwards compatibility (#469)
1 parent b6684c9 commit 179435d

File tree

1 file changed

+61
-5
lines changed

1 file changed

+61
-5
lines changed

python/hsfs/core/arrow_flight_client.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from hsfs.constructor import query
4040
from hsfs.core.variable_api import VariableApi
4141
from hsfs.storage_connector import StorageConnector
42+
from pyarrow.flight import FlightServerError
4243
from retrying import retry
4344

4445

@@ -96,6 +97,13 @@ def _should_retry_healthcheck(exception):
9697
)
9798

9899

100+
def _should_retry_certificate_registration(exception):
101+
return (
102+
_should_retry_healthcheck(exception)
103+
or _is_feature_query_service_queue_full_error(exception)
104+
)
105+
106+
99107
# Avoid unnecessary client init
100108
def is_data_format_supported(data_format: str, read_options: Optional[Dict[str, Any]]):
101109
if data_format not in ArrowFlightClient.SUPPORTED_FORMATS:
@@ -199,6 +207,12 @@ def __init__(self, disabled_for_session: bool = False):
199207

200208
try:
201209
self._health_check()
210+
if "get-version" in [action.type for action in self._connection.list_actions()]:
211+
self._server_version = self._get_server_version()
212+
else:
213+
self._server_version = None
214+
if self._server_version is None:
215+
self._register_certificates()
202216
except Exception as e:
203217
_logger.debug("Failed to connect to Hopsworks Query Service.")
204218
_logger.exception(e)
@@ -293,6 +307,20 @@ def _health_check(self):
293307
list(self._connection.do_action(action, options=options))
294308
_logger.debug("Healthcheck succeeded.")
295309

310+
@retry(
311+
wait_exponential_multiplier=1000,
312+
stop_max_attempt_number=3,
313+
retry_on_exception=_should_retry,
314+
)
315+
def _get_server_version(self):
316+
_logger.debug("Acquiring the server version of Hopsworks Query Service.")
317+
action = pyarrow.flight.Action("get-version", b"")
318+
options = pyarrow.flight.FlightCallOptions(timeout=self.health_check_timeout)
319+
for res in self._connection.do_action(action, options=options):
320+
version = res.body.to_pybytes()
321+
_logger.debug(f"The HQS server is of version {version}.")
322+
return version
323+
296324
def _should_be_used(self):
297325
if not self._enabled_on_cluster:
298326
_logger.debug(
@@ -331,10 +359,32 @@ def _certificates(self):
331359
cert_key = self._client._cert_key
332360
return {"kstore": kstore, "tstore": tstore, "cert_key": cert_key}
333361

334-
def _certificates_header(self):
362+
def _make_certificates_json(self):
335363
if self._certificates_json is None:
336364
self._certificates_json = json.dumps(self._certificates()).encode("utf-8")
337-
return (b"x-certificates-json", self._certificates_json)
365+
return self._certificates_json
366+
367+
def _certificates_headers(self):
368+
if self._server_version is None:
369+
return []
370+
return [(b"x-certificates-json", self._make_certificates_json())]
371+
372+
@retry(
373+
wait_exponential_multiplier=1000,
374+
stop_max_attempt_number=3,
375+
retry_on_exception=_should_retry_certificate_registration,
376+
)
377+
def _register_certificates(self):
378+
certificates_json = self._make_certificates_json()
379+
certificates_json_buf = pyarrow.py_buffer(certificates_json)
380+
action = pyarrow.flight.Action(
381+
"register-client-certificates", certificates_json_buf
382+
)
383+
# Registering certificates queue time occasionally spike.
384+
options = pyarrow.flight.FlightCallOptions(timeout=self.health_check_timeout)
385+
_logger.debug("Registering client certificates with Hopsworks Query Service.")
386+
self._connection.do_action(action, options=options)
387+
_logger.debug("Client certificates registered.")
338388

339389
def _handle_afs_exception(user_message="None"):
340390
def decorator(func):
@@ -346,7 +396,13 @@ def afs_error_handler_wrapper(instance, *args, **kw):
346396
message = str(e)
347397
_logger.debug("Caught exception in %s: %s", func.__name__, message)
348398
_logger.exception(e)
349-
if _is_feature_query_service_queue_full_error(e):
399+
if instance._server_version is None and (
400+
isinstance(e, FlightServerError)
401+
and "Please register client certificates first." in message
402+
):
403+
instance._register_certificates()
404+
return func(instance, *args, **kw)
405+
elif _is_feature_query_service_queue_full_error(e):
350406
raise FeatureStoreException(
351407
"Hopsworks Query Service is busy right now. Please try again later."
352408
) from e
@@ -384,7 +440,7 @@ def _get_dataset(self, descriptor, timeout=None, dataframe_type="pandas"):
384440
info = self.get_flight_info(descriptor)
385441
_logger.debug("Retrieved flight info: %s. Fetching dataset.", str(info))
386442
options = pyarrow.flight.FlightCallOptions(
387-
timeout=timeout, headers=[self._certificates_header()]
443+
timeout=timeout, headers=self._certificates_headers()
388444
)
389445
reader = self._connection.do_get(info.endpoints[0].ticket, options)
390446
_logger.debug("Dataset fetched. Converting to dataframe %s.", dataframe_type)
@@ -456,7 +512,7 @@ def create_training_dataset(
456512
else self.timeout
457513
)
458514
options = pyarrow.flight.FlightCallOptions(
459-
timeout=timeout, headers=[self._certificates_header()]
515+
timeout=timeout, headers=self._certificates_headers()
460516
)
461517
for result in self._connection.do_action(action, options):
462518
return result.body.to_pybytes()

0 commit comments

Comments
 (0)