Skip to content

Commit b6684c9

Browse files
averseySirOibaf
authored andcommitted
[FSTORE-1661] Fix Arrow Flight certificate registration (#454)
1 parent 0e64745 commit b6684c9

File tree

1 file changed

+40
-62
lines changed

1 file changed

+40
-62
lines changed

python/hsfs/core/arrow_flight_client.py

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from hsfs.constructor import query
4040
from hsfs.core.variable_api import VariableApi
4141
from hsfs.storage_connector import StorageConnector
42-
from pyarrow.flight import FlightServerError
4342
from retrying import retry
4443

4544

@@ -67,7 +66,7 @@ def close() -> None:
6766

6867
def _disable_feature_query_service_client():
6968
global _arrow_flight_instance
70-
_logger.debug("Disabling Hopsworks Feature Query Service Client.")
69+
_logger.debug("Disabling Hopsworks Query Service Client.")
7170
if _arrow_flight_instance is None:
7271
_arrow_flight_instance.ArrowFlightClient(disabled_for_session=True)
7372
else:
@@ -90,12 +89,10 @@ def _is_no_commits_found_error(exception):
9089
) and "No commits found" in str(exception)
9190

9291

93-
def _should_retry_healthcheck_or_certificate_registration(exception):
92+
def _should_retry_healthcheck(exception):
9493
return (
9594
isinstance(exception, pyarrow._flight.FlightUnavailableError)
9695
or isinstance(exception, pyarrow._flight.FlightTimedOutError)
97-
# not applicable for healthcheck, only certificate registration
98-
or _is_feature_query_service_queue_full_error(exception)
9996
)
10097

10198

@@ -151,8 +148,8 @@ class ArrowFlightClient:
151148
StorageConnector.SNOWFLAKE,
152149
StorageConnector.BIGQUERY,
153150
]
154-
READ_ERROR = "Could not read data using Hopsworks Feature Query Service."
155-
WRITE_ERROR = 'Could not write data using Hopsworks Feature Query Service. If the issue persists, use write_options={"use_spark": True} instead.'
151+
READ_ERROR = "Could not read data using Hopsworks Query Service."
152+
WRITE_ERROR = 'Could not write data using Hopsworks Query Service. If the issue persists, use write_options={"use_spark": True} instead.'
156153
DEFAULTING_TO_DIFFERENT_SERVICE_WARNING = (
157154
"Defaulting to Spark execution for this call."
158155
)
@@ -162,7 +159,7 @@ class ArrowFlightClient:
162159
DEFAULT_GRPC_MIN_RECONNECT_BACKOFF_MS = 2000
163160

