Skip to content

Commit a849fe4

Browse files
committed
Working f5_v2
1 parent 62462c7 commit a849fe4

File tree

9 files changed

+33343
-141
lines changed

9 files changed

+33343
-141
lines changed

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
[project]
22
name = "thtts"
33
version = "0.1.0"
4-
description = "Add your description here"
4+
description = "tts via wyoming for Thai"
55
readme = "README.md"
66
requires-python = ">=3.10"
77
dependencies = [
88
"cached-path",
99
"f5-tts",
10-
"omegaconf",
10+
"langdetect>=1.0.9",
11+
"phonemizer>=3.3.0",
1112
"pythainlp>=5.1.2",
1213
"python-crfsuite>=0.9.11",
14+
"ssg>=0.0.8",
15+
"syllapy>=0.7.2",
16+
"tltk>=1.9.1",
1317
"torch>=2.8.0",
1418
"torchaudio",
1519
"transformers>=4.55.2",

src/util/cleantext.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# src: https://github.com/VYNCX/F5-TTS-THAI/blob/99b8314f66a14fc2f0a6b53e5122829fbdf9c59c/src/f5_tts/cleantext/th_repeat.py
22
# src: https://github.com/VYNCX/F5-TTS-THAI/blob/99b8314f66a14fc2f0a6b53e5122829fbdf9c59c/src/f5_tts/cleantext/number_tha.py
3+
34
import re
45

56
from pythainlp.tokenize import syllable_tokenize

src/util/custom_infer.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
# src: https://github.com/VYNCX/F5-TTS-THAI/blob/99b8314f66a14fc2f0a6b53e5122829fbdf9c59c/src/f5_tts/infer/utils_infer.py
2+
3+
import re
4+
import tqdm
5+
import torch
6+
import numpy as np
7+
import syllapy
8+
import torchaudio
9+
10+
from ssg import syllable_tokenize
11+
from concurrent.futures import ThreadPoolExecutor
12+
from f5_tts.model.utils import convert_char_to_pinyin
13+
from f5_tts.infer.utils_infer import target_sample_rate, hop_length, mel_spec_type, target_rms, cross_fade_duration, nfe_step, cfg_strength, sway_sampling_coef, speed, fix_duration, device
14+
15+
16+
from util.ipa import any_ipa
17+
18+
19+
def custom_chunk_text(text: str, max_chars=200):
20+
"""
21+
Splits the input text into chunks by breaking at spaces, creating visually balanced chunks.
22+
23+
Args:
24+
text (str): The text to be split.
25+
max_chars (int): Approximate maximum number of bytes per chunk in UTF-8 encoding.
26+
27+
Returns:
28+
List[str]: A list of text chunks.
29+
"""
30+
chunks: list[str] = []
31+
current_chunk = ""
32+
# Replace spaces with <unk> if desired, then split on <unk> or spaces
33+
text = text.replace(" ", "<unk>")
34+
segments = re.split(r"(<unk>|\s+)", text)
35+
36+
for segment in segments:
37+
if not segment or segment in ("<unk>", " "):
38+
continue
39+
# Check the byte length for UTF-8 encoding
40+
if len((current_chunk + segment).encode("utf-8")) <= max_chars:
41+
current_chunk += segment
42+
current_chunk += " " # Add space after each segment for readability
43+
else:
44+
if current_chunk:
45+
chunks.append(current_chunk.strip())
46+
current_chunk = segment + " "
47+
48+
if current_chunk:
49+
chunks.append(current_chunk.strip())
50+
51+
# Replace <unk> back with spaces in the final output
52+
chunks = [chunk.replace("<unk>", " ") for chunk in chunks]
53+
54+
return chunks
55+
56+
57+
# estimated duration with syllable
58+
FRAMES_PER_SEC = target_sample_rate / hop_length
59+
60+
61+
def words_to_frame(text: str, frame_per_words: int):
62+
thai_pattern = r'[\u0E00-\u0E7F\s]+'
63+
english_pattern = r'[a-zA-Z\s]+'
64+
65+
thai_segs = re.findall(thai_pattern, text)
66+
eng_segs = re.findall(english_pattern, text)
67+
68+
syl_th = sum(len(syllable_tokenize(seg.strip())) for seg in thai_segs if seg.strip())
69+
syl_en = sum(syllapy._syllables(seg.strip()) for seg in eng_segs if seg.strip())
70+
syl_unk = text.count(',') # Count spaces as 1 syllable each
71+
72+
duration = (syl_th + syl_en + syl_unk) * frame_per_words
73+
# print(f"Thai: {syl_th}, Eng: {syl_en}, Spaces: {syl_unk}, Total: {duration} frames")
74+
return duration
75+
76+
77+
def custom_infer_process(
78+
ref_audio,
79+
ref_text,
80+
gen_text,
81+
model_obj,
82+
vocoder,
83+
mel_spec_type=mel_spec_type,
84+
show_info=print,
85+
progress=tqdm,
86+
target_rms=target_rms,
87+
cross_fade_duration=cross_fade_duration,
88+
nfe_step=nfe_step,
89+
cfg_strength=cfg_strength,
90+
sway_sampling_coef=sway_sampling_coef,
91+
speed=speed,
92+
fix_duration=fix_duration,
93+
device=device,
94+
set_max_chars=250,
95+
use_ipa=False
96+
):
97+
# Split the input text into batches
98+
audio, sr = torchaudio.load(ref_audio)
99+
# max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
100+
gen_text_batches = custom_chunk_text(gen_text, max_chars=set_max_chars)
101+
for i, gen_text in enumerate(gen_text_batches):
102+
print(f"gen_text {i}", gen_text)
103+
print("\n")
104+
105+
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
106+
return next(
107+
custom_infer_batch_process(
108+
(audio, sr),
109+
ref_text,
110+
gen_text_batches,
111+
model_obj,
112+
vocoder,
113+
mel_spec_type=mel_spec_type,
114+
progress=progress,
115+
target_rms=target_rms,
116+
cross_fade_duration=cross_fade_duration,
117+
nfe_step=nfe_step,
118+
cfg_strength=cfg_strength,
119+
sway_sampling_coef=sway_sampling_coef,
120+
speed=speed,
121+
fix_duration=fix_duration,
122+
device=device,
123+
use_ipa=use_ipa
124+
)
125+
)
126+
127+
# infer batches
128+
129+
130+
def custom_infer_batch_process(
131+
ref_audio,
132+
ref_text,
133+
gen_text_batches,
134+
model_obj,
135+
vocoder,
136+
mel_spec_type="vocos",
137+
progress=tqdm,
138+
target_rms=0.1,
139+
cross_fade_duration=0.15,
140+
nfe_step=32,
141+
cfg_strength=2.0,
142+
sway_sampling_coef: float = -1,
143+
speed: float = 1,
144+
fix_duration=None,
145+
device=None,
146+
streaming=False,
147+
chunk_size=2048,
148+
use_ipa=False
149+
):
150+
audio, sr = ref_audio
151+
if audio.shape[0] > 1:
152+
audio = torch.mean(audio, dim=0, keepdim=True)
153+
154+
rms = torch.sqrt(torch.mean(torch.square(audio)))
155+
if rms < target_rms:
156+
audio = audio * target_rms / rms
157+
if sr != target_sample_rate:
158+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
159+
audio = resampler(audio)
160+
audio = audio.to(device)
161+
162+
generated_waves = []
163+
spectrograms = []
164+
165+
if len(ref_text[-1].encode("utf-8")) == 1:
166+
ref_text = ref_text + " "
167+
168+
def process_batch(gen_text):
169+
local_speed = speed
170+
if len(gen_text.encode("utf-8")) < 15:
171+
local_speed = 0.3
172+
173+
# Prepare the text
174+
if use_ipa:
175+
ref_text_ipa = any_ipa(ref_text)
176+
gen_text_ipa = any_ipa(gen_text)
177+
final_text_list = [ref_text_ipa + " " + gen_text_ipa] # pyright: ignore[reportOperatorIssue]
178+
else:
179+
text_list = [ref_text + gen_text]
180+
final_text_list = convert_char_to_pinyin(text_list)
181+
182+
ref_audio_len = audio.shape[-1] // hop_length
183+
if fix_duration is not None:
184+
duration = int(fix_duration * target_sample_rate / hop_length)
185+
else:
186+
# Calculate duration
187+
FRAMES_PER_WORDS = FRAMES_PER_SEC / 4
188+
speech_rate = int(FRAMES_PER_WORDS / local_speed)
189+
duration = ref_audio_len + words_to_frame(text=gen_text, frame_per_words=speech_rate)
190+
191+
# inference
192+
with torch.inference_mode():
193+
generated, _ = model_obj.sample(
194+
cond=audio,
195+
text=final_text_list,
196+
duration=duration,
197+
steps=nfe_step,
198+
cfg_strength=cfg_strength,
199+
sway_sampling_coef=sway_sampling_coef,
200+
lens=torch.tensor([ref_audio_len], device=device, dtype=torch.long)
201+
)
202+
del _
203+
204+
generated = generated.to(torch.float32) # generated mel spectrogram
205+
generated = generated[:, ref_audio_len:, :]
206+
generated = generated.permute(0, 2, 1)
207+
if mel_spec_type == "vocos":
208+
generated_wave = vocoder.decode(generated)
209+
elif mel_spec_type == "bigvgan":
210+
generated_wave = vocoder(generated)
211+
if rms < target_rms:
212+
generated_wave = generated_wave * rms / target_rms
213+
214+
# wav -> numpy
215+
generated_wave = generated_wave.squeeze().cpu().numpy()
216+
217+
if streaming:
218+
for j in range(0, len(generated_wave), chunk_size):
219+
yield generated_wave[j: j + chunk_size], target_sample_rate
220+
else:
221+
generated_cpu = generated[0].cpu().numpy()
222+
del generated
223+
yield generated_wave, generated_cpu
224+
225+
if streaming:
226+
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
227+
for chunk in process_batch(gen_text):
228+
yield chunk
229+
else:
230+
with ThreadPoolExecutor() as executor:
231+
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
232+
for future in progress.tqdm(futures) if progress is not None else futures:
233+
result = future.result()
234+
if result:
235+
generated_wave, generated_mel_spec = next(result)
236+
generated_waves.append(generated_wave)
237+
spectrograms.append(generated_mel_spec)
238+
239+
if generated_waves:
240+
if cross_fade_duration <= 0:
241+
# Simply concatenate
242+
final_wave = np.concatenate(generated_waves)
243+
else:
244+
# Combine all generated waves with cross-fading
245+
final_wave = generated_waves[0]
246+
for i in range(1, len(generated_waves)):
247+
prev_wave = final_wave
248+
next_wave = generated_waves[i]
249+
250+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
251+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
252+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
253+
254+
if cross_fade_samples <= 0:
255+
# No overlap possible, concatenate
256+
final_wave = np.concatenate([prev_wave, next_wave])
257+
continue
258+
259+
# Overlapping parts
260+
prev_overlap = prev_wave[-cross_fade_samples:]
261+
next_overlap = next_wave[:cross_fade_samples]
262+
263+
# Fade out and fade in
264+
fade_out = np.linspace(1, 0, cross_fade_samples)
265+
fade_in = np.linspace(0, 1, cross_fade_samples)
266+
267+
# Cross-faded overlap
268+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
269+
270+
# Combine
271+
new_wave = np.concatenate(
272+
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
273+
)
274+
275+
final_wave = new_wave
276+
277+
# Create a combined spectrogram
278+
combined_spectrogram = np.concatenate(spectrograms, axis=1)
279+
280+
yield final_wave, target_sample_rate, combined_spectrogram
281+
282+
else:
283+
yield None, target_sample_rate, None

0 commit comments

Comments
 (0)