|
43 | 43 |
|
44 | 44 | class ModelManager:
|
45 | 45 | def __init__(self):
|
46 |
| - self.device = "cuda" if torch.cuda.is_available() else "cpu" |
47 |
| - self.current_model: Optional[str] = None |
48 |
| - self.model: Optional[AutoModelForCausalLM] = None |
49 |
| - self.tokenizer: Optional[AutoTokenizer] = None |
50 |
| - self.model_config: Optional[Dict[str, Any]] = None |
51 |
| - self.last_used: float = time.time() |
52 |
| - self.compiled_model: bool = False |
53 |
| - self.response_cache = {} # Simple in-memory cache for responses |
54 |
| - |
55 |
| - logger.info(f"Using device: {self.device}") |
56 |
| - |
57 |
| - # Only try to use Flash Attention if it's explicitly enabled and not empty |
58 |
| - if ENABLE_FLASH_ATTENTION and str(ENABLE_FLASH_ATTENTION).lower() not in ('false', '0', 'none', ''): |
59 |
| - try: |
60 |
| - import flash_attn |
61 |
| - logger.info("Flash Attention enabled - will accelerate transformer attention operations") |
62 |
| - except ImportError: |
63 |
| - logger.info("Flash Attention not available - this is an optional optimization and won't affect basic functionality") |
64 |
| - logger.info("To enable Flash Attention, install with: pip install flash-attn --no-build-isolation") |
| 46 | + self.model = None |
| 47 | + self.tokenizer = None |
| 48 | + self.current_model = None |
| 49 | + self._loading = False |
| 50 | + self._last_use = time.time() |
65 | 51 |
|
66 | 52 | def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
|
67 | 53 | """Get quantization configuration based on settings"""
|
@@ -250,181 +236,37 @@ def _apply_optimizations(self, model: AutoModelForCausalLM) -> AutoModelForCausa
|
250 | 236 | logger.warning(f"Some optimizations could not be applied: {str(e)}")
|
251 | 237 | return model
|
252 | 238 |
|
253 |
| - async def load_model(self, model_id: str) -> bool: |
254 |
| - """Load a model from HuggingFace Hub""" |
| 239 | + async def load_model(self, model_id: str) -> None: |
| 240 | + """Load a model but don't persist it to config""" |
| 241 | + if self._loading: |
| 242 | + raise RuntimeError("Another model is currently loading") |
| 243 | + |
| 244 | + self._loading = True |
| 245 | + |
255 | 246 | try:
|
| 247 | + # Unload current model if any |
| 248 | + if self.model: |
| 249 | + self.unload_model() |
| 250 | + |
| 251 | + # Load the new model |
| 252 | + logger.info(f"Loading model: {model_id}") |
256 | 253 | start_time = time.time()
|
257 |
| - logger.info(f"\n{Fore.CYAN}Loading model: {model_id}{Style.RESET_ALL}") |
258 |
| - |
259 |
| - # Get and validate HuggingFace token |
260 |
| - from .config import get_hf_token, HF_TOKEN_ENV, set_env_var |
261 |
| - hf_token = get_hf_token(interactive=False) |
262 | 254 |
|
263 |
| - if hf_token: |
264 |
| - # Ensure token is properly set in environment |
265 |
| - set_env_var(HF_TOKEN_ENV, hf_token) |
| 255 | + # Apply optimizations based on config |
| 256 | + self.model = await self._load_model_with_optimizations(model_id) |
| 257 | + self.current_model = model_id |
| 258 | + |
| 259 | + load_time = time.time() - start_time |
| 260 | + logger.info(f"Model loaded in {load_time:.2f} seconds") |
| 261 | + |
| 262 | + # Log the model load but don't persist to config |
| 263 | + log_model_loaded(model_id, load_time) |
266 | 264 |
|
267 |
| - if not hf_token and model_id in ["microsoft/phi-2"]: # Add other gated models here |
268 |
| - logger.error(f"{Fore.RED}This model requires authentication. Please configure your HuggingFace token first.{Style.RESET_ALL}") |
269 |
| - logger.info(f"{Fore.YELLOW}You can set your token by running: locallab config{Style.RESET_ALL}") |
270 |
| - raise HTTPException( |
271 |
| - status_code=401, |
272 |
| - detail="HuggingFace token required for this model. Run 'locallab config' to set up." |
273 |
| - ) |
274 |
| - |
275 |
| - if self.model is not None: |
276 |
| - prev_model = self.current_model |
277 |
| - logger.info(f"Unloading previous model: {prev_model}") |
278 |
| - del self.model |
279 |
| - self.model = None |
280 |
| - self.compiled_model = False |
281 |
| - self.response_cache.clear() # Clear cache when changing models |
282 |
| - torch.cuda.empty_cache() |
283 |
| - gc.collect() |
284 |
| - log_model_unloaded(prev_model) |
285 |
| - |
286 |
| - # Validate token if provided |
287 |
| - if hf_token: |
288 |
| - from huggingface_hub import HfApi |
289 |
| - try: |
290 |
| - api = HfApi() |
291 |
| - api.whoami(token=hf_token) |
292 |
| - logger.info(f"{Fore.GREEN}✓ HuggingFace token validated{Style.RESET_ALL}") |
293 |
| - except Exception as e: |
294 |
| - logger.error(f"{Fore.RED}Invalid HuggingFace token: {str(e)}{Style.RESET_ALL}") |
295 |
| - raise HTTPException( |
296 |
| - status_code=401, |
297 |
| - detail=f"Invalid HuggingFace token: {str(e)}" |
298 |
| - ) |
299 |
| - # Set CUDA memory allocation configuration |
300 |
| - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
301 |
| - try: |
302 |
| - # First, try to get model config to check architecture |
303 |
| - from transformers import AutoConfig, BertLMHeadModel, AutoModelForCausalLM |
304 |
| - model_config = AutoConfig.from_pretrained( |
305 |
| - model_id, |
306 |
| - trust_remote_code=True, |
307 |
| - token=hf_token |
308 |
| - ) |
309 |
| - |
310 |
| - # First, try to get model config to check architecture |
311 |
| - from transformers import AutoConfig |
312 |
| - model_config = AutoConfig.from_pretrained( |
313 |
| - model_id, |
314 |
| - trust_remote_code=True, |
315 |
| - token=hf_token |
316 |
| - ) |
317 |
| - |
318 |
| - # Check if it's a BERT-based model |
319 |
| - is_bert_based = any(arch.lower().startswith('bert') for arch in model_config.architectures) if hasattr(model_config, 'architectures') else False |
320 |
| - |
321 |
| - # If it's BERT-based, set is_decoder=True |
322 |
| - if is_bert_based: |
323 |
| - logger.info("Detected BERT-based model, configuring for generation...") |
324 |
| - model_config.is_decoder = True |
325 |
| - if hasattr(model_config, 'add_cross_attention'): |
326 |
| - model_config.add_cross_attention = False |
327 |
| - |
328 |
| - # Load tokenizer first |
329 |
| - self.tokenizer = AutoTokenizer.from_pretrained( |
330 |
| - model_id, |
331 |
| - trust_remote_code=True, |
332 |
| - token=hf_token |
333 |
| - ) |
334 |
| - # Get quantization configuration |
335 |
| - config = self._get_quantization_config() |
336 |
| - |
337 |
| - # Determine if we should use CPU offloading |
338 |
| - use_cpu_offload = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 4 * 1024 * 1024 * 1024 # Less than 4GB VRAM |
339 |
| - |
340 |
| - if use_cpu_offload: |
341 |
| - logger.info("Using CPU offloading due to limited GPU memory or CPU-only environment") |
342 |
| - config["device_map"] = { |
343 |
| - "": "cpu" |
344 |
| - } |
345 |
| - config["offload_folder"] = "offload" |
346 |
| - if "torch_dtype" in config: |
347 |
| - # Use lower precision for CPU to save memory |
348 |
| - config["torch_dtype"] = torch.float32 |
349 |
| - |
350 |
| - # Load the model with the modified config |
351 |
| - self.model = AutoModelForCausalLM.from_pretrained( |
352 |
| - model_id, |
353 |
| - config=model_config, |
354 |
| - trust_remote_code=True, |
355 |
| - token=hf_token, |
356 |
| - low_cpu_mem_usage=True, # Enable low memory usage |
357 |
| - **config |
358 |
| - ) |
359 |
| - |
360 |
| - logger.info(f"Model loaded with device_map='auto' for automatic placement") |
361 |
| - |
362 |
| - # Apply memory optimizations |
363 |
| - if use_cpu_offload: |
364 |
| - # Enable gradient checkpointing for memory efficiency |
365 |
| - if hasattr(self.model, "gradient_checkpointing_enable"): |
366 |
| - self.model.gradient_checkpointing_enable() |
367 |
| - logger.info("Enabled gradient checkpointing for memory efficiency") |
368 |
| - |
369 |
| - # Enable CPU offloading if available |
370 |
| - if hasattr(self.model, "enable_cpu_offload"): |
371 |
| - self.model.enable_cpu_offload() |
372 |
| - logger.info("Enabled CPU offloading") |
373 |
| - |
374 |
| - # Apply optimizations if needed |
375 |
| - self.model = self._apply_optimizations(self.model) |
376 |
| - |
377 |
| - # Set model to evaluation mode |
378 |
| - self.model.eval() |
379 |
| - |
380 |
| - # Clear any unused memory |
381 |
| - if torch.cuda.is_available(): |
382 |
| - torch.cuda.empty_cache() |
383 |
| - gc.collect() |
384 |
| - |
385 |
| - # Capture model parameters after loading |
386 |
| - model_architecture = self.model.config.architectures[0] if hasattr(self.model.config, 'architectures') else 'Unknown' |
387 |
| - memory_used = torch.cuda.memory_allocated() if torch.cuda.is_available() else 'N/A' |
388 |
| - logger.info(f"Model architecture: {model_architecture}") |
389 |
| - logger.info(f"Memory used: {memory_used}") |
390 |
| - |
391 |
| - self.current_model = model_id |
392 |
| - if model_id in MODEL_REGISTRY: |
393 |
| - self.model_config = MODEL_REGISTRY[model_id] |
394 |
| - else: |
395 |
| - self.model_config = {"max_length": DEFAULT_MAX_LENGTH} |
396 |
| - |
397 |
| - load_time = time.time() - start_time |
398 |
| - log_model_loaded(model_id, load_time) |
399 |
| - logger.info(f"{Fore.GREEN}✓ Model '{model_id}' loaded successfully in {load_time:.2f} seconds{Style.RESET_ALL}") |
400 |
| - return True |
401 |
| - |
402 |
| - except Exception as e: |
403 |
| - logger.error(f"{Fore.RED}✗ Error loading model {model_id}: {str(e)}{Style.RESET_ALL}") |
404 |
| - if self.model is not None: |
405 |
| - del self.model |
406 |
| - self.model = None |
407 |
| - self.compiled_model = False |
408 |
| - torch.cuda.empty_cache() |
409 |
| - gc.collect() |
410 |
| - |
411 |
| - # Try to load a smaller fallback model |
412 |
| - fallback_model = "microsoft/phi-2" # A smaller model that works well |
413 |
| - if model_id != fallback_model: |
414 |
| - logger.warning(f"{Fore.YELLOW}! Attempting to load fallback model: {fallback_model}{Style.RESET_ALL}") |
415 |
| - return await self.load_model(fallback_model) |
416 |
| - else: |
417 |
| - raise HTTPException( |
418 |
| - status_code=500, |
419 |
| - detail=f"Failed to load model: {str(e)}" |
420 |
| - ) |
421 |
| - |
422 | 265 | except Exception as e:
|
423 |
| - logger.error(f"{Fore.RED}✗ Failed to load model {model_id}: {str(e)}{Style.RESET_ALL}") |
424 |
| - raise HTTPException( |
425 |
| - status_code=500, |
426 |
| - detail=f"Failed to load model: {str(e)}" |
427 |
| - ) |
| 266 | + logger.error(f"Error loading model: {str(e)}") |
| 267 | + raise |
| 268 | + finally: |
| 269 | + self._loading = False |
428 | 270 |
|
429 | 271 | def check_model_timeout(self):
|
430 | 272 | """Check if model should be unloaded due to inactivity"""
|
|
0 commit comments