164161
def __init__(self, disabled_for_session: bool = False):
165-
_logger.debug("Initializing Hopsworks Feature Query Service Client.")
162+
_logger.debug("Initializing Hopsworks Query Service Client.")
166163
self._timeout: float = ArrowFlightClient.DEFAULT_TIMEOUT_SECONDS
167164
self._health_check_timeout: float = (
168165
ArrowFlightClient.DEFAULT_HEALTHCHECK_TIMEOUT_SECONDS
@@ -179,36 +176,34 @@ def __init__(self, disabled_for_session: bool = False):
179176

180177
self._client = client.get_instance()
181178
self._variable_api: VariableApi = VariableApi()
179+
self._certificates_json: Optional[str] = None
182180

183181
try:
184182
self._check_cluster_service_enabled()
185183
self._host_url = self._retrieve_host_url()
186184

187185
if self._enabled_on_cluster:
188-
_logger.debug(
189-
"Hopsworks Feature Query Service is enabled on the cluster."
190-
)
186+
_logger.debug("Hopsworks Query Service is enabled on the cluster.")
191187
self._initialize_flight_client()
192188
else:
193189
_logger.debug(
194-
"Hopsworks Feature Query Service Client is not enabled on the cluster or a cluster variable is misconfigured."
190+
"Hopsworks Query Service Client is not enabled on the cluster or a cluster variable is misconfigured."
195191
)
196192
self._disable_for_session()
197193
return
198194
except Exception as e:
199-
_logger.debug("Failed to connect to Hopsworks Feature Query Service")
195+
_logger.debug("Failed to connect to Hopsworks Query Service")
200196
_logger.exception(e)
201197
self._disable_for_session(str(e))
202198
return
203199

204200
try:
205201
self._health_check()
206-
self._register_certificates()
207202
except Exception as e:
208-
_logger.debug("Failed to connect to Hopsworks Feature Query Service.")
203+
_logger.debug("Failed to connect to Hopsworks Query Service.")
209204
_logger.exception(e)
210205
warnings.warn(
211-
f"Failed to connect to Hopsworks Feature Query Service, got {str(e)}."
206+
f"Failed to connect to Hopsworks Query Service, got {str(e)}."
212207
+ ArrowFlightClient.DEFAULTING_TO_DIFFERENT_SERVICE_WARNING
213208
+ ArrowFlightClient.CLIENT_WILL_STAY_ACTIVE_WARNING,
214209
stacklevel=1,
@@ -218,13 +213,13 @@ def __init__(self, disabled_for_session: bool = False):
218213
def _check_cluster_service_enabled(self) -> None:
219214
try:
220215
_logger.debug(
221-
"Connecting to Hopsworks Cluster to check if Feature Query Service is enabled."
216+
"Connecting to Hopsworks Cluster to check if Hopsworks Query Service is enabled."
222217
)
223218
self._enabled_on_cluster = self._variable_api.get_flyingduck_enabled()
224219
except Exception as e:
225220
# if feature flag cannot be retrieved, assume it is disabled
226221
_logger.debug(
227-
"Unable to fetch Hopsworks Feature Query Service (HQFS) flag, disabling HFQS client."
222+
"Unable to fetch Hopsworks Query Service flag, disabling its client."
228223
)
229224
_logger.exception(e)
230225
self._enabled_on_cluster = False
@@ -240,13 +235,11 @@ def _retrieve_host_url(self) -> Optional[str]:
240235
service_discovery_domain = self._variable_api.get_service_discovery_domain()
241236
if service_discovery_domain == "":
242237
raise FeatureStoreException(
243-
"Client could not get Feature Query Service hostname from service_discovery_domain. "
238+
"Client could not get Hopsworks Query Service hostname from service_discovery_domain. "
244239
"The variable is either not set or empty in Hopsworks cluster configuration."
245240
)
246241
host_url = f"grpc+tls://flyingduck.service.{service_discovery_domain}:5005"
247-
_logger.debug(
248-
f"Connecting to Hopsworks Feature Query Service on host {host_url}"
249-
)
242+
_logger.debug(f"Connecting to Hopsworks Query Service on host {host_url}")
250243
return host_url
251244

252245
def _disable_for_session(
@@ -255,17 +248,17 @@ def _disable_for_session(
255248
self._disabled_for_session = True
256249
if on_purpose:
257250
warnings.warn(
258-
"Hopsworks Feature Query Service will be disabled for this session.",
251+
"Hopsworks Query Service will be disabled for this session.",
259252
stacklevel=1,
260253
)
261254
if self._enabled_on_cluster:
262255
warnings.warn(
263-
"Hospworks Feature Query Service is disabled on cluster. Contact your administrator to enable it.",
256+
"Hospworks Query Service is disabled on cluster. Contact your administrator to enable it.",
264257
stacklevel=1,
265258
)
266259
else:
267260
warnings.warn(
268-
f"Client initialisation failed: {message}. Hopsworks Feature Query Service will be disabled for this session."
261+
f"Client initialisation failed: {message}. Hopsworks Query Service will be disabled for this session."
269262
"If you believe this is a transient error, you can call `(hopsworks.)hsfs.reset_offline_query_service_client()`"
270263
" to re-enable it.",
271264
stacklevel=1,
@@ -291,10 +284,10 @@ def _initialize_flight_client(self):
291284
@retry(
292285
wait_exponential_multiplier=1000,
293286
stop_max_attempt_number=5,
294-
retry_on_exception=_should_retry_healthcheck_or_certificate_registration,
287+
retry_on_exception=_should_retry_healthcheck,
295288
)
296289
def _health_check(self):
297-
_logger.debug("Performing healthcheck of Hopsworks Feature Query Service.")
290+
_logger.debug("Performing healthcheck of Hopsworks Query Service.")
298291
action = pyarrow.flight.Action("healthcheck", b"")
299292
options = pyarrow.flight.FlightCallOptions(timeout=self.health_check_timeout)
300293
list(self._connection.do_action(action, options=options))
@@ -303,17 +296,17 @@ def _health_check(self):
303296
def _should_be_used(self):
304297
if not self._enabled_on_cluster:
305298
_logger.debug(
306-
"Hopsworks Feature Query Service not used as it is disabled on the cluster."
299+
"Hopsworks Query Service not used as it is disabled on the cluster."
307300
)
308301
return False
309302

310303
if self._disabled_for_session:
311304
_logger.debug(
312-
"Hopsworks Feature Query Service client failed to initialise and is disabled for the session."
305+
"Hopsworks Query Service client failed to initialise and is disabled for the session."
313306
)
314307
return False
315308

316-
_logger.debug("Using Hopsworks Feature Query Service.")
309+
_logger.debug("Using Hopsworks Query Service.")
317310
return True
318311

319312
def _extract_certs(self):
@@ -332,29 +325,16 @@ def _encode_certs(self, path):
332325
content = f.read()
333326
return base64.b64encode(content).decode("utf-8")
334327

335-
@retry(
336-
wait_exponential_multiplier=1000,
337-
stop_max_attempt_number=3,
338-
retry_on_exception=_should_retry_healthcheck_or_certificate_registration,
339-
)
340-
def _register_certificates(self):
328+
def _certificates(self):
341329
kstore = self._encode_certs(self._client._get_jks_key_store_path())
342330
tstore = self._encode_certs(self._client._get_jks_trust_store_path())
343331
cert_key = self._client._cert_key
344-
certificates_json = json.dumps(
345-
{"kstore": kstore, "tstore": tstore, "cert_key": cert_key}
346-
).encode("ascii")
347-
certificates_json_buf = pyarrow.py_buffer(certificates_json)
348-
action = pyarrow.flight.Action(
349-
"register-client-certificates", certificates_json_buf
350-
)
351-
# Registering certificates queue time occasionally spike.
352-
options = pyarrow.flight.FlightCallOptions(timeout=self.health_check_timeout)
353-
_logger.debug(
354-
"Registering client certificates with Hopsworks Feature Query Service."
355-
)
356-
self._connection.do_action(action, options=options)
357-
_logger.debug("Client certificates registered.")
332+
return {"kstore": kstore, "tstore": tstore, "cert_key": cert_key}
333+
334+
def _certificates_header(self):
335+
if self._certificates_json is None:
336+
self._certificates_json = json.dumps(self._certificates()).encode("utf-8")
337+
return (b"x-certificates-json", self._certificates_json)
358338

359339
def _handle_afs_exception(user_message="None"):
360340
def decorator(func):
@@ -366,15 +346,9 @@ def afs_error_handler_wrapper(instance, *args, **kw):
366346
message = str(e)
367347
_logger.debug("Caught exception in %s: %s", func.__name__, message)
368348
_logger.exception(e)
369-
if (
370-
isinstance(e, FlightServerError)
371-
and "Please register client certificates first." in message
372-
):
373-
instance._register_certificates()
374-
return func(instance, *args, **kw)
375-
elif _is_feature_query_service_queue_full_error(e):
349+
if _is_feature_query_service_queue_full_error(e):
376350
raise FeatureStoreException(
377-
"Hopsworks Feature Query Service is busy right now. Please try again later."
351+
"Hopsworks Query Service is busy right now. Please try again later."
378352
) from e
379353
elif _is_no_commits_found_error(e):
380354
raise FeatureStoreException(str(e).split("Details:")[0]) from e
@@ -409,7 +383,9 @@ def _get_dataset(self, descriptor, timeout=None, dataframe_type="pandas"):
409383
timeout = self.timeout
410384
info = self.get_flight_info(descriptor)
411385
_logger.debug("Retrieved flight info: %s. Fetching dataset.", str(info))
412-
options = pyarrow.flight.FlightCallOptions(timeout=timeout)
386+
options = pyarrow.flight.FlightCallOptions(
387+
timeout=timeout, headers=[self._certificates_header()]
388+
)
413389
reader = self._connection.do_get(info.endpoints[0].ticket, options)
414390
_logger.debug("Dataset fetched. Converting to dataframe %s.", dataframe_type)
415391
if dataframe_type.lower() == "polars":
@@ -479,7 +455,9 @@ def create_training_dataset(
479455
if arrow_flight_config
480456
else self.timeout
481457
)
482-
options = pyarrow.flight.FlightCallOptions(timeout=timeout)
458+
options = pyarrow.flight.FlightCallOptions(
459+
timeout=timeout, headers=[self._certificates_header()]
460+
)
483461
for result in self._connection.do_action(action, options):
484462
return result.body.to_pybytes()
485463
except pyarrow.lib.ArrowIOError as e:
@@ -516,7 +494,7 @@ def is_enabled(self):
516494

517495
@property
518496
def timeout(self) -> Union[int, float]:
519-
"""Timeout in seconds for Hopsworks Feature Query Service do_get or do_action operations, not including the healthcheck."""
497+
"""Timeout in seconds for Hopsworks Query Service do_get or do_action operations, not including the healthcheck."""
520498
return self._timeout
521499

522500
@timeout.setter
@@ -534,7 +512,7 @@ def health_check_timeout(self, value: Union[int, float]) -> None:
534512

535513
@property
536514
def host_url(self) -> Optional[str]:
537-
"""URL of Hopsworks Feature Query Service."""
515+
"""URL of Hopsworks Query Service."""
538516
return self._host_url
539517

540518
@host_url.setter

0 commit comments

Comments
 (0)