Skip to content

Commit 36728ec

Browse files
committed
Fixed model loading issues
1 parent bd3b4d2 commit 36728ec

File tree

4 files changed

+64
-216
lines changed

4 files changed

+64
-216
lines changed

locallab/cli/interactive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def prompt_for_config(use_ngrok: bool = None, port: int = None, ngrok_auth_token
9191
default=config.get("model_id", DEFAULT_MODEL)
9292
)
9393
config["model_id"] = model_id
94+
# Set environment variable for model
95+
os.environ["HUGGINGFACE_MODEL"] = model_id
9496

9597
# Port configuration
9698
port = click.prompt(

locallab/core/app.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def init(backend, **kwargs):
3838
ENABLE_COMPRESSION,
3939
QUANTIZATION_TYPE,
4040
)
41+
from ..cli.config import get_config_value
4142

4243
# Get the logger
4344
logger = get_logger("locallab.app")
@@ -97,8 +98,15 @@ async def startup_event():
9798
else:
9899
logger.warning("FastAPICache not available, caching disabled")
99100

100-
# Check for model specified in environment variables (prioritize HUGGINGFACE_MODEL)
101-
model_to_load = os.environ.get("HUGGINGFACE_MODEL", DEFAULT_MODEL)
101+
# Check for model specified in environment variables or CLI config
102+
# Priority: HUGGINGFACE_MODEL > CLI config > DEFAULT_MODEL
103+
from ..cli.config import get_config_value
104+
105+
model_to_load = (
106+
os.environ.get("HUGGINGFACE_MODEL") or
107+
get_config_value("model_id") or
108+
DEFAULT_MODEL
109+
)
102110

103111
# Log model configuration
104112
logger.info(f"Model configuration:")

locallab/model_manager.py

Lines changed: 31 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,11 @@
4343

4444
class ModelManager:
4545
def __init__(self):
46-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
47-
self.current_model: Optional[str] = None
48-
self.model: Optional[AutoModelForCausalLM] = None
49-
self.tokenizer: Optional[AutoTokenizer] = None
50-
self.model_config: Optional[Dict[str, Any]] = None
51-
self.last_used: float = time.time()
52-
self.compiled_model: bool = False
53-
self.response_cache = {} # Simple in-memory cache for responses
54-
55-
logger.info(f"Using device: {self.device}")
56-
57-
# Only try to use Flash Attention if it's explicitly enabled and not empty
58-
if ENABLE_FLASH_ATTENTION and str(ENABLE_FLASH_ATTENTION).lower() not in ('false', '0', 'none', ''):
59-
try:
60-
import flash_attn
61-
logger.info("Flash Attention enabled - will accelerate transformer attention operations")
62-
except ImportError:
63-
logger.info("Flash Attention not available - this is an optional optimization and won't affect basic functionality")
64-
logger.info("To enable Flash Attention, install with: pip install flash-attn --no-build-isolation")
46+
self.model = None
47+
self.tokenizer = None
48+
self.current_model = None
49+
self._loading = False
50+
self._last_use = time.time()
6551

6652
def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
6753
"""Get quantization configuration based on settings"""
@@ -250,181 +236,37 @@ def _apply_optimizations(self, model: AutoModelForCausalLM) -> AutoModelForCausa
250236
logger.warning(f"Some optimizations could not be applied: {str(e)}")
251237
return model
252238

253-
async def load_model(self, model_id: str) -> bool:
254-
"""Load a model from HuggingFace Hub"""
239+
async def load_model(self, model_id: str) -> None:
240+
"""Load a model but don't persist it to config"""
241+
if self._loading:
242+
raise RuntimeError("Another model is currently loading")
243+
244+
self._loading = True
245+
255246
try:
247+
# Unload current model if any
248+
if self.model:
249+
self.unload_model()
250+
251+
# Load the new model
252+
logger.info(f"Loading model: {model_id}")
256253
start_time = time.time()
257-
logger.info(f"\n{Fore.CYAN}Loading model: {model_id}{Style.RESET_ALL}")
258-
259-
# Get and validate HuggingFace token
260-
from .config import get_hf_token, HF_TOKEN_ENV, set_env_var
261-
hf_token = get_hf_token(interactive=False)
262254

263-
if hf_token:
264-
# Ensure token is properly set in environment
265-
set_env_var(HF_TOKEN_ENV, hf_token)
255+
# Apply optimizations based on config
256+
self.model = await self._load_model_with_optimizations(model_id)
257+
self.current_model = model_id
258+
259+
load_time = time.time() - start_time
260+
logger.info(f"Model loaded in {load_time:.2f} seconds")
261+
262+
# Log the model load but don't persist to config
263+
log_model_loaded(model_id, load_time)
266264

