Skip to content

Commit 1ac5a0d

Browse files
committed
Improved steaming response generation
1 parent 2325fba commit 1ac5a0d

File tree

1 file changed

+42
-47
lines changed

1 file changed

+42
-47
lines changed

client/python_client/locallab/client.py

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ async def stream_generate(self, prompt: str, options: Optional[Union[GenerateOpt
172172
if options is None:
173173
options = GenerateOptions()
174174

175-
# Ensure stream is True and format data correctly
175+
# Format data consistently
176176
data = {
177177
"prompt": prompt,
178178
"stream": True,
@@ -184,81 +184,76 @@ async def stream_generate(self, prompt: str, options: Optional[Union[GenerateOpt
184184
# Remove None values
185185
data = {k: v for k, v in data.items() if v is not None}
186186

187-
async with self.session.post("/generate", json=data) as response:
188-
if response.status != 200:
189-
error_msg = await response.text()
190-
logger.error(f"Streaming error: {error_msg}")
191-
yield f"Error: {error_msg}"
192-
return
193-
194-
buffer = ""
195-
current_sentence = ""
196-
last_token_was_space = False
187+
if not self.session:
188+
await self.connect()
197189

198-
try:
190+
try:
191+
async with self.session.post("/generate", json=data) as response:
192+
if response.status != 200:
193+
error_msg = await response.text()
194+
logger.error(f"Streaming error: {error_msg}")
195+
yield f"Error: {error_msg}"
196+
return
197+
198+
buffer = ""
199199
async for line in response.content:
200200
if line:
201201
try:
202202
line = line.decode('utf-8').strip()
203-
# Skip empty lines
204203
if not line:
205204
continue
206-
205+
207206
# Handle SSE format
208207
if line.startswith("data: "):
209-
line = line[6:] # Remove "data: " prefix
210-
208+
line = line[6:]
209+
211210
# Skip control messages
212211
if line in ["[DONE]", "[ERROR]"]:
213212
continue
214-
213+
214+
# Parse response
215215
try:
216-
# Try to parse as JSON
217216
data = json.loads(line)
218-
text = data.get("text", data.get("response", ""))
217+
# Handle different response formats
218+
text = data.get("text", data.get("response", data.get("content", "")))
219219
except json.JSONDecodeError:
220-
# If not JSON, use the line as is
220+
# If not JSON, use raw line
221221
text = line
222-
222+
223223
if text:
224-
# Clean up any special tokens
225-
text = text.replace("<|", "").replace("|>", "")
226-
text = text.replace("<", "").replace(">", "")
227-
text = text.replace("[", "").replace("]", "")
228-
text = text.replace("{", "").replace("}", "")
229-
text = text.replace("data:", "")
230-
text = text.replace("��", "")
231-
text = text.replace("\\n", "\n")
232-
text = text.replace("|user|", "")
233-
text = text.replace("|The", "The")
234-
text = text.replace("/|assistant|", "").replace("/|user|", "")
235-
text = text.replace("assistant", "").replace("Error:", "")
224+
# Clean up special tokens and formatting
225+
text = (text.replace("<|", "").replace("|>", "")
226+
.replace("<", "").replace(">", "")
227+
.replace("[", "").replace("]", "")
228+
.replace("{", "").replace("}", "")
229+
.replace("data:", "")
230+
.replace("��", "")
231+
.replace("\\n", "\n")
232+
.replace("|user|", "")
233+
.replace("|assistant|", "")
234+
.replace("assistant:", "")
235+
.replace("user:", ""))
236236

237237
# Add space between words if needed
238-
if (not text.startswith(" ") and
238+
if (buffer and
239+
not text.startswith(" ") and
239240
not text.startswith("\n") and
240-
not last_token_was_space and
241-
buffer and
242-
not buffer.endswith(" ") and
241+
not buffer.endswith(" ") and
243242
not buffer.endswith("\n")):
244243
text = " " + text
245-
246-
# Update tracking variables
247-
buffer += text
248-
current_sentence += text
249-
last_token_was_space = text.endswith(" ") or text.endswith("\n")
250244

245+
buffer += text
251246
yield text
252247

253248
except Exception as e:
254249
logger.error(f"Error processing stream chunk: {str(e)}")
255-
yield f"\nError: {str(e)}"
250+
yield f"\nError: Failed to process response - {str(e)}"
256251
return
257252

258-
except Exception as e:
259-
logger.error(f"Stream connection error: {str(e)}")
260-
yield f"\nError: Connection error - {str(e)}"
261-
return
253+
except Exception as e:
254+
logger.error(f"Stream connection error: {str(e)}")
255+
yield f"\nError: Connection failed - {str(e)}"
256+
return
262257

263258
async def generate(self, prompt: str, options: Optional[Union[GenerateOptions, Dict]] = None) -> GenerateResponse:
264259
"""Generate text from prompt"""

0 commit comments

Comments
 (0)