|
| 1 | +# Import early configuration module first to set up logging and environment variables |
| 2 | +# This ensures Hugging Face's progress bars are displayed correctly |
| 3 | +from .utils.early_config import enable_hf_progress_bars, StdoutRedirector |
| 4 | + |
1 | 5 | from .config import HF_TOKEN_ENV, get_env_var, set_env_var
|
2 | 6 | import os
|
3 | 7 | import logging
|
4 | 8 | import torch
|
5 |
| -from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
6 | 9 | from typing import Optional, Generator, Dict, Any, List, Union, Callable, AsyncGenerator
|
7 | 10 | from fastapi import HTTPException
|
8 | 11 | import time
|
|
13 | 16 | )
|
14 | 17 | from .logger.logger import logger, log_model_loaded, log_model_unloaded
|
15 | 18 | from .utils import check_resource_availability, get_device, format_model_size
|
16 |
| -from .utils.progress import configure_hf_hub_progress |
17 | 19 | import gc
|
18 | 20 | from colorama import Fore, Style
|
19 | 21 | import asyncio
|
|
22 | 24 | import tempfile
|
23 | 25 | import json
|
24 | 26 |
|
25 |
| -# Configure HuggingFace Hub progress bars to use native display |
| 27 | +# Enable Hugging Face progress bars with native display |
26 | 28 | # This ensures we see the visually appealing progress bars from HuggingFace
|
27 |
| -configure_hf_hub_progress() |
28 |
| - |
29 |
| -# Also configure transformers to use HuggingFace Hub's progress bars |
30 |
| -try: |
31 |
| - import transformers |
32 |
| - transformers.utils.logging.enable_progress_bar() |
33 |
| - # Set transformers logging to only show warnings and errors |
34 |
| - transformers.logging.set_verbosity_warning() |
35 |
| -except ImportError: |
36 |
| - logger.debug("Could not configure transformers progress bars") |
37 |
| -except Exception as e: |
38 |
| - logger.debug(f"Error configuring transformers progress bars: {str(e)}") |
| 29 | +enable_hf_progress_bars() |
| 30 | + |
| 31 | +# Import transformers after configuring logging to ensure proper display |
| 32 | +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
39 | 33 |
|
40 | 34 | QUANTIZATION_SETTINGS = {
|
41 | 35 | "fp16": {
|
@@ -280,52 +274,41 @@ async def _load_model_with_optimizations(self, model_id: str):
|
280 | 274 | # Add an empty line to separate from HuggingFace progress bars
|
281 | 275 | print("")
|
282 | 276 |
|
283 |
| - # Set a flag to indicate we're downloading a model |
284 |
| - # This will help our logger know to let HuggingFace's progress bars through |
285 |
| - try: |
286 |
| - # Access the module's global variable |
287 |
| - import locallab.utils.progress |
288 |
| - locallab.utils.progress.is_downloading = True |
289 |
| - |
290 |
| - # Ensure HuggingFace Hub's progress bars are enabled |
291 |
| - from huggingface_hub.utils import logging as hf_logging |
292 |
| - hf_logging.enable_progress_bars() |
293 |
| - |
294 |
| - # Configure transformers to use progress bars |
295 |
| - import transformers |
296 |
| - transformers.utils.logging.enable_progress_bar() |
297 |
| - |
298 |
| - # Also ensure tqdm is properly configured for nice display |
299 |
| - import tqdm |
300 |
| - tqdm.tqdm.monitor_interval = 0 # Disable monitor thread which can cause issues |
| 277 | + # Add an empty line before progress bars start |
| 278 | + print(f"\n{Fore.CYAN}Starting model download - native progress bars will appear below{Style.RESET_ALL}\n") |
| 279 | + |
| 280 | + # Enable Hugging Face progress bars again to ensure they're properly configured |
| 281 | + enable_hf_progress_bars() |
| 282 | + |
| 283 | + # Use a context manager to ensure proper display of Hugging Face progress bars |
| 284 | + with StdoutRedirector(disable_logging=True): |
| 285 | + # Load tokenizer first |
| 286 | + logger.info(f"Loading tokenizer for {model_id}...") |
| 287 | + self.tokenizer = AutoTokenizer.from_pretrained( |
| 288 | + model_id, |
| 289 | + token=hf_token if hf_token else None |
| 290 | + ) |
| 291 | + logger.info(f"Tokenizer loaded successfully") |
301 | 292 |
|
302 |
| - # Temporarily disable our custom logger for HuggingFace logs |
303 |
| - import logging |
304 |
| - for logger_name in ['tqdm', 'huggingface_hub', 'transformers', 'filelock']: |
305 |
| - logging.getLogger(logger_name).handlers = [] # Remove any handlers |
306 |
| - logging.getLogger(logger_name).propagate = False # Don't propagate to parent loggers |
307 |
| - except: |
308 |
| - # Fallback if import fails |
309 |
| - pass |
| 293 | + # Load model with optimizations |
| 294 | + logger.info(f"Loading model weights for {model_id}...") |
310 | 295 |
|
311 |
| - # Add an empty line before progress bars start |
312 |
| - print("\n") |
| 296 | + # This is the critical part where we want to see nice progress bars |
| 297 | + # We'll temporarily disable our logger's handlers to prevent interference |
| 298 | + root_logger = logging.getLogger() |
| 299 | + original_handlers = root_logger.handlers.copy() |
| 300 | + root_logger.handlers = [] |
313 | 301 |
|
314 |
| - # Load tokenizer first |
315 |
| - logger.info(f"Loading tokenizer for {model_id}...") |
316 |
| - self.tokenizer = AutoTokenizer.from_pretrained( |
317 |
| - model_id, |
318 |
| - token=hf_token if hf_token else None |
319 |
| - ) |
320 |
| - logger.info(f"Tokenizer loaded successfully") |
321 |
| - |
322 |
| - # Load model with optimizations |
323 |
| - logger.info(f"Loading model weights for {model_id}...") |
324 |
| - self.model = AutoModelForCausalLM.from_pretrained( |
325 |
| - model_id, |
326 |
| - token=hf_token if hf_token else None, |
327 |
| - **quant_config |
328 |
| - ) |
| 302 | + try: |
| 303 | + # Load the model with Hugging Face's native progress bars |
| 304 | + self.model = AutoModelForCausalLM.from_pretrained( |
| 305 | + model_id, |
| 306 | + token=hf_token if hf_token else None, |
| 307 | + **quant_config |
| 308 | + ) |
| 309 | + finally: |
| 310 | + # Restore our logger's handlers |
| 311 | + root_logger.handlers = original_handlers |
329 | 312 | # Reset the downloading flag
|
330 | 313 | try:
|
331 | 314 | # Access the module's global variable
|
@@ -1106,47 +1089,41 @@ async def load_custom_model(self, model_name: str, fallback_model: Optional[str]
|
1106 | 1089 | # Add an empty line to separate from HuggingFace progress bars
|
1107 | 1090 | print("")
|
1108 | 1091 |
|
1109 |
| - # Set a flag to indicate we're downloading a model |
1110 |
| - try: |
1111 |
| - # Access the module's global variable |
1112 |
| - import locallab.utils.progress |
1113 |
| - locallab.utils.progress.is_downloading = True |
1114 |
| - |
1115 |
| - # Ensure HuggingFace Hub's progress bars are enabled |
1116 |
| - from huggingface_hub.utils import logging as hf_logging |
1117 |
| - hf_logging.enable_progress_bars() |
| 1092 | + # Add an empty line before progress bars start |
| 1093 | + print(f"\n{Fore.CYAN}Starting custom model download - native progress bars will appear below{Style.RESET_ALL}\n") |
1118 | 1094 |
|
1119 |
| - # Configure transformers to use progress bars |
1120 |
| - import transformers |
1121 |
| - transformers.utils.logging.enable_progress_bar() |
| 1095 | + # Enable Hugging Face progress bars again to ensure they're properly configured |
| 1096 | + enable_hf_progress_bars() |
1122 | 1097 |
|
1123 |
| - # Also ensure tqdm is properly configured for nice display |
1124 |
| - import tqdm |
1125 |
| - tqdm.tqdm.monitor_interval = 0 # Disable monitor thread which can cause issues |
| 1098 | + # Use a context manager to ensure proper display of Hugging Face progress bars |
| 1099 | + with StdoutRedirector(disable_logging=True): |
| 1100 | + # Load tokenizer first |
| 1101 | + logger.info(f"Loading tokenizer for custom model {model_name}...") |
1126 | 1102 |
|
1127 |
| - # Temporarily disable our custom logger for HuggingFace logs |
1128 |
| - import logging |
1129 |
| - for logger_name in ['tqdm', 'huggingface_hub', 'transformers', 'filelock']: |
1130 |
| - logging.getLogger(logger_name).handlers = [] # Remove any handlers |
1131 |
| - logging.getLogger(logger_name).propagate = False # Don't propagate to parent loggers |
1132 |
| - except: |
1133 |
| - # Fallback if import fails |
1134 |
| - pass |
| 1103 | + # This is the critical part where we want to see nice progress bars |
| 1104 | + # We'll temporarily disable our logger's handlers to prevent interference |
| 1105 | + root_logger = logging.getLogger() |
| 1106 | + original_handlers = root_logger.handlers.copy() |
| 1107 | + root_logger.handlers = [] |
1135 | 1108 |
|
1136 |
| - # Add an empty line before progress bars start |
1137 |
| - print("\n") |
1138 |
| - |
1139 |
| - self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
1140 |
| - logger.info(f"Tokenizer loaded successfully") |
1141 |
| - |
1142 |
| - # Load model with optimizations |
1143 |
| - logger.info(f"Loading model weights for custom model {model_name}...") |
1144 |
| - self.model = AutoModelForCausalLM.from_pretrained( |
1145 |
| - model_name, |
1146 |
| - torch_dtype=torch.float16, |
1147 |
| - device_map="auto", |
1148 |
| - quantization_config=quant_config |
1149 |
| - ) |
| 1109 | + try: |
| 1110 | + # Load tokenizer with Hugging Face's native progress bars |
| 1111 | + self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 1112 | + logger.info(f"Tokenizer loaded successfully") |
| 1113 | + |
| 1114 | + # Load model with optimizations |
| 1115 | + logger.info(f"Loading model weights for custom model {model_name}...") |
| 1116 | + |
| 1117 | + # Load the model with Hugging Face's native progress bars |
| 1118 | + self.model = AutoModelForCausalLM.from_pretrained( |
| 1119 | + model_name, |
| 1120 | + torch_dtype=torch.float16, |
| 1121 | + device_map="auto", |
| 1122 | + quantization_config=quant_config |
| 1123 | + ) |
| 1124 | + finally: |
| 1125 | + # Restore our logger's handlers |
| 1126 | + root_logger.handlers = original_handlers |
1150 | 1127 | # Reset the downloading flag
|
1151 | 1128 | try:
|
1152 | 1129 | # Access the module's global variable
|
|
0 commit comments