Skip to content

Commit 292834a

Browse files
hoytakWauplin
andauthored
Improved progress reporting for Xet uploads (#3096)
* Improved progress reporting for Xet Uploads This PR adds detailed progress reporting for upload_files when hf_xet is used, showing both per-file progress and accurate total progress. Total progress speed, which includes both deduplication and data transfer, is also separated out into separate bars. Requries xet-core / hf_xet at commit 4faec0b or later. * Smoothed out transfer speed. * Updated progress bars. * Update; requires new release. * Update src/huggingface_hub/utils/_xet_progress_reporting.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/huggingface_hub/utils/_xet_progress_reporting.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/huggingface_hub/utils/_xet_progress_reporting.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/huggingface_hub/_commit_api.py Co-authored-by: Lucain <lucainp@gmail.com> * Update src/huggingface_hub/utils/_xet_progress_reporting.py Co-authored-by: Lucain <lucainp@gmail.com> * Updates. * Updated style. * Style update. * Update style. * Updated mypy issues. * Update for mypy issues. * Updated style to work with python 3.9 --------- Co-authored-by: Lucain <lucain@huggingface.co> Co-authored-by: Lucain <lucainp@gmail.com>
1 parent a0429e7 commit 292834a

File tree

3 files changed

+163
-29
lines changed

3 files changed

+163
-29
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_version() -> str:
1414
install_requires = [
1515
"filelock",
1616
"fsspec>=2023.5.0",
17-
"hf-xet>=1.1.2,<2.0.0; platform_machine=='x86_64' or platform_machine=='amd64' or platform_machine=='arm64' or platform_machine=='aarch64'",
17+
"hf-xet>=1.1.3,<2.0.0; platform_machine=='x86_64' or platform_machine=='amd64' or platform_machine=='arm64' or platform_machine=='aarch64'",
1818
"packaging>=20.9",
1919
"pyyaml>=5.1",
2020
"requests",

src/huggingface_hub/_commit_api.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import base64
66
import io
7-
import math
87
import os
98
import warnings
109
from collections import defaultdict
@@ -23,6 +22,7 @@
2322
from .utils import (
2423
FORBIDDEN_FOLDERS,
2524
XetTokenType,
25+
are_progress_bars_disabled,
2626
chunk_iterable,
2727
fetch_xet_connection_info_from_repo_info,
2828
get_session,
@@ -33,7 +33,6 @@
3333
validate_hf_hub_args,
3434
)
3535
from .utils import tqdm as hf_tqdm
36-
from .utils.tqdm import _get_progress_bar_context
3736

3837

3938
if TYPE_CHECKING:
@@ -529,9 +528,12 @@ def _upload_xet_files(
529528
"""
530529
if len(additions) == 0:
531530
return
531+
532532
# at this point, we know that hf_xet is installed
533533
from hf_xet import upload_bytes, upload_files
534534

535+
from .utils._xet_progress_reporting import XetProgressReporter
536+
535537
try:
536538
xet_connection_info = fetch_xet_connection_info_from_repo_info(
537539
token_type=XetTokenType.WRITE,
@@ -567,40 +569,26 @@ def token_refresher() -> Tuple[str, int]:
567569
raise XetRefreshTokenError("Failed to refresh xet token")
568570
return new_xet_connection.access_token, new_xet_connection.expiration_unix_epoch
569571

570-
num_chunks = math.ceil(len(additions) / UPLOAD_BATCH_MAX_NUM_FILES)
571-
num_chunks_num_digits = int(math.log10(num_chunks)) + 1
572-
for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)):
573-
_chunk = [op for op in chunk]
574-
575-
bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)]
576-
paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))]
577-
expected_size = sum(op.upload_info.size for op in bytes_ops + paths_ops)
572+
if not are_progress_bars_disabled():
573+
progress = XetProgressReporter()
574+
progress_callback = progress.update_progress
575+
else:
576+
progress, progress_callback = None, None
578577

579-
if num_chunks > 1:
580-
description = f"Uploading Batch [{str(i + 1).zfill(num_chunks_num_digits)}/{num_chunks}]..."
581-
else:
582-
description = "Uploading..."
583-
progress_cm = _get_progress_bar_context(
584-
desc=description,
585-
total=expected_size,
586-
initial=0,
587-
unit="B",
588-
unit_scale=True,
589-
name="huggingface_hub.xet_put",
590-
log_level=logger.getEffectiveLevel(),
591-
)
592-
with progress_cm as progress:
578+
try:
579+
for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)):
580+
_chunk = [op for op in chunk]
593581

594-
def update_progress(increment: int):
595-
progress.update(increment)
582+
bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)]
583+
paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))]
596584

597585
if len(paths_ops) > 0:
598586
upload_files(
599587
[str(op.path_or_fileobj) for op in paths_ops],
600588
xet_endpoint,
601589
access_token_info,
602590
token_refresher,
603-
update_progress,
591+
progress_callback,
604592
repo_type,
605593
)
606594
if len(bytes_ops) > 0:
@@ -609,9 +597,14 @@ def update_progress(increment: int):
609597
xet_endpoint,
610598
access_token_info,
611599
token_refresher,
612-
update_progress,
600+
progress_callback,
613601
repo_type,
614602
)
603+
604+
finally:
605+
if progress is not None:
606+
progress.close(False)
607+
615608
return
616609

617610

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from collections import OrderedDict
2+
from typing import List
3+
4+
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate
5+
6+
from .tqdm import tqdm
7+
8+
9+
class XetProgressReporter:
10+
def __init__(self, n_lines: int = 10, description_width: int = 40):
11+
self.n_lines = n_lines
12+
self.description_width = description_width
13+
14+
self.tqdm_settings = {
15+
"unit": "B",
16+
"unit_scale": True,
17+
"leave": True,
18+
"unit_divisor": 1000,
19+
"nrows": n_lines + 3,
20+
"miniters": 1,
21+
"bar_format": "{l_bar}{bar}| {n_fmt:>5}B / {total_fmt:>5}B{postfix:>12}",
22+
}
23+
24+
# Overall progress bars
25+
self.data_processing_bar = tqdm(
26+
total=0, desc=self.format_desc("Processing Files (0 / 0)", False), position=0, **self.tqdm_settings
27+
)
28+
29+
self.upload_bar = tqdm(
30+
total=0, desc=self.format_desc("New Data Upload", False), position=1, **self.tqdm_settings
31+
)
32+
33+
self.known_items: set[str] = set()
34+
self.completed_items: set[str] = set()
35+
36+
# Item bars (scrolling view)
37+
self.item_state: OrderedDict[str, PyItemProgressUpdate] = OrderedDict()
38+
self.current_bars: List = [None] * self.n_lines
39+
40+
def format_desc(self, name: str, indent: bool) -> str:
41+
"""
42+
if name is longer than width characters, prints ... at the start and then the last width-3 characters of the name, otherwise
43+
the whole name right justified into 20 characters. Also adds some padding.
44+
"""
45+
padding = " " if indent else ""
46+
width = self.description_width - len(padding)
47+
48+
if len(name) > width:
49+
name = f"...{name[-(width - 3) :]}"
50+
51+
return f"{padding}{name.ljust(width)}"
52+
53+
def update_progress(self, total_update: PyTotalProgressUpdate, item_updates: List[PyItemProgressUpdate]):
54+
# Update all the per-item values.
55+
for item in item_updates:
56+
item_name = item.item_name
57+
58+
self.known_items.add(item_name)
59+
60+
# Only care about items where the processing has already started.
61+
if item.bytes_completed == 0:
62+
continue
63+
64+
# Overwrite the existing value in there.
65+
self.item_state[item_name] = item
66+
67+
bar_idx = 0
68+
new_completed = []
69+
70+
# Now, go through and update all the bars
71+
for name, item in self.item_state.items():
72+
# Is this ready to be removed on the next update?
73+
if item.bytes_completed == item.total_bytes:
74+
self.completed_items.add(name)
75+
new_completed.append(name)
76+
77+
# If we've run out of bars to use, then collapse the last ones together.
78+
if bar_idx >= len(self.current_bars):
79+
bar = self.current_bars[-1]
80+
in_final_bar_mode = True
81+
final_bar_aggregation_count = bar_idx + 1 - len(self.current_bars)
82+
else:
83+
bar = self.current_bars[bar_idx]
84+
in_final_bar_mode = False
85+
86+
if bar is None:
87+
self.current_bars[bar_idx] = tqdm(
88+
desc=self.format_desc(name, True),
89+
position=2 + bar_idx, # Set to the position past the initial bars.
90+
total=item.total_bytes,
91+
initial=item.bytes_completed,
92+
**self.tqdm_settings,
93+
)
94+
95+
elif in_final_bar_mode:
96+
bar.n += item.bytes_completed
97+
bar.total += item.total_bytes
98+
bar.set_description(self.format_desc(f"[+ {final_bar_aggregation_count} files]", True), refresh=False)
99+
else:
100+
bar.set_description(self.format_desc(name, True), refresh=False)
101+
bar.n = item.bytes_completed
102+
bar.total = item.total_bytes
103+
104+
bar_idx += 1
105+
106+
# Remove all the completed ones from the ordered dictionary
107+
for name in new_completed:
108+
# Only remove ones from consideration to make room for more items coming in.
109+
if len(self.item_state) <= self.n_lines:
110+
break
111+
112+
del self.item_state[name]
113+
114+
# Now manually refresh each of the bars
115+
for bar in self.current_bars:
116+
if bar:
117+
bar.refresh()
118+
119+
# Update overall bars
120+
def postfix(speed):
121+
s = tqdm.format_sizeof(speed) if speed is not None else "???"
122+
return f"{s}B/s ".rjust(10, " ")
123+
124+
self.data_processing_bar.total = total_update.total_bytes
125+
self.data_processing_bar.set_description(
126+
self.format_desc(f"Processing Files ({len(self.completed_items)} / {len(self.known_items)})", False),
127+
refresh=False,
128+
)
129+
self.data_processing_bar.set_postfix_str(postfix(total_update.total_bytes_completion_rate), refresh=False)
130+
self.data_processing_bar.update(total_update.total_bytes_completion_increment)
131+
132+
self.upload_bar.total = total_update.total_transfer_bytes
133+
self.upload_bar.set_postfix_str(postfix(total_update.total_transfer_bytes_completion_rate), refresh=False)
134+
self.upload_bar.update(total_update.total_transfer_bytes_completion_increment)
135+
136+
def close(self, _success):
137+
self.data_processing_bar.close()
138+
self.upload_bar.close()
139+
for bar in self.current_bars:
140+
if bar:
141+
bar.close()

0 commit comments

Comments
 (0)