Skip to content

Commit 75b3ddf

Browse files
committed
Updated package v4.6
1 parent 217cce7 commit 75b3ddf

File tree

5 files changed

+124
-16
lines changed

5 files changed

+124
-16
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
All notable changes to LocalLab will be documented in this file.
44

5+
## 0.4.6 - 2025-03-08
6+
7+
### Fixed
8+
9+
- Improved streaming generation quality to match non-streaming responses
10+
- Added proper stopping conditions for streaming to prevent endless generation
11+
- Implemented repetition detection to stop low-quality streaming responses
12+
- Reduced token chunk size for better quality control in streaming mode
13+
- Ensured consistent generation parameters between streaming and non-streaming modes
14+
515
## 0.4.5 - 2025-03-08
616

717
### Added

locallab/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LocalLab: Run LLMs locally with a friendly API similar to OpenAI
33
"""
44

5-
__version__ = "0.4.5"
5+
__version__ = "0.4.6"
66

77
from typing import Dict, Any, Optional
88

locallab/model_manager.py

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -558,14 +558,14 @@ def _stream_generate(
558558
temperature = gen_params.get(
559559
"temperature", DEFAULT_TEMPERATURE)
560560
top_p = gen_params.get("top_p", DEFAULT_TOP_P)
561-
top_k = gen_params.get("top_k", 40) # Default to 40 for faster generation
561+
top_k = gen_params.get("top_k", 40) # Default to 40 for better quality
562562
repetition_penalty = gen_params.get("repetition_penalty", 1.1)
563563
else:
564564
# Use provided individual parameters or defaults
565565
max_length = max_length or min(DEFAULT_MAX_LENGTH, 512) # Limit default max_length
566-
temperature = temperature or 0.7 # Lower temperature for faster generation
566+
temperature = temperature or 0.7 # Use same temperature as non-streaming
567567
top_p = top_p or DEFAULT_TOP_P
568-
top_k = 40 # Default to 40 for faster generation
568+
top_k = 40 # Default to 40 for better quality
569569
repetition_penalty = 1.1
570570

571571
# Get the actual device of the model
@@ -582,15 +582,21 @@ def _stream_generate(
582582
attention_mask = inputs["attention_mask"]
583583

584584
# Generate fewer tokens at once for more responsive streaming
585-
# Using smaller chunks makes it appear more interactive
586-
tokens_to_generate_per_step = 3 # Reduced from 8 to 3 for more responsive streaming
585+
# Using smaller chunks makes it appear more interactive while maintaining quality
586+
tokens_to_generate_per_step = 2 # Reduced from 3 to 2 for better quality control
587+
588+
# Track generated text for quality control
589+
generated_text = ""
590+
591+
# Define stop sequences for proper termination
592+
stop_sequences = ["</s>", "<|endoftext|>", "<|im_end|>", "<|assistant|>"]
587593

588594
with torch.no_grad():
589595
for step in range(0, max_length, tokens_to_generate_per_step):
590596
# Calculate how many tokens to generate in this step
591597
current_tokens_to_generate = min(tokens_to_generate_per_step, max_length - step)
592598

593-
# Generate parameters
599+
# Generate parameters - use the same high-quality parameters as non-streaming
594600
generate_params = {
595601
"input_ids": input_ids,
596602
"attention_mask": attention_mask,
@@ -601,7 +607,6 @@ def _stream_generate(
601607
"do_sample": True,
602608
"pad_token_id": self.tokenizer.eos_token_id,
603609
"repetition_penalty": repetition_penalty,
604-
# Remove early_stopping to fix the warning
605610
"num_beams": 1 # Explicitly set to 1 to avoid warnings
606611
}
607612

@@ -623,9 +628,37 @@ def _stream_generate(
623628
if not new_text or new_text.isspace():
624629
break
625630

631+
# Add to generated text for quality control
632+
generated_text += new_text
633+
634+
# Check for stop sequences
635+
should_stop = False
636+
for stop_seq in stop_sequences:
637+
if stop_seq in generated_text:
638+
# We've reached a stop sequence, stop generation
639+
should_stop = True
640+
break
641+
642+
# Check for repetition (a sign of poor quality)
643+
if len(generated_text) > 50:
644+
# Check for repeating patterns of 10+ characters
645+
last_50_chars = generated_text[-50:]
646+
for pattern_len in range(10, 20):
647+
if pattern_len < len(last_50_chars) // 2:
648+
pattern = last_50_chars[-pattern_len:]
649+
if pattern in last_50_chars[:-pattern_len]:
650+
# Detected repetition, stop generation
651+
logger.warning("Detected repetition in streaming generation, stopping")
652+
should_stop = True
653+
break
654+
626655
# Yield the new text
627656
yield new_text
628657

658+
# Stop if needed
659+
if should_stop:
660+
break
661+
629662
# Update input_ids and attention_mask for next iteration
630663
input_ids = outputs
631664
attention_mask = torch.ones_like(input_ids)
@@ -666,6 +699,7 @@ def _stream_generate(
666699
if not new_text or new_text.isspace():
667700
break
668701

702+
generated_text += new_text
669703
yield new_text
670704

671705
input_ids = outputs
@@ -699,6 +733,24 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
699733
# Get model-specific generation parameters
700734
from .config import get_model_generation_params
701735
gen_params = get_model_generation_params(self.current_model)
736+
737+
# Set optimized defaults for streaming that match non-streaming quality
738+
# Use the same parameters as non-streaming for consistency
739+
if not kwargs.get("max_length") and not kwargs.get("max_new_tokens"):
740+
# Use a reasonable default max_length
741+
gen_params["max_length"] = min(gen_params.get("max_length", DEFAULT_MAX_LENGTH), 512)
742+
743+
if not kwargs.get("temperature"):
744+
# Use the same temperature as non-streaming
745+
gen_params["temperature"] = min(gen_params.get("temperature", DEFAULT_TEMPERATURE), 0.7)
746+
747+
if not kwargs.get("top_k"):
748+
# Add top_k for better quality
749+
gen_params["top_k"] = 40
750+
751+
if not kwargs.get("repetition_penalty"):
752+
# Add repetition penalty to avoid loops
753+
gen_params["repetition_penalty"] = 1.1
702754

703755
# Update with provided kwargs
704756
for key, value in kwargs.items():
@@ -718,10 +770,56 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
718770
for key in inputs:
719771
inputs[key] = inputs[key].to(model_device)
720772

721-
# Now stream tokens using the prepared inputs and parameters
722-
for token in self._stream_generate(inputs, gen_params=gen_params):
773+
# Check if we need to clear CUDA cache before generation
774+
if torch.cuda.is_available():
775+
current_mem = torch.cuda.memory_allocated() / (1024 * 1024 * 1024) # GB
776+
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024 * 1024) # GB
777+
if current_mem > 0.8 * total_mem: # If using >80% of GPU memory
778+
# Clear cache to avoid OOM
779+
torch.cuda.empty_cache()
780+
logger.info("Cleared CUDA cache before streaming generation to avoid out of memory error")
781+
782+
# Create a custom stream generator with improved quality
783+
async def improved_stream_generator():
784+
# Use the same stopping conditions as non-streaming
785+
stop_sequences = ["</s>", "<|endoftext|>", "<|im_end|>", "<|assistant|>"]
786+
accumulated_text = ""
787+
788+
# Use a generator that produces high-quality chunks
789+
try:
790+
for token_chunk in self._stream_generate(inputs, gen_params=gen_params):
791+
accumulated_text += token_chunk
792+
793+
# Check for stop sequences
794+
should_stop = False
795+
for stop_seq in stop_sequences:
796+
if stop_seq in accumulated_text:
797+
# Truncate at stop sequence
798+
accumulated_text = accumulated_text.split(stop_seq)[0]
799+
should_stop = True
800+
break
801+
802+
# Yield the token chunk
803+
yield token_chunk
804+
805+
# Stop if we've reached a stop sequence
806+
if should_stop:
807+
break
808+
809+
# Also stop if we've generated too much text (safety measure)
810+
if len(accumulated_text) > gen_params.get("max_length", 512) * 4: # Character estimate
811+
logger.warning("Stream generation exceeded maximum length - stopping")
812+
break
813+
814+
await asyncio.sleep(0)
815+
except Exception as e:
816+
logger.error(f"Error in stream generation: {str(e)}")
817+
# Don't propagate the error to avoid breaking the stream
818+
# Just stop generating
819+
820+
# Use the improved generator
821+
async for token in improved_stream_generator():
723822
yield token
724-
await asyncio.sleep(0)
725823

726824
def get_model_info(self) -> Dict[str, Any]:
727825
"""Get information about the currently loaded model"""

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="locallab",
8-
version="0.4.5",
8+
version="0.4.6",
99
packages=find_packages(include=["locallab", "locallab.*"]),
1010
install_requires=[
1111
"fastapi>=0.95.0,<1.0.0",

tests/test_endpoints.zsh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,21 @@ test_endpoint "/models/available"
5151

5252
# Test Text Generation
5353
test_endpoint "/generate" "POST" '{
54-
"prompt": "Hello!",
54+
"prompt": "Hello! Tell me about React Native.",
5555
"stream": false
5656
}'
5757

5858
# Test Streaming Generation
5959
test_endpoint "/generate" "POST" '{
60-
"prompt": "Tell me a story",
60+
"prompt": "Create a professional X post for any topics which you want.",
6161
"stream": true
6262
}'
6363

6464
# Test Chat Completion
6565
test_endpoint "/chat" "POST" '{
6666
"messages": [
6767
{"role": "system", "content": "You are a helpful assistant"},
68-
{"role": "user", "content": "Hi, how are you?"}
68+
{"role": "user", "content": "Hi, how are you?, Create a perfect X post for the Topic about Expo and React Native and tell difference btw them."}
6969
],
7070
"stream": false
7171
}'
@@ -74,7 +74,7 @@ test_endpoint "/chat" "POST" '{
7474
test_endpoint "/generate/batch" "POST" '{
7575
"prompts": [
7676
"What is 2+2?",
77-
"Who is Shakespeare?"
77+
"Who is Linus Torwalds?",
7878
]
7979
}'
8080

0 commit comments

Comments
 (0)