Skip to content

Commit feecc04

Browse files
committed
Working f5 with proper streamed sentence split
1 parent 5ffa8cc commit feecc04

File tree

4 files changed

+219
-68
lines changed

4 files changed

+219
-68
lines changed

entrypoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ set -Eeuo pipefail
1515
: "${THTTS_DEVICE:=auto}" # auto|cpu|cuda
1616
: "${THTTS_SPEED:=1.0}"
1717
: "${THTTS_NFE_STEPS:=32}"
18-
: "${THTTS_MAX_CONCURRENT:=2}"
18+
: "${THTTS_MAX_CONCURRENT:=1}"
1919
: "${THTTS_CKPT_FILE:=}" # optional override
2020
: "${THTTS_VOCAB_FILE:=}" # optional override
2121

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"cached-path",
99
"f5-tts",
1010
"omegaconf",
11+
"pythainlp>=5.1.2",
1112
"torch>=2.8.0",
1213
"torchaudio",
1314
"transformers>=4.55.2",

src/wyoming_thai_f5.py

Lines changed: 195 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,4 @@
11
#!/usr/bin/env python3
2-
import argparse
3-
import asyncio
4-
import logging
5-
import re
6-
import os
7-
from typing import List
8-
import numpy as np
9-
10-
# ---- Wyoming imports (official lib) ----
11-
from wyoming.event import Event
12-
from wyoming.server import AsyncEventHandler, AsyncServer
13-
14-
from wyoming.tts import (
15-
Synthesize,
16-
SynthesizeStart,
17-
SynthesizeChunk,
18-
SynthesizeStop,
19-
SynthesizeStopped,
20-
)
21-
22-
from wyoming.audio import (
23-
AudioStart,
24-
AudioChunk,
25-
AudioStop,
26-
)
27-
28-
from wyoming.info import Info, Describe, TtsProgram, TtsVoice, Attribution
29-
from wyoming.error import Error
30-
31-
# ---- F5-TTS (Thai) imports ----
32-
import torch
33-
from importlib.resources import files
34-
from cached_path import cached_path
35-
from omegaconf import OmegaConf
36-
37-
from f5_tts.model import DiT
382
from f5_tts.infer.utils_infer import (
393
mel_spec_type, # "vocos" (24 kHz)
404
target_rms,
@@ -49,6 +13,38 @@
4913
load_vocoder,
5014
preprocess_ref_audio_text,
5115
)
16+
from f5_tts.model import DiT
17+
from cached_path import cached_path
18+
import torch
19+
from wyoming.error import Error
20+
from wyoming.info import Info, Describe, TtsProgram, TtsVoice, Attribution
21+
from wyoming.audio import (
22+
AudioStart,
23+
AudioChunk,
24+
AudioStop,
25+
)
26+
from wyoming.tts import (
27+
Synthesize,
28+
SynthesizeStart,
29+
SynthesizeChunk,
30+
SynthesizeStop,
31+
SynthesizeStopped,
32+
)
33+
from wyoming.server import AsyncEventHandler, AsyncServer
34+
from wyoming.event import Event
35+
from pythainlp.tokenize import sent_tokenize
36+
import argparse
37+
import asyncio
38+
import logging
39+
from typing import List
40+
import numpy as np
41+
42+
MIN_CHARS = 48 # flush when buffer reaches this length
43+
MAX_WAIT_MS = 220 # flush if idle for this many ms
44+
MAX_SENT_LEN = 180 # if a 'sentence' is too long, treat as complete to avoid stalling
45+
TERMINATORS = {"।", "?", "!", "…", "\n"} # Thai often lacks punctuation; timeout/length still cover
46+
MIN_SENT_CHARS = 15 # do not emit a sentence shorter than this unless final flush
47+
5248

5349
# -----------------------
5450
# Utils
@@ -62,21 +58,68 @@ def float32_to_int16_pcm(x: np.ndarray) -> bytes:
6258

6359