267-
if not hf_token and model_id in ["microsoft/phi-2"]: # Add other gated models here
268-
logger.error(f"{Fore.RED}This model requires authentication. Please configure your HuggingFace token first.{Style.RESET_ALL}")
269-
logger.info(f"{Fore.YELLOW}You can set your token by running: locallab config{Style.RESET_ALL}")
270-
raise HTTPException(
271-
status_code=401,
272-
detail="HuggingFace token required for this model. Run 'locallab config' to set up."
273-
)
274-
275-
if self.model is not None:
276-
prev_model = self.current_model
277-
logger.info(f"Unloading previous model: {prev_model}")
278-
del self.model
279-
self.model = None
280-
self.compiled_model = False
281-
self.response_cache.clear() # Clear cache when changing models
282-
torch.cuda.empty_cache()
283-
gc.collect()
284-
log_model_unloaded(prev_model)
285-
286-
# Validate token if provided
287-
if hf_token:
288-
from huggingface_hub import HfApi
289-
try:
290-
api = HfApi()
291-
api.whoami(token=hf_token)
292-
logger.info(f"{Fore.GREEN}✓ HuggingFace token validated{Style.RESET_ALL}")
293-
except Exception as e:
294-
logger.error(f"{Fore.RED}Invalid HuggingFace token: {str(e)}{Style.RESET_ALL}")
295-
raise HTTPException(
296-
status_code=401,
297-
detail=f"Invalid HuggingFace token: {str(e)}"
298-
)
299-
# Set CUDA memory allocation configuration
300-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
301-
try:
302-
# First, try to get model config to check architecture
303-
from transformers import AutoConfig, BertLMHeadModel, AutoModelForCausalLM
304-
model_config = AutoConfig.from_pretrained(
305-
model_id,
306-
trust_remote_code=True,
307-
token=hf_token
308-
)
309-
310-
# First, try to get model config to check architecture
311-
from transformers import AutoConfig
312-
model_config = AutoConfig.from_pretrained(
313-
model_id,
314-
trust_remote_code=True,
315-
token=hf_token
316-
)
317-
318-
# Check if it's a BERT-based model
319-
is_bert_based = any(arch.lower().startswith('bert') for arch in model_config.architectures) if hasattr(model_config, 'architectures') else False
320-
321-
# If it's BERT-based, set is_decoder=True
322-
if is_bert_based:
323-
logger.info("Detected BERT-based model, configuring for generation...")
324-
model_config.is_decoder = True
325-
if hasattr(model_config, 'add_cross_attention'):
326-
model_config.add_cross_attention = False
327-
328-
# Load tokenizer first
329-
self.tokenizer = AutoTokenizer.from_pretrained(
330-
model_id,
331-
trust_remote_code=True,
332-
token=hf_token
333-
)
334-
# Get quantization configuration
335-
config = self._get_quantization_config()
336-
337-
# Determine if we should use CPU offloading
338-
use_cpu_offload = not torch.cuda.is_available() or torch.cuda.get_device_properties(0).total_memory < 4 * 1024 * 1024 * 1024 # Less than 4GB VRAM
339-
340-
if use_cpu_offload:
341-
logger.info("Using CPU offloading due to limited GPU memory or CPU-only environment")
342-
config["device_map"] = {
343-
"": "cpu"
344-
}
345-
config["offload_folder"] = "offload"
346-
if "torch_dtype" in config:
347-
# Use lower precision for CPU to save memory
348-
config["torch_dtype"] = torch.float32
349-
350-
# Load the model with the modified config
351-
self.model = AutoModelForCausalLM.from_pretrained(
352-
model_id,
353-
config=model_config,
354-
trust_remote_code=True,
355-
token=hf_token,
356-
low_cpu_mem_usage=True, # Enable low memory usage
357-
**config
358-
)
359-
360-
logger.info(f"Model loaded with device_map='auto' for automatic placement")
361-
362-
# Apply memory optimizations
363-
if use_cpu_offload:
364-
# Enable gradient checkpointing for memory efficiency
365-
if hasattr(self.model, "gradient_checkpointing_enable"):
366-
self.model.gradient_checkpointing_enable()
367-
logger.info("Enabled gradient checkpointing for memory efficiency")
368-
369-
# Enable CPU offloading if available
370-
if hasattr(self.model, "enable_cpu_offload"):
371-
self.model.enable_cpu_offload()
372-
logger.info("Enabled CPU offloading")
373-
374-
# Apply optimizations if needed
375-
self.model = self._apply_optimizations(self.model)
376-
377-
# Set model to evaluation mode
378-
self.model.eval()
379-
380-
# Clear any unused memory
381-
if torch.cuda.is_available():
382-
torch.cuda.empty_cache()
383-
gc.collect()
384-
385-
# Capture model parameters after loading
386-
model_architecture = self.model.config.architectures[0] if hasattr(self.model.config, 'architectures') else 'Unknown'
387-
memory_used = torch.cuda.memory_allocated() if torch.cuda.is_available() else 'N/A'
388-
logger.info(f"Model architecture: {model_architecture}")
389-
logger.info(f"Memory used: {memory_used}")
390-
391-
self.current_model = model_id
392-
if model_id in MODEL_REGISTRY:
393-
self.model_config = MODEL_REGISTRY[model_id]
394-
else:
395-
self.model_config = {"max_length": DEFAULT_MAX_LENGTH}
396-
397-
load_time = time.time() - start_time
398-
log_model_loaded(model_id, load_time)
399-
logger.info(f"{Fore.GREEN}✓ Model '{model_id}' loaded successfully in {load_time:.2f} seconds{Style.RESET_ALL}")
400-
return True
401-
402-
except Exception as e:
403-
logger.error(f"{Fore.RED}✗ Error loading model {model_id}: {str(e)}{Style.RESET_ALL}")
404-
if self.model is not None:
405-
del self.model
406-
self.model = None
407-
self.compiled_model = False
408-
torch.cuda.empty_cache()
409-
gc.collect()
410-
411-
# Try to load a smaller fallback model
412-
fallback_model = "microsoft/phi-2" # A smaller model that works well
413-
if model_id != fallback_model:
414-
logger.warning(f"{Fore.YELLOW}! Attempting to load fallback model: {fallback_model}{Style.RESET_ALL}")
415-
return await self.load_model(fallback_model)
416-
else:
417-
raise HTTPException(
418-
status_code=500,
419-
detail=f"Failed to load model: {str(e)}"
420-
)
421-
422265
except Exception as e:
423-
logger.error(f"{Fore.RED}✗ Failed to load model {model_id}: {str(e)}{Style.RESET_ALL}")
424-
raise HTTPException(
425-
status_code=500,
426-
detail=f"Failed to load model: {str(e)}"
427-
)
266+
logger.error(f"Error loading model: {str(e)}")
267+
raise
268+
finally:
269+
self._loading = False
428270

