@@ -558,14 +558,14 @@ def _stream_generate(
558
558
temperature = gen_params .get (
559
559
"temperature" , DEFAULT_TEMPERATURE )
560
560
top_p = gen_params .get ("top_p" , DEFAULT_TOP_P )
561
- top_k = gen_params .get ("top_k" , 40 ) # Default to 40 for faster generation
561
+ top_k = gen_params .get ("top_k" , 40 ) # Default to 40 for better quality
562
562
repetition_penalty = gen_params .get ("repetition_penalty" , 1.1 )
563
563
else :
564
564
# Use provided individual parameters or defaults
565
565
max_length = max_length or min (DEFAULT_MAX_LENGTH , 512 ) # Limit default max_length
566
- temperature = temperature or 0.7 # Lower temperature for faster generation
566
+ temperature = temperature or 0.7 # Use same temperature as non-streaming
567
567
top_p = top_p or DEFAULT_TOP_P
568
- top_k = 40 # Default to 40 for faster generation
568
+ top_k = 40 # Default to 40 for better quality
569
569
repetition_penalty = 1.1
570
570
571
571
# Get the actual device of the model
@@ -582,15 +582,21 @@ def _stream_generate(
582
582
attention_mask = inputs ["attention_mask" ]
583
583
584
584
# Generate fewer tokens at once for more responsive streaming
585
- # Using smaller chunks makes it appear more interactive
586
- tokens_to_generate_per_step = 3 # Reduced from 8 to 3 for more responsive streaming
585
+ # Using smaller chunks makes it appear more interactive while maintaining quality
586
+ tokens_to_generate_per_step = 2 # Reduced from 3 to 2 for better quality control
587
+
588
+ # Track generated text for quality control
589
+ generated_text = ""
590
+
591
+ # Define stop sequences for proper termination
592
+ stop_sequences = ["</s>" , "<|endoftext|>" , "<|im_end|>" , "<|assistant|>" ]
587
593
588
594
with torch .no_grad ():
589
595
for step in range (0 , max_length , tokens_to_generate_per_step ):
590
596
# Calculate how many tokens to generate in this step
591
597
current_tokens_to_generate = min (tokens_to_generate_per_step , max_length - step )
592
598
593
- # Generate parameters
599
+ # Generate parameters - use the same high-quality parameters as non-streaming
594
600
generate_params = {
595
601
"input_ids" : input_ids ,
596
602
"attention_mask" : attention_mask ,
@@ -601,7 +607,6 @@ def _stream_generate(
601
607
"do_sample" : True ,
602
608
"pad_token_id" : self .tokenizer .eos_token_id ,
603
609
"repetition_penalty" : repetition_penalty ,
604
- # Remove early_stopping to fix the warning
605
610
"num_beams" : 1 # Explicitly set to 1 to avoid warnings
606
611
}
607
612
@@ -623,9 +628,37 @@ def _stream_generate(
623
628
if not new_text or new_text .isspace ():
624
629
break
625
630
631
+ # Add to generated text for quality control
632
+ generated_text += new_text
633
+
634
+ # Check for stop sequences
635
+ should_stop = False
636
+ for stop_seq in stop_sequences :
637
+ if stop_seq in generated_text :
638
+ # We've reached a stop sequence, stop generation
639
+ should_stop = True
640
+ break
641
+
642
+ # Check for repetition (a sign of poor quality)
643
+ if len (generated_text ) > 50 :
644
+ # Check for repeating patterns of 10+ characters
645
+ last_50_chars = generated_text [- 50 :]
646
+ for pattern_len in range (10 , 20 ):
647
+ if pattern_len < len (last_50_chars ) // 2 :
648
+ pattern = last_50_chars [- pattern_len :]
649
+ if pattern in last_50_chars [:- pattern_len ]:
650
+ # Detected repetition, stop generation
651
+ logger .warning ("Detected repetition in streaming generation, stopping" )
652
+ should_stop = True
653
+ break
654
+
626
655
# Yield the new text
627
656
yield new_text
628
657
658
+ # Stop if needed
659
+ if should_stop :
660
+ break
661
+
629
662
# Update input_ids and attention_mask for next iteration
630
663
input_ids = outputs
631
664
attention_mask = torch .ones_like (input_ids )
@@ -666,6 +699,7 @@ def _stream_generate(
666
699
if not new_text or new_text .isspace ():
667
700
break
668
701
702
+ generated_text += new_text
669
703
yield new_text
670
704
671
705
input_ids = outputs
@@ -699,6 +733,24 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
699
733
# Get model-specific generation parameters
700
734
from .config import get_model_generation_params
701
735
gen_params = get_model_generation_params (self .current_model )
736
+
737
+ # Set optimized defaults for streaming that match non-streaming quality
738
+ # Use the same parameters as non-streaming for consistency
739
+ if not kwargs .get ("max_length" ) and not kwargs .get ("max_new_tokens" ):
740
+ # Use a reasonable default max_length
741
+ gen_params ["max_length" ] = min (gen_params .get ("max_length" , DEFAULT_MAX_LENGTH ), 512 )
742
+
743
+ if not kwargs .get ("temperature" ):
744
+ # Use the same temperature as non-streaming
745
+ gen_params ["temperature" ] = min (gen_params .get ("temperature" , DEFAULT_TEMPERATURE ), 0.7 )
746
+
747
+ if not kwargs .get ("top_k" ):
748
+ # Add top_k for better quality
749
+ gen_params ["top_k" ] = 40
750
+
751
+ if not kwargs .get ("repetition_penalty" ):
752
+ # Add repetition penalty to avoid loops
753
+ gen_params ["repetition_penalty" ] = 1.1
702
754
703
755
# Update with provided kwargs
704
756
for key , value in kwargs .items ():
@@ -718,10 +770,56 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
718
770
for key in inputs :
719
771
inputs [key ] = inputs [key ].to (model_device )
720
772
721
- # Now stream tokens using the prepared inputs and parameters
722
- for token in self ._stream_generate (inputs , gen_params = gen_params ):
773
+ # Check if we need to clear CUDA cache before generation
774
+ if torch .cuda .is_available ():
775
+ current_mem = torch .cuda .memory_allocated () / (1024 * 1024 * 1024 ) # GB
776
+ total_mem = torch .cuda .get_device_properties (0 ).total_memory / (1024 * 1024 * 1024 ) # GB
777
+ if current_mem > 0.8 * total_mem : # If using >80% of GPU memory
778
+ # Clear cache to avoid OOM
779
+ torch .cuda .empty_cache ()
780
+ logger .info ("Cleared CUDA cache before streaming generation to avoid out of memory error" )
781
+
782
+ # Create a custom stream generator with improved quality
783
+ async def improved_stream_generator ():
784
+ # Use the same stopping conditions as non-streaming
785
+ stop_sequences = ["</s>" , "<|endoftext|>" , "<|im_end|>" , "<|assistant|>" ]
786
+ accumulated_text = ""
787
+
788
+ # Use a generator that produces high-quality chunks
789
+ try :
790
+ for token_chunk in self ._stream_generate (inputs , gen_params = gen_params ):
791
+ accumulated_text += token_chunk
792
+
793
+ # Check for stop sequences
794
+ should_stop = False
795
+ for stop_seq in stop_sequences :
796
+ if stop_seq in accumulated_text :
797
+ # Truncate at stop sequence
798
+ accumulated_text = accumulated_text .split (stop_seq )[0 ]
799
+ should_stop = True
800
+ break
801
+
802
+ # Yield the token chunk
803
+ yield token_chunk
804
+
805
+ # Stop if we've reached a stop sequence
806
+ if should_stop :
807
+ break
808
+
809
+ # Also stop if we've generated too much text (safety measure)
810
+ if len (accumulated_text ) > gen_params .get ("max_length" , 512 ) * 4 : # Character estimate
811
+ logger .warning ("Stream generation exceeded maximum length - stopping" )
812
+ break
813
+
814
+ await asyncio .sleep (0 )
815
+ except Exception as e :
816
+ logger .error (f"Error in stream generation: { str (e )} " )
817
+ # Don't propagate the error to avoid breaking the stream
818
+ # Just stop generating
819
+
820
+ # Use the improved generator
821
+ async for token in improved_stream_generator ():
723
822
yield token
724
- await asyncio .sleep (0 )
725
823
726
824
def get_model_info (self ) -> Dict [str , Any ]:
727
825
"""Get information about the currently loaded model"""
0 commit comments