|
| 1 | +from collections import OrderedDict |
| 2 | +from typing import List |
| 3 | + |
| 4 | +from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate |
| 5 | + |
| 6 | +from .tqdm import tqdm |
| 7 | + |
| 8 | + |
| 9 | +class XetProgressReporter: |
| 10 | + def __init__(self, n_lines: int = 10, description_width: int = 40): |
| 11 | + self.n_lines = n_lines |
| 12 | + self.description_width = description_width |
| 13 | + |
| 14 | + self.tqdm_settings = { |
| 15 | + "unit": "B", |
| 16 | + "unit_scale": True, |
| 17 | + "leave": True, |
| 18 | + "unit_divisor": 1000, |
| 19 | + "nrows": n_lines + 3, |
| 20 | + "miniters": 1, |
| 21 | + "bar_format": "{l_bar}{bar}| {n_fmt:>5}B / {total_fmt:>5}B{postfix:>12}", |
| 22 | + } |
| 23 | + |
| 24 | + # Overall progress bars |
| 25 | + self.data_processing_bar = tqdm( |
| 26 | + total=0, desc=self.format_desc("Processing Files (0 / 0)", False), position=0, **self.tqdm_settings |
| 27 | + ) |
| 28 | + |
| 29 | + self.upload_bar = tqdm( |
| 30 | + total=0, desc=self.format_desc("New Data Upload", False), position=1, **self.tqdm_settings |
| 31 | + ) |
| 32 | + |
| 33 | + self.known_items: set[str] = set() |
| 34 | + self.completed_items: set[str] = set() |
| 35 | + |
| 36 | + # Item bars (scrolling view) |
| 37 | + self.item_state: OrderedDict[str, PyItemProgressUpdate] = OrderedDict() |
| 38 | + self.current_bars: List = [None] * self.n_lines |
| 39 | + |
| 40 | + def format_desc(self, name: str, indent: bool) -> str: |
| 41 | + """ |
| 42 | + if name is longer than width characters, prints ... at the start and then the last width-3 characters of the name, otherwise |
| 43 | + the whole name right justified into 20 characters. Also adds some padding. |
| 44 | + """ |
| 45 | + padding = " " if indent else "" |
| 46 | + width = self.description_width - len(padding) |
| 47 | + |
| 48 | + if len(name) > width: |
| 49 | + name = f"...{name[-(width - 3) :]}" |
| 50 | + |
| 51 | + return f"{padding}{name.ljust(width)}" |
| 52 | + |
| 53 | + def update_progress(self, total_update: PyTotalProgressUpdate, item_updates: List[PyItemProgressUpdate]): |
| 54 | + # Update all the per-item values. |
| 55 | + for item in item_updates: |
| 56 | + item_name = item.item_name |
| 57 | + |
| 58 | + self.known_items.add(item_name) |
| 59 | + |
| 60 | + # Only care about items where the processing has already started. |
| 61 | + if item.bytes_completed == 0: |
| 62 | + continue |
| 63 | + |
| 64 | + # Overwrite the existing value in there. |
| 65 | + self.item_state[item_name] = item |
| 66 | + |
| 67 | + bar_idx = 0 |
| 68 | + new_completed = [] |
| 69 | + |
| 70 | + # Now, go through and update all the bars |
| 71 | + for name, item in self.item_state.items(): |
| 72 | + # Is this ready to be removed on the next update? |
| 73 | + if item.bytes_completed == item.total_bytes: |
| 74 | + self.completed_items.add(name) |
| 75 | + new_completed.append(name) |
| 76 | + |
| 77 | + # If we've run out of bars to use, then collapse the last ones together. |
| 78 | + if bar_idx >= len(self.current_bars): |
| 79 | + bar = self.current_bars[-1] |
| 80 | + in_final_bar_mode = True |
| 81 | + final_bar_aggregation_count = bar_idx + 1 - len(self.current_bars) |
| 82 | + else: |
| 83 | + bar = self.current_bars[bar_idx] |
| 84 | + in_final_bar_mode = False |
| 85 | + |
| 86 | + if bar is None: |
| 87 | + self.current_bars[bar_idx] = tqdm( |
| 88 | + desc=self.format_desc(name, True), |
| 89 | + position=2 + bar_idx, # Set to the position past the initial bars. |
| 90 | + total=item.total_bytes, |
| 91 | + initial=item.bytes_completed, |
| 92 | + **self.tqdm_settings, |
| 93 | + ) |
| 94 | + |
| 95 | + elif in_final_bar_mode: |
| 96 | + bar.n += item.bytes_completed |
| 97 | + bar.total += item.total_bytes |
| 98 | + bar.set_description(self.format_desc(f"[+ {final_bar_aggregation_count} files]", True), refresh=False) |
| 99 | + else: |
| 100 | + bar.set_description(self.format_desc(name, True), refresh=False) |
| 101 | + bar.n = item.bytes_completed |
| 102 | + bar.total = item.total_bytes |
| 103 | + |
| 104 | + bar_idx += 1 |
| 105 | + |
| 106 | + # Remove all the completed ones from the ordered dictionary |
| 107 | + for name in new_completed: |
| 108 | + # Only remove ones from consideration to make room for more items coming in. |
| 109 | + if len(self.item_state) <= self.n_lines: |
| 110 | + break |
| 111 | + |
| 112 | + del self.item_state[name] |
| 113 | + |
| 114 | + # Now manually refresh each of the bars |
| 115 | + for bar in self.current_bars: |
| 116 | + if bar: |
| 117 | + bar.refresh() |
| 118 | + |
| 119 | + # Update overall bars |
| 120 | + def postfix(speed): |
| 121 | + s = tqdm.format_sizeof(speed) if speed is not None else "???" |
| 122 | + return f"{s}B/s ".rjust(10, " ") |
| 123 | + |
| 124 | + self.data_processing_bar.total = total_update.total_bytes |
| 125 | + self.data_processing_bar.set_description( |
| 126 | + self.format_desc(f"Processing Files ({len(self.completed_items)} / {len(self.known_items)})", False), |
| 127 | + refresh=False, |
| 128 | + ) |
| 129 | + self.data_processing_bar.set_postfix_str(postfix(total_update.total_bytes_completion_rate), refresh=False) |
| 130 | + self.data_processing_bar.update(total_update.total_bytes_completion_increment) |
| 131 | + |
| 132 | + self.upload_bar.total = total_update.total_transfer_bytes |
| 133 | + self.upload_bar.set_postfix_str(postfix(total_update.total_transfer_bytes_completion_rate), refresh=False) |
| 134 | + self.upload_bar.update(total_update.total_transfer_bytes_completion_increment) |
| 135 | + |
| 136 | + def close(self, _success): |
| 137 | + self.data_processing_bar.close() |
| 138 | + self.upload_bar.close() |
| 139 | + for bar in self.current_bars: |
| 140 | + if bar: |
| 141 | + bar.close() |
0 commit comments