429271
def check_model_timeout(self):
430272
"""Check if model should be unloaded due to inactivity"""

locallab/routes/models.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,37 @@
22
API routes for model management
33
"""
44

5-
from fastapi import APIRouter, HTTPException, BackgroundTasks
5+
from fastapi import APIRouter, HTTPException, Request
66
from pydantic import BaseModel
7-
from typing import Dict, List, Any, Optional
8-
import time
9-
import asyncio
7+
from typing import Dict, Any, Optional
8+
import os
109

1110
from ..logger import get_logger
1211
from ..core.app import model_manager
13-
from ..config import MODEL_REGISTRY
12+
from ..logger.logger import log_model_loaded, log_model_unloaded
13+
from ..config import get_env_var
1414

1515
# Get logger
1616
logger = get_logger("locallab.routes.models")
1717

1818
# Create router
19-
router = APIRouter(tags=["Models"], prefix="/models")
19+
router = APIRouter(tags=["Models"])
2020

21+
class LoadModelRequest(BaseModel):
22+
"""Request model for loading a model"""
23+
model_id: str
2124

22-
class ModelResponse(BaseModel):
23-
"""Response model for the current model info"""
24-
id: str
25-
name: str
26-
is_loaded: bool
27-
loading_progress: Optional[float] = None
28-
29-
30-
class ModelsListResponse(BaseModel):
31-
"""Response model for the list of available models"""
32-
models: List[Dict[str, Any]]
33-
current_model: Optional[str] = None
34-
25+
@router.post("/models/load")
26+
async def load_model(request: LoadModelRequest) -> Dict[str, str]:
27+
"""Load a specific model"""
28+
try:
29+
# Load the model but don't persist it to config
30+
# This way it won't override CLI/env settings on restart
31+
await model_manager.load_model(request.model_id)
32+
return {"message": f"Model {request.model_id} loaded successfully"}
33+
except Exception as e:
34+
logger.error(f"Failed to load model {request.model_id}: {str(e)}")
35+
raise HTTPException(status_code=500, detail=str(e))
3536

3637
@router.get("", response_model=ModelsListResponse)
3738
async def list_models() -> ModelsListResponse:
@@ -95,11 +96,6 @@ async def load_model(model_id: str, background_tasks: BackgroundTasks) -> Dict[s
9596
raise HTTPException(status_code=500, detail=str(e))
9697

9798

98-
class LoadModelRequest(BaseModel):
99-
"""Request model for loading a model with JSON body"""
100-
model_id: str
101-
102-
10399
@router.post("/load", response_model=Dict[str, str])
104100
async def load_model_from_body(request: LoadModelRequest, background_tasks: BackgroundTasks) -> Dict[str, str]:
105101
"""Load a specific model using model_id from request body"""
@@ -158,4 +154,4 @@ async def get_model_status(model_id: str) -> ModelResponse:
158154
name=model_info.get("name", model_id),
159155
is_loaded=False,
160156
loading_progress=0.0
161-
)
157+
)

0 commit comments

Comments
 (0)