6460
def split_sentences_th(text: str) -> List[str]:
65-
parts = re.split(r'([.!?。\n])', text)
66-
chunks, buf = [], ""
67-
for p in parts:
68-
if p is None:
69-
continue
70-
buf += p
71-
if p in {".", "!", "?", "。", "\n"}:
72-
s = buf.strip()
73-
if s:
74-
chunks.append(s)
75-
buf = ""
76-
tail = buf.strip()
77-
if tail:
78-
chunks.append(tail)
79-
return [c for c in (s.strip() for s in chunks) if c]
61+
splitted = sent_tokenize(text, keep_whitespace=False, engine="thaisum")
62+
logging.debug(f"Splitted sentences to: len={len(splitted)} {splitted}")
63+
return splitted
64+
65+
66+
def _split_ready_vs_tail(text: str, *, final: bool = False) -> tuple[list[str], str]:
67+
"""
68+
Tokenize Thai into sentences and return (ready_sentences, tail_remainder).
69+
Strategy:
70+
- If >=2 sentences, treat all but the last as ready; keep last as tail.
71+
- If only 1 sentence:
72+
- If it ends with a terminator or is very long, treat as ready.
73+
- Else keep as tail (incomplete).
74+
- Additionally, never emit a ready sentence with length < MIN_SENT_CHARS
75+
by coalescing it with the next sentence — unless final=True.
76+
"""
77+
sents = split_sentences_th(text)
78+
if not sents:
79+
return [], ""
80+
81+
if len(sents) >= 2:
82+
base = sents[:-1]
83+
last = sents[-1]
84+
# Coalesce short sentences in 'base' so we only emit items >= MIN_SENT_CHARS (unless final).
85+
ready: list[str] = []
86+
acc = ""
87+
for s in base:
88+
if not final and (len(s) < MIN_SENT_CHARS):
89+
acc += s
90+
if len(acc) >= MIN_SENT_CHARS:
91+
ready.append(acc)
92+
acc = ""
93+
else:
94+
if acc:
95+
# prefer to attach short acc to this sentence if it keeps it coherent
96+
merged = acc + s
97+
if not final and len(merged) < MIN_SENT_CHARS:
98+
acc = merged
99+
else:
100+
ready.append(merged)
101+
acc = ""
102+
else:
103+
ready.append(s)
104+
# Whatever remains in acc is too short; push it into the tail.
105+
tail = (acc + last)
106+
# If tail is obviously complete, we may emit it too.
107+
if tail and (tail[-1] in TERMINATORS or len(tail) >= MAX_SENT_LEN or final):
108+
if final or len(tail) >= MIN_SENT_CHARS:
109+
ready.append(tail)
110+
tail = ""
111+
return ready, tail
112+
113+
# single sentence
114+
s = sents[0]
115+
if s and (s[-1] in TERMINATORS or len(s) >= MAX_SENT_LEN or final):
116+
# Only emit if it's long enough, unless final=True
117+
if final or len(s) >= MIN_SENT_CHARS:
118+
return [s], ""
119+
else:
120+
# too short and not final → keep waiting
121+
return [], s
122+
return [], s
80123

81124

