Skip to content

Fix snapshot_download when unreliable number of files #3241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,19 @@ def snapshot_download(
# At this stage, internet connection is up and running
# => let's download the files!
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."

# Corner case: on very large repos, the siblings list in `repo_info` might not contain all files.
# In that case, we need to use the `list_repo_tree` method to prevent caching issues.
repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings]
has_many_files = len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD
if has_many_files:
logger.info("The repo has more than 50,000 files. Using `list_repo_tree` to ensure all files are listed.")
repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] if repo_info.siblings is not None else []
unreliable_nb_files = (
repo_info.siblings is None
or len(repo_info.siblings) == 0
or len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD
)
if unreliable_nb_files:
logger.info(
"Number of files in the repo is unreliable. Using `list_repo_tree` to ensure all files are listed."
)
repo_files = (
f.rfilename
for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type)
Expand All @@ -274,7 +279,7 @@ def snapshot_download(
ignore_patterns=ignore_patterns,
)

if not has_many_files:
if not unreliable_nb_files:
filtered_repo_files = list(filtered_repo_files)
tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
else:
Expand Down