Skip to content

Commit 213d4f7

Browse files
roryeckelclaude
andcommitted
Add robust error handling and abort logic for TTS streaming failures
Introduces TtsStreamResult dataclass and TtsStreamError exception to properly handle synthesis failures in both streaming and buffered modes. The _abort_synthesis method ensures clean state reset and appropriate Wyoming protocol events when synthesis fails. Includes comprehensive test coverage for various failure scenarios. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 9fbaacd commit 213d4f7

File tree

2 files changed

+290
-58
lines changed

2 files changed

+290
-58
lines changed

src/wyoming_openai/handler.py

Lines changed: 153 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import logging
44
import wave
5+
from dataclasses import dataclass
56

67
import pysbd
78
from openai import NOT_GIVEN, AsyncStream
@@ -44,6 +45,22 @@ def _truncate_for_log(text: str, max_length: int = 100) -> str:
4445
TTS_CHUNK_SIZE = 2048 # Magical guess - but must be larger than 44 bytes for a potential WAV header
4546
TTS_CONCURRENT_REQUESTS = 3 # Number of concurrent OpenAI TTS requests when streaming sentences
4647

48+
49+
@dataclass(frozen=True)
50+
class TtsStreamResult:
51+
"""Container for TTS streaming outcomes."""
52+
streamed: bool
53+
audio: bytes | None = None
54+
55+
56+
class TtsStreamError(Exception):
57+
"""Raised when TTS streaming fails for a specific text chunk."""
58+
def __init__(self, message: str, chunk_preview: str, voice: str):
59+
super().__init__(message)
60+
self.chunk_preview = chunk_preview
61+
self.voice = voice
62+
63+
4764
class OpenAIEventHandler(AsyncEventHandler):
4865
def __init__(
4966
self,
@@ -371,7 +388,7 @@ def _meets_min_criteria(self, text: str, min_words: int) -> bool:
371388
return word_count >= min_words
372389

373390

374-
async def _process_ready_sentences(self, sentences: list[str], language: str | None = None) -> None:
391+
async def _process_ready_sentences(self, sentences: list[str], language: str | None = None) -> bool:
375392
"""
376393
Process complete sentences for immediate TTS synthesis with concurrent requests.
377394
@@ -384,36 +401,31 @@ async def _process_ready_sentences(self, sentences: list[str], language: str | N
384401
- Await tasks in order for sequential playback
385402
- Semaphore naturally limits concurrency to TTS_CONCURRENT_REQUESTS
386403
387-
Error Handling Strategy:
388-
- If synthesis fails for a sentence, we log the error and continue with the next
389-
- This ensures partial content delivery even if some sentences fail
390-
- Audio timestamps remain continuous even when sentences are skipped
391-
392404
Args:
393405
sentences (list[str]): Complete sentences ready for synthesis.
394406
language (str | None): Language code for the sentences.
407+
408+
Returns:
409+
bool: True if processing succeeded, False if synthesis was aborted.
395410
"""
396411
if not sentences or not self._synthesis_voice:
397-
return
412+
return True
398413

399414
try:
400-
# Validate voice and language
401415
requested_voice = self._synthesis_voice.name
402416
requested_language = self._synthesis_voice.language
403417
voice = self._validate_tts_voice_and_language(requested_voice, requested_language)
404418
if not voice:
405419
_LOGGER.error("Failed to validate voice for incremental synthesis")
406-
return
420+
return await self._abort_synthesis()
407421

408-
# Check if streaming is enabled for this voice
409422
use_streaming = self._is_tts_voice_streaming(voice.name)
410423

411424
if use_streaming:
412-
# Filter out empty sentences
413425
valid_sentences = [s for s in sentences if s.strip()]
414426
if not valid_sentences:
415427
_LOGGER.debug("No non-empty sentences available for incremental synthesis.")
416-
return
428+
return True
417429

418430
_LOGGER.info("Starting concurrent synthesis for %d sentences", len(valid_sentences))
419431

@@ -433,39 +445,70 @@ async def _process_ready_sentences(self, sentences: list[str], language: str | N
433445
# Await tasks IN ORDER for sequential playback
434446
# Enable streaming for whichever task we're currently awaiting
435447
for i, (task_id, task) in enumerate(synthesis_tasks):
436-
_LOGGER.debug("Processing sentence %d: %s", i + 1, _truncate_for_log(valid_sentences[i], 50))
448+
sentence_preview = _truncate_for_log(valid_sentences[i], 50)
449+
_LOGGER.debug("Processing sentence %d: %s", i + 1, sentence_preview)
437450

438-
# Enable streaming for this task (if still running, it will stream directly)
439451
self._allow_streaming_task_id = task_id
452+
try:
453+
result = await task
454+
except TtsStreamError as err:
455+
_LOGGER.error(
456+
"Failed to synthesize sentence %d (%s) with voice %s: %s",
457+
i + 1,
458+
err.chunk_preview,
459+
err.voice,
460+
err,
461+
)
462+
return await self._abort_synthesis()
463+
except Exception as err:
464+
_LOGGER.exception(
465+
"Unexpected error while synthesizing sentence %d (%s): %s",
466+
i + 1,
467+
sentence_preview,
468+
err,
469+
)
470+
return await self._abort_synthesis()
471+
finally:
472+
self._allow_streaming_task_id = None
440473

441-
# Get result from task (might be buffered data or None if it streamed)
442-
audio_data = await task
443-
444-
# Disable streaming
445-
self._allow_streaming_task_id = None
446-
447-
# If audio_data is None, task streamed directly (already sent to Wyoming)
448-
if audio_data is None:
474+
if result.streamed:
449475
_LOGGER.debug("Sentence %d streamed directly with minimal latency", i + 1)
450476
# Timestamp and audio_started already updated by _stream_tts_audio_incremental
451477
continue
452478

453479
# Otherwise, task completed and buffered - stream the buffered data now
480+
audio_data = result.audio
481+
if not audio_data:
482+
_LOGGER.error(
483+
"Buffered synthesis returned no audio for sentence %d (%s)",
484+
i + 1,
485+
sentence_preview,
486+
)
487+
return await self._abort_synthesis()
488+
454489
chunk_timestamp = await self._stream_audio_to_wyoming(
455490
audio_data,
456491
is_first_chunk=(not self._audio_started),
457-
start_timestamp=self._current_timestamp
492+
start_timestamp=self._current_timestamp,
458493
)
459494

460-
if chunk_timestamp is not None:
461-
self._current_timestamp = chunk_timestamp
462-
self._audio_started = True
463-
_LOGGER.debug("Successfully streamed buffered sentence %d, timestamp: %.2f", i + 1, chunk_timestamp)
464-
else:
495+
if chunk_timestamp is None:
465496
_LOGGER.error("Failed to stream sentence %d to Wyoming", i + 1)
497+
return await self._abort_synthesis()
498+
499+
self._current_timestamp = chunk_timestamp
500+
self._audio_started = True
501+
_LOGGER.debug(
502+
"Successfully streamed buffered sentence %d, timestamp: %.2f",
503+
i + 1,
504+
chunk_timestamp,
505+
)
466506

507+
return True
467508
except Exception as e:
468509
_LOGGER.exception("Error processing ready sentences: %s", e)
510+
return await self._abort_synthesis()
511+
469512

470513
async def _stream_tts_audio_incremental(self, text: str, voice: TtsVoiceModel) -> float | None:
471514
"""
@@ -494,6 +537,26 @@ async def _stream_tts_audio_incremental(self, text: str, voice: TtsVoiceModel) -
494537

495538
return timestamp
496539

540+
541+
async def _abort_synthesis(self) -> bool:
542+
"""Abort the current synthesis session, emitting stop events and resetting state."""
543+
if self._audio_started:
544+
await self.write_event(AudioStop(timestamp=int(self._current_timestamp)).event())
545+
546+
await self.write_event(SynthesizeStopped().event())
547+
548+
self._audio_started = False
549+
self._current_timestamp = 0
550+
self._allow_streaming_task_id = None
551+
self._is_synthesizing = False
552+
self._synthesis_buffer = []
553+
self._text_accumulator = ""
554+
self._ready_chunks = []
555+
self._pysbd_segmenters.clear()
556+
self._synthesis_voice = None
557+
558+
return False
559+
497560
def _log_unsupported_asr_model(self, model_name: str | None = None):
498561
"""Log an unsupported ASR model"""
499562
if model_name:
@@ -666,7 +729,8 @@ async def _handle_synthesize_chunk(self, synthesize_chunk: SynthesizeChunk) -> b
666729
_LOGGER.info("Detected %d ready sentences for immediate synthesis: %s",
667730
len(ready_sentences),
668731
[_truncate_for_log(s, 30) for s in ready_sentences])
669-
await self._process_ready_sentences(ready_sentences, requested_language)
732+
if not await self._process_ready_sentences(ready_sentences, requested_language):
733+
return False
670734
else:
671735
_LOGGER.debug("No complete sentences ready yet, accumulator has: '%s'",
672736
_truncate_for_log(self._text_accumulator))
@@ -687,7 +751,8 @@ async def _handle_synthesize_stop(self) -> bool:
687751
_LOGGER.info("Processing final remaining text: '%s'",
688752
_truncate_for_log(self._text_accumulator))
689753
requested_language = self._synthesis_voice.language if self._synthesis_voice else None
690-
await self._process_ready_sentences([self._text_accumulator], requested_language)
754+
if not await self._process_ready_sentences([self._text_accumulator], requested_language):
755+
return False
691756

692757
# Get accumulated text and voice for fallback
693758
full_text = "".join(self._synthesis_buffer)
@@ -746,6 +811,7 @@ async def _handle_synthesize_stop(self) -> bool:
746811
_LOGGER.debug("Text chunked into %d parts for streaming synthesis", len(chunks))
747812

748813
# Create ALL tasks with IDs - API calls start concurrently
814+
# Semaphore limits actual concurrency to TTS_CONCURRENT_REQUESTS
749815
_LOGGER.info("Starting concurrent synthesis for %d chunks", len(chunks))
750816
synthesis_tasks = [
751817
(
@@ -762,37 +828,59 @@ async def _handle_synthesize_stop(self) -> bool:
762828
# Enable streaming for whichever task we're currently awaiting
763829
total_timestamp = 0
764830
for i, (task_id, task) in enumerate(synthesis_tasks):
831+
chunk_preview = _truncate_for_log(chunks[i], 50)
765832
_LOGGER.debug("Streaming chunk %d/%d to Wyoming", i + 1, len(chunks))
766833

767-
# Enable streaming for this task
768834
self._allow_streaming_task_id = task_id
835+
try:
836+
result = await task
837+
except TtsStreamError as err:
838+
_LOGGER.error(
839+
"Failed to synthesize chunk %d (%s) with voice %s: %s",
840+
i + 1,
841+
err.chunk_preview,
842+
err.voice,
843+
err,
844+
)
845+
return await self._abort_synthesis()
846+
except Exception as err:
847+
_LOGGER.exception(
848+
"Unexpected error while synthesizing chunk %d (%s): %s",
849+
i + 1,
850+
chunk_preview,
851+
err,
852+
)
853+
return await self._abort_synthesis()
854+
finally:
855+
self._allow_streaming_task_id = None
769856

770-
# Get result from task (might be buffered data or None if it streamed)
771-
audio_data = await task
772-
773-
# Disable streaming
774-
self._allow_streaming_task_id = None
775-
776-
# If audio_data is None, task streamed directly
777-
if audio_data is None:
857+
if result.streamed:
778858
_LOGGER.debug("Chunk %d streamed directly", i + 1)
779859
# Update timestamp from streamed audio
780860
total_timestamp = self._current_timestamp
781861
continue
782862

783863
# Otherwise, stream the buffered data
864+
audio_data = result.audio
865+
if not audio_data:
866+
_LOGGER.error(
867+
"Buffered synthesis returned no audio for chunk %d (%s)",
868+
i + 1,
869+
chunk_preview,
870+
)
871+
return await self._abort_synthesis()
872+
784873
chunk_timestamp = await self._stream_audio_to_wyoming(
785874
audio_data,
786875
is_first_chunk=(i == 0),
787-
start_timestamp=total_timestamp
876+
start_timestamp=total_timestamp,
788877
)
789878

790-
if chunk_timestamp is not None:
791-
total_timestamp = chunk_timestamp
792-
else:
879+
if chunk_timestamp is None:
793880
_LOGGER.error("Failed to stream chunk %d to Wyoming", i + 1)
794-
await self.write_event(SynthesizeStopped().event())
795-
return False
881+
return await self._abort_synthesis()
882+
883+
total_timestamp = chunk_timestamp
796884

797885
# Send final audio stop
798886
await self.write_event(AudioStop(timestamp=total_timestamp).event())
@@ -813,7 +901,7 @@ async def _handle_synthesize_stop(self) -> bool:
813901
await self.write_event(SynthesizeStopped().event())
814902
return False
815903

816-
async def _get_tts_audio_stream(self, text: str, voice: TtsVoiceModel, task_id: str | None = None) -> bytes | None:
904+
async def _get_tts_audio_stream(self, text: str, voice: TtsVoiceModel, task_id: str | None = None) -> TtsStreamResult:
817905
"""
818906
Get TTS audio stream from OpenAI for a text chunk (parallel-safe).
819907
@@ -826,21 +914,22 @@ async def _get_tts_audio_stream(self, text: str, voice: TtsVoiceModel, task_id:
826914
task_id (str | None): Optional task identifier for streaming coordination.
827915
828916
Returns:
829-
bytes | None: Complete audio data for the chunk (if buffered), or None if streamed directly.
917+
TtsStreamResult: Container with streaming status and optional buffered audio.
830918
"""
919+
chunk_preview = _truncate_for_log(text, 50)
920+
831921
try:
832922
# Check if this task is allowed to stream directly
833-
should_stream = (task_id is not None and task_id == self._allow_streaming_task_id)
923+
should_stream = task_id is not None and task_id == self._allow_streaming_task_id
834924

835925
if should_stream:
836926
# Stream directly to Wyoming (no buffering) - minimal latency
837-
_LOGGER.debug("Streaming chunk directly (task %s): %s", task_id, _truncate_for_log(text, 50))
927+
_LOGGER.debug("Streaming chunk directly (task %s): %s", task_id, chunk_preview)
838928
timestamp = await self._stream_tts_audio_incremental(text, voice)
839929
if timestamp is None:
840-
_LOGGER.error("Failed to stream chunk directly")
841-
return None
842-
_LOGGER.debug("Completed direct streaming for chunk: %s", _truncate_for_log(text, 50))
843-
return None # Signal that streaming was handled
930+
raise TtsStreamError("OpenAI returned no audio while streaming chunk", chunk_preview, voice.name)
931+
_LOGGER.debug("Completed direct streaming for chunk: %s", chunk_preview)
932+
return TtsStreamResult(streamed=True)
844933

845934
# Buffer audio (default behavior for parallel tasks)
846935
audio_data = b""
@@ -855,12 +944,18 @@ async def _get_tts_audio_stream(self, text: str, voice: TtsVoiceModel, task_id:
855944
async for chunk in response.iter_bytes(chunk_size=TTS_CHUNK_SIZE):
856945
audio_data += chunk
857946

858-
_LOGGER.debug("Completed buffered synthesis for chunk: %s", _truncate_for_log(text, 50))
859-
return audio_data
947+
if not audio_data:
948+
raise TtsStreamError("OpenAI returned empty audio response", chunk_preview, voice.name)
949+
950+
_LOGGER.debug("Completed buffered synthesis for chunk: %s", chunk_preview)
951+
return TtsStreamResult(streamed=False, audio=audio_data)
952+
953+
except TtsStreamError:
954+
raise
955+
except Exception as exc:
956+
_LOGGER.exception("Error getting TTS audio stream for %s: %s", chunk_preview, exc)
957+
raise TtsStreamError("Unexpected error while retrieving TTS audio", chunk_preview, voice.name) from exc
860958

861-
except Exception as e:
862-
_LOGGER.exception("Error getting TTS audio stream: %s", e)
863-
return None
864959

865960
async def _stream_audio_to_wyoming(self, audio_data: bytes, is_first_chunk: bool, start_timestamp: float) -> float | None:
866961
"""

0 commit comments

Comments
 (0)