Skip to content

Commit 217cce7

Browse files
committed
Updated package v4.5
1 parent c7e9492 commit 217cce7

File tree

4 files changed

+166
-79
lines changed

4 files changed

+166
-79
lines changed

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22

33
All notable changes to LocalLab will be documented in this file.
44

5+
## 0.4.5 - 2025-03-08
6+
7+
### Added
8+
9+
- Added memory monitoring to prevent CUDA out of memory errors
10+
- Implemented adaptive token generation for streaming responses
11+
- Added CUDA memory configuration with expandable segments
12+
13+
### Fixed
14+
15+
- Fixed torch.compile() errors by adding proper error handling and fallback to eager mode
16+
- Fixed early stopping warning by correctly setting num_beams parameter
17+
- Improved streaming generation with smaller token chunks for more responsive output
18+
- Added memory-aware generation that adapts to available GPU resources
19+
- Implemented error recovery for out-of-memory situations during generation
20+
521
## 0.4.4 - 2025-03-08
622

723
### Fixed

locallab/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LocalLab: Run LLMs locally with a friendly API similar to OpenAI
33
"""
44

5-
__version__ = "0.4.4"
5+
__version__ = "0.4.5"
66

77
from typing import Dict, Any, Optional
88

locallab/model_manager.py

Lines changed: 148 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,14 @@ async def load_model(self, model_id: str) -> bool:
231231
gc.collect()
232232
log_model_unloaded(prev_model)
233233

234+
# Set CUDA memory allocation configuration to avoid fragmentation
235+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
236+
237+
# Configure torch.compile to suppress errors and fall back to eager mode
238+
if hasattr(torch, "_dynamo"):
239+
torch._dynamo.config.suppress_errors = True
240+
logger.info("Configured torch._dynamo to suppress errors and fall back to eager mode")
241+
234242
hf_token = os.getenv("HF_TOKEN")
235243

236244
# Check quantization settings from environment variables
@@ -270,29 +278,30 @@ async def load_model(self, model_id: str) -> bool:
270278
self.model = self._apply_optimizations(self.model)
271279

272280
# Try to compile the model for faster inference if PyTorch version supports it
281+
# But with more robust error handling
273282
try:
274-
if torch.__version__ >= "2.0.0" and torch.cuda.is_available():
275-
logger.info("Attempting to compile model with torch.compile() for faster inference...")
276-
# Use a separate thread to avoid blocking
277-
import threading
278-
279-
def compile_model():
283+
if hasattr(torch, 'compile') and torch.cuda.is_available():
284+
# Only attempt compilation if we have enough GPU memory
285+
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
286+
if free_memory > 2 * 1024 * 1024 * 1024: # Only compile if we have >2GB free
287+
logger.info("Attempting to compile model with torch.compile() for faster inference...")
280288
try:
281-
# Only compile the forward method for generation
282-
self.model.forward = torch.compile(
283-
self.model.forward,
284-
mode="reduce-overhead", # Use reduce-overhead mode for faster compilation
285-
fullgraph=False # Partial graph compilation for better compatibility
289+
# Use a safer compilation mode
290+
self.model = torch.compile(
291+
self.model,
292+
mode="reduce-overhead",
293+
fullgraph=False,
294+
dynamic=True # Better for variable input sizes
286295
)
287296
self.compiled_model = True
288297
logger.info(f"{Fore.GREEN}Model successfully compiled with torch.compile(){Style.RESET_ALL}")
289-
except Exception as e:
290-
logger.warning(f"Could not compile model: {str(e)}. Continuing with uncompiled model.")
291-
292-
# Start compilation in background
293-
threading.Thread(target=compile_model).start()
298+
except Exception as compile_error:
299+
logger.warning(f"Model compilation failed with specific error: {str(compile_error)}")
300+
logger.info("Continuing with uncompiled model")
301+
else:
302+
logger.info("Skipping model compilation due to limited GPU memory")
294303
except Exception as e:
295-
logger.warning(f"Torch compile not available: {str(e)}. Continuing with standard model.")
304+
logger.warning(f"Could not compile model: {str(e)}. Continuing with uncompiled model.")
296305

297306
# Capture model parameters after loading
298307
model_architecture = self.model.config.architectures[0] if hasattr(self.model.config, 'architectures') else 'Unknown'
@@ -454,41 +463,62 @@ async def generate(
454463
if stream:
455464
return self.async_stream_generate(inputs, gen_params)
456465

466+
# Check if we need to clear CUDA cache before generation
467+
if torch.cuda.is_available():
468+
current_mem = torch.cuda.memory_allocated() / (1024 * 1024 * 1024) # GB
469+
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024 * 1024) # GB
470+
if current_mem > 0.8 * total_mem: # If using >80% of GPU memory
471+
# Clear cache to avoid OOM
472+
torch.cuda.empty_cache()
473+
logger.info("Cleared CUDA cache before generation to avoid out of memory error")
474+
457475
with torch.no_grad():
458-
generate_params = {
459-
**inputs,
460-
"max_new_tokens": gen_params["max_length"],
461-
"temperature": gen_params["temperature"],
462-
"top_p": gen_params["top_p"],
463-
"do_sample": True,
464-
"pad_token_id": self.tokenizer.eos_token_id
465-
}
476+
try:
477+
generate_params = {
478+
**inputs,
479+
"max_new_tokens": gen_params["max_length"],
480+
"temperature": gen_params["temperature"],
481+
"top_p": gen_params["top_p"],
482+
"do_sample": True,
483+
"pad_token_id": self.tokenizer.eos_token_id,
484+
# Fix the early stopping warning by setting num_beams explicitly
485+
"num_beams": 1
486+
}
466487

467-
# Add optional parameters if present in gen_params
468-
if "top_k" in gen_params:
469-
generate_params["top_k"] = gen_params["top_k"]
470-
if "repetition_penalty" in gen_params:
471-
generate_params["repetition_penalty"] = gen_params["repetition_penalty"]
472-
473-
# Add early stopping for faster generation
474-
generate_params["early_stopping"] = True
475-
476-
# Add batch size for faster generation (process multiple tokens at once)
477-
generate_params["num_return_sequences"] = 1
478-
479-
# Set a reasonable max time for generation to prevent hanging
480-
if "max_time" not in generate_params and not stream:
481-
generate_params["max_time"] = 30.0 # 30 seconds max for generation
482-
483-
# Use efficient attention implementation if available
484-
if hasattr(self.model.config, "attn_implementation"):
485-
generate_params["attn_implementation"] = "flash_attention_2"
488+
# Add optional parameters if present in gen_params
489+
if "top_k" in gen_params:
490+
generate_params["top_k"] = gen_params["top_k"]
491+
if "repetition_penalty" in gen_params:
492+
generate_params["repetition_penalty"] = gen_params["repetition_penalty"]
493+
494+
# Set a reasonable max time for generation to prevent hanging
495+
if "max_time" not in generate_params and not stream:
496+
generate_params["max_time"] = 30.0 # 30 seconds max for generation
497+
498+
# Use efficient attention implementation if available
499+
if hasattr(self.model.config, "attn_implementation"):
500+
generate_params["attn_implementation"] = "flash_attention_2"
486501

487-
# Generate text
488-
start_time = time.time()
489-
outputs = self.model.generate(**generate_params)
490-
generation_time = time.time() - start_time
491-
logger.info(f"Generation completed in {generation_time:.2f} seconds")
502+
# Generate text
503+
start_time = time.time()
504+
outputs = self.model.generate(**generate_params)
505+
generation_time = time.time() - start_time
506+
logger.info(f"Generation completed in {generation_time:.2f} seconds")
507+
508+
except RuntimeError as e:
509+
if "CUDA out of memory" in str(e):
510+
# If we run out of memory, clear cache and try again with smaller parameters
511+
torch.cuda.empty_cache()
512+
logger.warning("CUDA out of memory during generation. Cleared cache and reducing parameters.")
513+
514+
# Reduce parameters for memory efficiency
515+
generate_params["max_new_tokens"] = min(generate_params.get("max_new_tokens", 512), 256)
516+
517+
# Try again with reduced parameters
518+
outputs = self.model.generate(**generate_params)
519+
else:
520+
# For other errors, re-raise
521+
raise
492522

493523
response = self.tokenizer.decode(
494524
outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
@@ -551,9 +581,10 @@ def _stream_generate(
551581
input_ids = inputs["input_ids"]
552582
attention_mask = inputs["attention_mask"]
553583

554-
# Generate multiple tokens at once for efficiency
555-
tokens_to_generate_per_step = 8 # Generate 8 tokens at a time for efficiency
556-
584+
# Generate fewer tokens at once for more responsive streaming
585+
# Using smaller chunks makes it appear more interactive
586+
tokens_to_generate_per_step = 3 # Reduced from 8 to 3 for more responsive streaming
587+
557588
with torch.no_grad():
558589
for step in range(0, max_length, tokens_to_generate_per_step):
559590
# Calculate how many tokens to generate in this step
@@ -570,38 +601,78 @@ def _stream_generate(
570601
"do_sample": True,
571602
"pad_token_id": self.tokenizer.eos_token_id,
572603
"repetition_penalty": repetition_penalty,
573-
"early_stopping": True
604+
# Remove early_stopping to fix the warning
605+
"num_beams": 1 # Explicitly set to 1 to avoid warnings
574606
}
575607

576608
# Use efficient attention if available
577609
if hasattr(self.model.config, "attn_implementation"):
578610
generate_params["attn_implementation"] = "flash_attention_2"
579611

580-
# Generate tokens
581-
outputs = self.model.generate(**generate_params)
582-
583-
# Get the new tokens (skip the input tokens)
584-
new_tokens = outputs[0][len(input_ids[0]):]
585-
586-
# Decode and yield each new token
587-
new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
588-
589-
# If no new text was generated or it's just whitespace, stop generation
590-
if not new_text or new_text.isspace():
591-
break
592-
593-
# Yield the new text
594-
yield new_text
595-
596-
# Update input_ids and attention_mask for next iteration
597-
input_ids = outputs
598-
attention_mask = torch.ones_like(input_ids)
612+
try:
613+
# Generate tokens
614+
outputs = self.model.generate(**generate_params)
615+
616+
# Get the new tokens (skip the input tokens)
617+
new_tokens = outputs[0][len(input_ids[0]):]
618+
619+
# Decode and yield each new token
620+
new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
621+
622+
# If no new text was generated or it's just whitespace, stop generation
623+
if not new_text or new_text.isspace():
624+
break
625+
626+
# Yield the new text
627+
yield new_text
628+
629+
# Update input_ids and attention_mask for next iteration
630+
input_ids = outputs
631+
attention_mask = torch.ones_like(input_ids)
632+
633+
# Ensure the updated inputs are on the correct device
634+
if input_ids.device != model_device:
635+
input_ids = input_ids.to(model_device)
636+
if attention_mask.device != model_device:
637+
attention_mask = attention_mask.to(model_device)
638+
639+
# Check if we're running out of memory and need to clear cache
640+
if torch.cuda.is_available():
641+
current_mem = torch.cuda.memory_allocated() / (1024 * 1024 * 1024) # GB
642+
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024 * 1024) # GB
643+
if current_mem > 0.9 * total_mem: # If using >90% of GPU memory
644+
# Clear cache to avoid OOM
645+
torch.cuda.empty_cache()
646+
logger.info("Cleared CUDA cache to avoid out of memory error")
599647

600-
# Ensure the updated inputs are on the correct device
601-
if input_ids.device != model_device:
602-
input_ids = input_ids.to(model_device)
603-
if attention_mask.device != model_device:
604-
attention_mask = attention_mask.to(model_device)
648+
except RuntimeError as e:
649+
if "CUDA out of memory" in str(e):
650+
# If we run out of memory, clear cache and try again with smaller batch
651+
torch.cuda.empty_cache()
652+
logger.warning("CUDA out of memory during streaming. Cleared cache and reducing batch size.")
653+
654+
# Reduce tokens per step for the rest of generation
655+
tokens_to_generate_per_step = 1
656+
current_tokens_to_generate = 1
657+
658+
# Try again with smaller batch
659+
generate_params["max_new_tokens"] = 1
660+
outputs = self.model.generate(**generate_params)
661+
662+
# Continue as before
663+
new_tokens = outputs[0][len(input_ids[0]):]
664+
new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
665+
666+
if not new_text or new_text.isspace():
667+
break
668+
669+
yield new_text
670+
671+
input_ids = outputs
672+
attention_mask = torch.ones_like(input_ids)
673+
else:
674+
# For other errors, re-raise
675+
raise
605676

606677
except Exception as e:
607678
logger.error(f"Streaming generation failed: {str(e)}")

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="locallab",
8-
version="0.4.4",
8+
version="0.4.5",
99
packages=find_packages(include=["locallab", "locallab.*"]),
1010
install_requires=[
1111
"fastapi>=0.95.0,<1.0.0",

0 commit comments

Comments
 (0)