Skip to content

Commit f6ccb98

Browse files
mfuntowiczWauplin
andauthored
fix(inference_endpoints): use GET healthRoute instead of GET / to check status (#3165)
* fix(inference_endpoints): use GET /health instead of GET / to check status * misc(quality): format * feat(inference_endpoints): use healthRoute from the API to retrieve the path where to query endpoint status * feat(inference_endpoints): missing health_url in classdef * feat(inference_endpoints): wrong variable name ... 🤦🏻‍♂️ * feat(inference_endpoints): need coffee at this stage. * feat(inference_endpoints): fix remaining tests * feat(inference_endpoints): address comments use health_route and compute _health_url * feat(inference_endpoints): move _health_url computation when endpoint is running and we have an actual url * Apply suggestions from code review * add test * Update src/huggingface_hub/_inference_endpoints.py --------- Co-authored-by: Lucain <lucain@huggingface.co> Co-authored-by: Lucain Pouget <lucainp@gmail.com>
1 parent 3734b64 commit f6ccb98

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

src/huggingface_hub/_inference_endpoints.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class InferenceEndpoint:
100100
namespace: str
101101
repository: str = field(init=False)
102102
status: InferenceEndpointStatus = field(init=False)
103+
health_route: str = field(init=False)
103104
url: Optional[str] = field(init=False)
104105

105106
# Other fields
@@ -220,7 +221,8 @@ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "Infere
220221
)
221222
if self.status == InferenceEndpointStatus.RUNNING and self.url is not None:
222223
# Verify the endpoint is actually reachable
223-
response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
224+
_health_url = f"{self.url.rstrip('/')}/{self.health_route.lstrip('/')}"
225+
response = get_session().get(_health_url, headers=self._api._build_hf_headers(token=self._token))
224226
if response.status_code == 200:
225227
logger.info("Inference Endpoint is ready to be used.")
226228
return self
@@ -400,6 +402,7 @@ def _populate_from_raw(self) -> None:
400402
self.repository = self.raw["model"]["repository"]
401403
self.status = self.raw["status"]["state"]
402404
self.url = self.raw["status"].get("url")
405+
self.health_route = self.raw["healthRoute"]
403406

404407
# Other fields
405408
self.framework = self.raw["model"]["framework"]

tests/test_hf_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4521,6 +4521,7 @@ def test_create_inference_endpoint_from_catalog(self, mock_get_session: Mock) ->
45214521
},
45224522
"name": "llama-3-2-3b-instruct-eey",
45234523
"provider": {"region": "us-east-1", "vendor": "aws"},
4524+
"healthRoute": "/health",
45244525
"status": {
45254526
"createdAt": "2025-03-07T15:30:13.949Z",
45264527
"createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
@@ -4650,6 +4651,7 @@ def test_create_inference_endpoint_custom_image_payload(
46504651
},
46514652
"name": "llama-3-2-3b-instruct-eey",
46524653
"provider": {"region": "us-east-1", "vendor": "aws"},
4654+
"healthRoute": "/health",
46534655
"status": {
46544656
"createdAt": "2025-03-07T15:30:13.949Z",
46554657
"createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},

tests/test_inference_endpoints.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"type": "protected",
2020
"accountId": None,
2121
"provider": {"vendor": "aws", "region": "us-east-1"},
22+
"healthRoute": "/health",
2223
"compute": {
2324
"accelerator": "cpu",
2425
"instanceType": "intel-icl",
@@ -51,6 +52,7 @@
5152
"type": "protected",
5253
"accountId": None,
5354
"provider": {"vendor": "aws", "region": "us-east-1"},
55+
"healthRoute": "/health",
5456
"compute": {
5557
"accelerator": "cpu",
5658
"instanceType": "intel-icl",
@@ -84,6 +86,7 @@
8486
"type": "protected",
8587
"accountId": None,
8688
"provider": {"vendor": "aws", "region": "us-east-1"},
89+
"healthRoute": "/health",
8790
"compute": {
8891
"accelerator": "cpu",
8992
"instanceType": "intel-icl",
@@ -116,6 +119,7 @@
116119
"type": "protected",
117120
"accountId": None,
118121
"provider": {"vendor": "aws", "region": "us-east-1"},
122+
"healthRoute": "/health",
119123
"compute": {
120124
"accelerator": "cpu",
121125
"instanceType": "intel-icl",
@@ -158,6 +162,7 @@ def test_from_raw_initialization():
158162
assert endpoint.revision == "11c5a3d5811f50298f278a704980280950aedb10"
159163
assert endpoint.task == "text-generation"
160164
assert endpoint.type == "protected"
165+
assert endpoint.health_route == "/health"
161166

162167
# Datetime parsed correctly
163168
assert endpoint.created_at == datetime(2023, 10, 26, 12, 41, 53, 263078, tzinfo=timezone.utc)
@@ -197,6 +202,7 @@ def test_get_client_ready():
197202
# Endpoint is ready
198203
assert endpoint.status == "running"
199204
assert endpoint.url == "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud"
205+
assert endpoint.health_route == "/health"
200206

201207
# => Client available
202208
client = endpoint.client
@@ -218,6 +224,7 @@ def test_fetch(mock_get: Mock):
218224

219225
assert endpoint.status == "running"
220226
assert endpoint.url == "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud"
227+
assert endpoint.health_route == "/health"
221228

222229

223230
@patch("huggingface_hub._inference_endpoints.get_session")
@@ -245,6 +252,11 @@ def test_wait_until_running(mock_get: Mock, mock_session: Mock):
245252
assert endpoint.status == "running"
246253
assert len(mock_get.call_args_list) == 6
247254

255+
# Ensure the health route has been called
256+
assert mock_session.return_value.get.call_count == 2
257+
for call in mock_session.return_value.get.call_args_list:
258+
assert call[0][0] == "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud/health"
259+
248260

249261
@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint")
250262
def test_wait_timeout(mock_get: Mock):

0 commit comments

Comments
 (0)