@@ -747,14 +747,26 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
747
747
748
748
# Create a custom stream generator with improved quality
749
749
async def improved_stream_generator ():
750
- # Use the same stopping conditions as non-streaming
751
750
stop_sequences = ["</s>" , "<|endoftext|>" , "<|im_end|>" , "<|assistant|>" ]
752
751
accumulated_text = ""
752
+ last_token_was_space = False # Track if last token was a space
753
753
754
- # Use a generator that produces high-quality chunks
755
754
try :
756
755
for token_chunk in self ._stream_generate (inputs , gen_params = gen_params ):
756
+ # Skip empty chunks
757
+ if not token_chunk or token_chunk .isspace ():
758
+ continue
759
+
760
+ # Add space between words if needed
761
+ if not token_chunk .startswith (" " ) and not last_token_was_space and accumulated_text :
762
+ token_chunk = " " + token_chunk
763
+
757
764
accumulated_text += token_chunk
765
+ last_token_was_space = token_chunk .endswith (" " )
766
+
767
+ # Clean up the token
768
+ token_chunk = token_chunk .replace ("|user|" , "" ).replace ("|The" , "The" )
769
+ token_chunk = token_chunk .replace ("��" , "" ).replace ("\\ n" , "\n " )
758
770
759
771
# Check for stop sequences
760
772
should_stop = False
@@ -765,23 +777,24 @@ async def improved_stream_generator():
765
777
should_stop = True
766
778
break
767
779
768
- # Yield the token chunk
780
+ # Yield the cleaned token chunk
769
781
yield token_chunk
770
782
771
783
# Stop if we've reached a stop sequence
772
784
if should_stop :
773
785
break
774
786
775
- # Also stop if we've generated too much text (safety measure)
776
- if len (accumulated_text ) > gen_params .get ("max_length" , 512 ) * 4 : # Character estimate
787
+ # Also stop if we've generated too much text
788
+ if len (accumulated_text ) > gen_params .get ("max_length" , 512 ) * 4 :
777
789
logger .warning ("Stream generation exceeded maximum length - stopping" )
778
790
break
779
791
780
792
await asyncio .sleep (0 )
793
+
781
794
except Exception as e :
782
795
logger .error (f"Error in stream generation: { str (e )} " )
783
- # Don't propagate the error to avoid breaking the stream
784
- # Just stop generating
796
+ # Send error message to client
797
+ yield f" \n Error: { str ( e ) } "
785
798
786
799
# Use the improved generator
787
800
async for token in improved_stream_generator ():
0 commit comments