Skip to content

Commit 58888f3

Browse files
committed
fixed streaming errors and improved formatting
1 parent 518dee6 commit 58888f3

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

locallab/model_manager.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -747,14 +747,26 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
747747

748748
# Create a custom stream generator with improved quality
749749
async def improved_stream_generator():
750-
# Use the same stopping conditions as non-streaming
751750
stop_sequences = ["</s>", "<|endoftext|>", "<|im_end|>", "<|assistant|>"]
752751
accumulated_text = ""
752+
last_token_was_space = False # Track if last token was a space
753753

754-
# Use a generator that produces high-quality chunks
755754
try:
756755
for token_chunk in self._stream_generate(inputs, gen_params=gen_params):
756+
# Skip empty chunks
757+
if not token_chunk or token_chunk.isspace():
758+
continue
759+
760+
# Add space between words if needed
761+
if not token_chunk.startswith(" ") and not last_token_was_space and accumulated_text:
762+
token_chunk = " " + token_chunk
763+
757764
accumulated_text += token_chunk
765+
last_token_was_space = token_chunk.endswith(" ")
766+
767+
# Clean up the token
768+
token_chunk = token_chunk.replace("|user|", "").replace("|The", "The")
769+
token_chunk = token_chunk.replace("��", "").replace("\\n", "\n")
758770

759771
# Check for stop sequences
760772
should_stop = False
@@ -765,23 +777,24 @@ async def improved_stream_generator():
765777
should_stop = True
766778
break
767779

768-
# Yield the token chunk
780+
# Yield the cleaned token chunk
769781
yield token_chunk
770782

771783
# Stop if we've reached a stop sequence
772784
if should_stop:
773785
break
774786

775-
# Also stop if we've generated too much text (safety measure)
776-
if len(accumulated_text) > gen_params.get("max_length", 512) * 4: # Character estimate
787+
# Also stop if we've generated too much text
788+
if len(accumulated_text) > gen_params.get("max_length", 512) * 4:
777789
logger.warning("Stream generation exceeded maximum length - stopping")
778790
break
779791

780792
await asyncio.sleep(0)
793+
781794
except Exception as e:
782795
logger.error(f"Error in stream generation: {str(e)}")
783-
# Don't propagate the error to avoid breaking the stream
784-
# Just stop generating
796+
# Send error message to client
797+
yield f"\nError: {str(e)}"
785798

786799
# Use the improved generator
787800
async for token in improved_stream_generator():

0 commit comments

Comments
 (0)