Skip to content

Commit eebd13d

Browse files
Override xet refresh route's base URL with HF Endpoint (#3180)
* override the xet refresh route base url with hf endpoint * simpler * Update src/huggingface_hub/utils/_xet.py Co-authored-by: Lucain <lucain@huggingface.co> * add a check before replacing the base url --------- Co-authored-by: Lucain <lucain@huggingface.co>
1 parent 6f9b87e commit eebd13d

File tree

4 files changed

+40
-4
lines changed

4 files changed

+40
-4
lines changed

src/huggingface_hub/file_download.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,7 @@ def get_hf_file_metadata(
14071407
library_version: Optional[str] = None,
14081408
user_agent: Union[Dict, str, None] = None,
14091409
headers: Optional[Dict[str, str]] = None,
1410+
endpoint: Optional[str] = None,
14101411
) -> HfFileMetadata:
14111412
"""Fetch metadata of a file versioned on the Hub for a given url.
14121413
@@ -1432,6 +1433,8 @@ def get_hf_file_metadata(
14321433
The user-agent info in the form of a dictionary or a string.
14331434
headers (`dict`, *optional*):
14341435
Additional headers to be sent with the request.
1436+
endpoint (`str`, *optional*):
1437+
Endpoint of the Hub. Defaults to <https://huggingface.co>.
14351438
14361439
Returns:
14371440
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
@@ -1471,7 +1474,7 @@ def get_hf_file_metadata(
14711474
size=_int_or_none(
14721475
r.headers.get(constants.HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")
14731476
),
1474-
xet_file_data=parse_xet_file_data_from_response(r), # type: ignore
1477+
xet_file_data=parse_xet_file_data_from_response(r, endpoint=endpoint), # type: ignore
14751478
)
14761479

14771480

@@ -1531,7 +1534,7 @@ def _get_metadata_or_catch_error(
15311534
try:
15321535
try:
15331536
metadata = get_hf_file_metadata(
1534-
url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token
1537+
url=url, proxies=proxies, timeout=etag_timeout, headers=headers, token=token, endpoint=endpoint
15351538
)
15361539
except EntryNotFoundError as http_error:
15371540
if storage_folder is not None and relative_filename is not None:

src/huggingface_hub/hf_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5392,6 +5392,7 @@ def get_hf_file_metadata(
53925392
library_name=self.library_name,
53935393
library_version=self.library_version,
53945394
user_agent=self.user_agent,
5395+
endpoint=self.endpoint,
53955396
)
53965397

53975398
@validate_hf_hub_args

src/huggingface_hub/utils/_xet.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ class XetConnectionInfo:
2626
endpoint: str
2727

2828

29-
def parse_xet_file_data_from_response(response: requests.Response) -> Optional[XetFileData]:
29+
def parse_xet_file_data_from_response(
30+
response: requests.Response, endpoint: Optional[str] = None
31+
) -> Optional[XetFileData]:
3032
"""
3133
Parse XET file metadata from an HTTP response.
3234
@@ -52,7 +54,9 @@ def parse_xet_file_data_from_response(response: requests.Response) -> Optional[X
5254
refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
5355
except KeyError:
5456
return None
55-
57+
endpoint = endpoint if endpoint is not None else constants.ENDPOINT
58+
if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
59+
refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
5660
return XetFileData(
5761
file_hash=file_hash,
5862
refresh_route=refresh_route,

tests/test_xet_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,34 @@ def test_parse_invalid_headers_file_info() -> None:
5151
assert parse_xet_file_data_from_response(mock_response) is None
5252

5353

54+
@pytest.mark.parametrize(
55+
"refresh_route, expected_refresh_route",
56+
[
57+
(
58+
"/api/refresh",
59+
"/api/refresh",
60+
),
61+
(
62+
"https://huggingface.co/api/refresh",
63+
"https://xet.example.com/api/refresh",
64+
),
65+
],
66+
)
67+
def test_parse_header_file_info_with_endpoint(refresh_route: str, expected_refresh_route: str) -> None:
68+
mock_response = MagicMock()
69+
mock_response.headers = {
70+
"X-Xet-Hash": "sha256:abcdef",
71+
"X-Xet-Refresh-Route": refresh_route,
72+
}
73+
mock_response.links = {}
74+
75+
file_data = parse_xet_file_data_from_response(mock_response, endpoint="https://xet.example.com")
76+
77+
assert file_data is not None
78+
assert file_data.refresh_route == expected_refresh_route
79+
assert file_data.file_hash == "sha256:abcdef"
80+
81+
5482
def test_parse_valid_headers_connection_info() -> None:
5583
headers = {
5684
"X-Xet-Cas-Url": "https://xet.example.com",

0 commit comments

Comments
 (0)