82125
# -----------------------
@@ -190,6 +233,8 @@ def __init__(self, *args, engine: ThaiF5Engine, sem: asyncio.Semaphore, **kwargs
190233
self._streaming = False
191234
self.engine = engine
192235
self.sem = sem
236+
self._buf: list[str] = [] # list of chunk strings
237+
self._flush_task = None # asyncio.Task or None
193238
peer = getattr(self, "writer", None)
194239
try:
195240
addr = peer.get_extra_info("peername") if peer else None
@@ -258,6 +303,7 @@ async def handle_event(self, event: Event) -> bool:
258303

259304
if SynthesizeStart.is_type(event.type):
260305
self._streaming = True
306+
self._reset_buffer()
261307
logging.info("Synthesize streaming START: %s", event)
262308
return True
263309

@@ -267,14 +313,31 @@ async def handle_event(self, event: Event) -> bool:
267313
if not text:
268314
logging.debug("Empty chunk")
269315
return True
270-
sents = split_sentences_th(text)
271-
logging.info("Synthesize streaming CHUNK: %d sentences", len(sents))
272-
for sentence in sents:
273-
await self._speak_text(sentence)
316+
317+
# Accumulate
318+
self._buf.append(text)
319+
buf_str = "".join(self._buf)
320+
logging.debug("Accumulated chunk; buffer_len=%d (just got: %r)", len(buf_str), text)
321+
322+
# Peek at sentence segmentation to see if we have a *complete* sentence
323+
sents = split_sentences_th(buf_str)
324+
if len(sents) >= 2:
325+
# We have at least one complete sentence; flush the ready part(s) now
326+
await self._flush_buffer()
327+
elif buf_str and buf_str[-1] in TERMINATORS:
328+
# Single sentence but explicitly terminated; safe to flush
329+
await self._flush_buffer()
330+
else:
331+
# Still constructing the first/only sentence → do NOT flush now.
332+
# Re-arm the idle timer so we don't stall if the producer pauses.
333+
self._schedule_idle_flush()
334+
274335
return True
275336

276337
if SynthesizeStop.is_type(event.type):
277338
logging.info("Synthesize streaming STOP")
339+
# Flush any remaining text first
340+
await self._flush_buffer(force_all=True)
278341
await self.write_event(SynthesizeStopped().event())
279342
self._streaming = False
280343
return True
@@ -294,6 +357,10 @@ async def _speak_text(self, text: str):
294357
return
295358
rate, width, channels = self.engine.sr, 2, 1 # 16-bit, mono
296359
logging.debug("AudioStart: rate=%d width=%d channels=%d", rate, width, channels)
360+
# Make sure no pending idle task fires mid-speak
361+
if self._flush_task and not self._flush_task.done():
362+
self._flush_task.cancel()
363+
self._flush_task = None
297364
await self.write_event(AudioStart(rate=rate, width=width, channels=channels).event())
298365

299366
loop = asyncio.get_running_loop()
@@ -311,10 +378,78 @@ async def _speak_text(self, text: str):
311378
logging.info("Streamed audio: text_len=%d samples=%d bytes=%d",
312379
len(text), len(wav), total_bytes)
313380

381+
def _reset_buffer(self):
382+
self._buf = []
383+
if self._flush_task and not self._flush_task.done():
384+
self._flush_task.cancel()
385+
self._flush_task = None
386+
logging.debug("Resetted buffer")
387+
388+
def _schedule_idle_flush(self):
389+
# Cancel any previous idle flush task and schedule a new one
390+
if self._flush_task and not self._flush_task.done():
391+
self._flush_task.cancel()
392+
393+
self._flush_task = asyncio.create_task(self._idle_wait_and_flush())
394+
395+
async def _idle_wait_and_flush(self):
396+
try:
397+
await asyncio.sleep(MAX_WAIT_MS / 1000.0)
398+
# Only flush if we truly have ready sentences; otherwise keep waiting.
399+
if not self._buf:
400+
return
401+
buf_str = "".join(self._buf)
402+
sents = split_sentences_th(buf_str)
403+
if len(sents) >= 2 or (buf_str and buf_str[-1] in TERMINATORS) or len(buf_str) >= MAX_SENT_LEN:
404+
await self._flush_buffer()
405+
else:
406+
# Not ready yet; re-arm the timer to check again later.
407+
self._schedule_idle_flush()
408+
except asyncio.CancelledError:
409+
pass
410+
411+
async def _flush_buffer(self, force_all: bool = False):
412+
"""
413+
Flush accumulated text:
414+
- If force_all=True, synth everything in the buffer (no remainder).
415+
- Else, synth only full sentences and keep tail remainder.
416+
"""
417+
logging.debug(f"Flushing buffer force_all={force_all}")
418+
if not self._buf:
419+
return
420+
421+
buf_str = "".join(self._buf)
422+
ready_sents: list[str]
423+
tail: str
424+
425+
if force_all:
426+
# Treat the entire buffer as ready (split just to get clean sentences)
427+
# final=True allows emitting < MIN_SENT_CHARS at end of stream
428+
ready_sents, tail = _split_ready_vs_tail(buf_str, final=True)
429+
tail = "" # by definition, final
430+
else:
431+
ready_sents, tail = _split_ready_vs_tail(buf_str, final=False)
432+
433+
if ready_sents:
434+
logging.info("Flushing %d ready sentence(s)", len(ready_sents))
435+
# Prevent a racing idle task from re-flushing the same sentences
436+
if self._flush_task and not self._flush_task.done():
437+
self._flush_task.cancel()
438+
self._flush_task = None
439+
# Move tail back to buffer BEFORE we speak, so any concurrent timer sees only the tail
440+
self._buf = [tail] if tail else []
441+
for sentence in ready_sents:
442+
await self._speak_text(sentence)
443+
444+
else:
445+
# No ready sentences; keep the current buffer as-is
446+
pass
447+
448+
# If there is still a tail, keep the idle flush armed (so it won't stall forever)
449+
if self._buf:
450+
self._schedule_idle_flush()
451+
314452

315-
# -----------------------
316-
# Main
317-
# -----------------------
318453
async def main():
319454
ap = argparse.ArgumentParser()
320455
ap.add_argument("--host", default="0.0.0.0")
@@ -329,7 +464,7 @@ async def main():
329464
ap.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
330465
ap.add_argument("--speed", type=float, default=default_speed, help="Speech speed multiplier.")
331466
ap.add_argument("--nfe-steps", type=int, default=nfe_step, help="Denoising steps.")
332-
ap.add_argument("--max-concurrent", type=int, default=2)
467+
ap.add_argument("--max-concurrent", type=int, default=1, help="Legacy params, do not change")
333468

334469
ap.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
335470
args = ap.parse_args()
@@ -351,7 +486,7 @@ async def main():
351486
speed=args.speed,
352487
nfe_steps=args.nfe_steps,
353488
)
354-
sem = asyncio.Semaphore(args.max_concurrent)
489+
sem = asyncio.Semaphore(args.max_concurrent) # TODO: more than 1 is broken
355490

356491
uri = f"tcp://{args.host}:{args.port}"
357492
server = AsyncServer.from_uri(uri)

0 commit comments

Comments
 (0)