|
| 1 | +# src: https://github.com/VYNCX/F5-TTS-THAI/blob/99b8314f66a14fc2f0a6b53e5122829fbdf9c59c/src/f5_tts/infer/utils_infer.py |
| 2 | + |
| 3 | +import re |
| 4 | +import tqdm |
| 5 | +import torch |
| 6 | +import numpy as np |
| 7 | +import syllapy |
| 8 | +import torchaudio |
| 9 | + |
| 10 | +from ssg import syllable_tokenize |
| 11 | +from concurrent.futures import ThreadPoolExecutor |
| 12 | +from f5_tts.model.utils import convert_char_to_pinyin |
| 13 | +from f5_tts.infer.utils_infer import target_sample_rate, hop_length, mel_spec_type, target_rms, cross_fade_duration, nfe_step, cfg_strength, sway_sampling_coef, speed, fix_duration, device |
| 14 | + |
| 15 | + |
| 16 | +from util.ipa import any_ipa |
| 17 | + |
| 18 | + |
| 19 | +def custom_chunk_text(text: str, max_chars=200): |
| 20 | + """ |
| 21 | + Splits the input text into chunks by breaking at spaces, creating visually balanced chunks. |
| 22 | +
|
| 23 | + Args: |
| 24 | + text (str): The text to be split. |
| 25 | + max_chars (int): Approximate maximum number of bytes per chunk in UTF-8 encoding. |
| 26 | +
|
| 27 | + Returns: |
| 28 | + List[str]: A list of text chunks. |
| 29 | + """ |
| 30 | + chunks: list[str] = [] |
| 31 | + current_chunk = "" |
| 32 | + # Replace spaces with <unk> if desired, then split on <unk> or spaces |
| 33 | + text = text.replace(" ", "<unk>") |
| 34 | + segments = re.split(r"(<unk>|\s+)", text) |
| 35 | + |
| 36 | + for segment in segments: |
| 37 | + if not segment or segment in ("<unk>", " "): |
| 38 | + continue |
| 39 | + # Check the byte length for UTF-8 encoding |
| 40 | + if len((current_chunk + segment).encode("utf-8")) <= max_chars: |
| 41 | + current_chunk += segment |
| 42 | + current_chunk += " " # Add space after each segment for readability |
| 43 | + else: |
| 44 | + if current_chunk: |
| 45 | + chunks.append(current_chunk.strip()) |
| 46 | + current_chunk = segment + " " |
| 47 | + |
| 48 | + if current_chunk: |
| 49 | + chunks.append(current_chunk.strip()) |
| 50 | + |
| 51 | + # Replace <unk> back with spaces in the final output |
| 52 | + chunks = [chunk.replace("<unk>", " ") for chunk in chunks] |
| 53 | + |
| 54 | + return chunks |
| 55 | + |
| 56 | + |
| 57 | +# estimated duration with syllable |
| 58 | +FRAMES_PER_SEC = target_sample_rate / hop_length |
| 59 | + |
| 60 | + |
| 61 | +def words_to_frame(text: str, frame_per_words: int): |
| 62 | + thai_pattern = r'[\u0E00-\u0E7F\s]+' |
| 63 | + english_pattern = r'[a-zA-Z\s]+' |
| 64 | + |
| 65 | + thai_segs = re.findall(thai_pattern, text) |
| 66 | + eng_segs = re.findall(english_pattern, text) |
| 67 | + |
| 68 | + syl_th = sum(len(syllable_tokenize(seg.strip())) for seg in thai_segs if seg.strip()) |
| 69 | + syl_en = sum(syllapy._syllables(seg.strip()) for seg in eng_segs if seg.strip()) |
| 70 | + syl_unk = text.count(',') # Count spaces as 1 syllable each |
| 71 | + |
| 72 | + duration = (syl_th + syl_en + syl_unk) * frame_per_words |
| 73 | + # print(f"Thai: {syl_th}, Eng: {syl_en}, Spaces: {syl_unk}, Total: {duration} frames") |
| 74 | + return duration |
| 75 | + |
| 76 | + |
| 77 | +def custom_infer_process( |
| 78 | + ref_audio, |
| 79 | + ref_text, |
| 80 | + gen_text, |
| 81 | + model_obj, |
| 82 | + vocoder, |
| 83 | + mel_spec_type=mel_spec_type, |
| 84 | + show_info=print, |
| 85 | + progress=tqdm, |
| 86 | + target_rms=target_rms, |
| 87 | + cross_fade_duration=cross_fade_duration, |
| 88 | + nfe_step=nfe_step, |
| 89 | + cfg_strength=cfg_strength, |
| 90 | + sway_sampling_coef=sway_sampling_coef, |
| 91 | + speed=speed, |
| 92 | + fix_duration=fix_duration, |
| 93 | + device=device, |
| 94 | + set_max_chars=250, |
| 95 | + use_ipa=False |
| 96 | +): |
| 97 | + # Split the input text into batches |
| 98 | + audio, sr = torchaudio.load(ref_audio) |
| 99 | + # max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) |
| 100 | + gen_text_batches = custom_chunk_text(gen_text, max_chars=set_max_chars) |
| 101 | + for i, gen_text in enumerate(gen_text_batches): |
| 102 | + print(f"gen_text {i}", gen_text) |
| 103 | + print("\n") |
| 104 | + |
| 105 | + show_info(f"Generating audio in {len(gen_text_batches)} batches...") |
| 106 | + return next( |
| 107 | + custom_infer_batch_process( |
| 108 | + (audio, sr), |
| 109 | + ref_text, |
| 110 | + gen_text_batches, |
| 111 | + model_obj, |
| 112 | + vocoder, |
| 113 | + mel_spec_type=mel_spec_type, |
| 114 | + progress=progress, |
| 115 | + target_rms=target_rms, |
| 116 | + cross_fade_duration=cross_fade_duration, |
| 117 | + nfe_step=nfe_step, |
| 118 | + cfg_strength=cfg_strength, |
| 119 | + sway_sampling_coef=sway_sampling_coef, |
| 120 | + speed=speed, |
| 121 | + fix_duration=fix_duration, |
| 122 | + device=device, |
| 123 | + use_ipa=use_ipa |
| 124 | + ) |
| 125 | + ) |
| 126 | + |
| 127 | +# infer batches |
| 128 | + |
| 129 | + |
| 130 | +def custom_infer_batch_process( |
| 131 | + ref_audio, |
| 132 | + ref_text, |
| 133 | + gen_text_batches, |
| 134 | + model_obj, |
| 135 | + vocoder, |
| 136 | + mel_spec_type="vocos", |
| 137 | + progress=tqdm, |
| 138 | + target_rms=0.1, |
| 139 | + cross_fade_duration=0.15, |
| 140 | + nfe_step=32, |
| 141 | + cfg_strength=2.0, |
| 142 | + sway_sampling_coef: float = -1, |
| 143 | + speed: float = 1, |
| 144 | + fix_duration=None, |
| 145 | + device=None, |
| 146 | + streaming=False, |
| 147 | + chunk_size=2048, |
| 148 | + use_ipa=False |
| 149 | +): |
| 150 | + audio, sr = ref_audio |
| 151 | + if audio.shape[0] > 1: |
| 152 | + audio = torch.mean(audio, dim=0, keepdim=True) |
| 153 | + |
| 154 | + rms = torch.sqrt(torch.mean(torch.square(audio))) |
| 155 | + if rms < target_rms: |
| 156 | + audio = audio * target_rms / rms |
| 157 | + if sr != target_sample_rate: |
| 158 | + resampler = torchaudio.transforms.Resample(sr, target_sample_rate) |
| 159 | + audio = resampler(audio) |
| 160 | + audio = audio.to(device) |
| 161 | + |
| 162 | + generated_waves = [] |
| 163 | + spectrograms = [] |
| 164 | + |
| 165 | + if len(ref_text[-1].encode("utf-8")) == 1: |
| 166 | + ref_text = ref_text + " " |
| 167 | + |
| 168 | + def process_batch(gen_text): |
| 169 | + local_speed = speed |
| 170 | + if len(gen_text.encode("utf-8")) < 15: |
| 171 | + local_speed = 0.3 |
| 172 | + |
| 173 | + # Prepare the text |
| 174 | + if use_ipa: |
| 175 | + ref_text_ipa = any_ipa(ref_text) |
| 176 | + gen_text_ipa = any_ipa(gen_text) |
| 177 | + final_text_list = [ref_text_ipa + " " + gen_text_ipa] # pyright: ignore[reportOperatorIssue] |
| 178 | + else: |
| 179 | + text_list = [ref_text + gen_text] |
| 180 | + final_text_list = convert_char_to_pinyin(text_list) |
| 181 | + |
| 182 | + ref_audio_len = audio.shape[-1] // hop_length |
| 183 | + if fix_duration is not None: |
| 184 | + duration = int(fix_duration * target_sample_rate / hop_length) |
| 185 | + else: |
| 186 | + # Calculate duration |
| 187 | + FRAMES_PER_WORDS = FRAMES_PER_SEC / 4 |
| 188 | + speech_rate = int(FRAMES_PER_WORDS / local_speed) |
| 189 | + duration = ref_audio_len + words_to_frame(text=gen_text, frame_per_words=speech_rate) |
| 190 | + |
| 191 | + # inference |
| 192 | + with torch.inference_mode(): |
| 193 | + generated, _ = model_obj.sample( |
| 194 | + cond=audio, |
| 195 | + text=final_text_list, |
| 196 | + duration=duration, |
| 197 | + steps=nfe_step, |
| 198 | + cfg_strength=cfg_strength, |
| 199 | + sway_sampling_coef=sway_sampling_coef, |
| 200 | + lens=torch.tensor([ref_audio_len], device=device, dtype=torch.long) |
| 201 | + ) |
| 202 | + del _ |
| 203 | + |
| 204 | + generated = generated.to(torch.float32) # generated mel spectrogram |
| 205 | + generated = generated[:, ref_audio_len:, :] |
| 206 | + generated = generated.permute(0, 2, 1) |
| 207 | + if mel_spec_type == "vocos": |
| 208 | + generated_wave = vocoder.decode(generated) |
| 209 | + elif mel_spec_type == "bigvgan": |
| 210 | + generated_wave = vocoder(generated) |
| 211 | + if rms < target_rms: |
| 212 | + generated_wave = generated_wave * rms / target_rms |
| 213 | + |
| 214 | + # wav -> numpy |
| 215 | + generated_wave = generated_wave.squeeze().cpu().numpy() |
| 216 | + |
| 217 | + if streaming: |
| 218 | + for j in range(0, len(generated_wave), chunk_size): |
| 219 | + yield generated_wave[j: j + chunk_size], target_sample_rate |
| 220 | + else: |
| 221 | + generated_cpu = generated[0].cpu().numpy() |
| 222 | + del generated |
| 223 | + yield generated_wave, generated_cpu |
| 224 | + |
| 225 | + if streaming: |
| 226 | + for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: |
| 227 | + for chunk in process_batch(gen_text): |
| 228 | + yield chunk |
| 229 | + else: |
| 230 | + with ThreadPoolExecutor() as executor: |
| 231 | + futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] |
| 232 | + for future in progress.tqdm(futures) if progress is not None else futures: |
| 233 | + result = future.result() |
| 234 | + if result: |
| 235 | + generated_wave, generated_mel_spec = next(result) |
| 236 | + generated_waves.append(generated_wave) |
| 237 | + spectrograms.append(generated_mel_spec) |
| 238 | + |
| 239 | + if generated_waves: |
| 240 | + if cross_fade_duration <= 0: |
| 241 | + # Simply concatenate |
| 242 | + final_wave = np.concatenate(generated_waves) |
| 243 | + else: |
| 244 | + # Combine all generated waves with cross-fading |
| 245 | + final_wave = generated_waves[0] |
| 246 | + for i in range(1, len(generated_waves)): |
| 247 | + prev_wave = final_wave |
| 248 | + next_wave = generated_waves[i] |
| 249 | + |
| 250 | + # Calculate cross-fade samples, ensuring it does not exceed wave lengths |
| 251 | + cross_fade_samples = int(cross_fade_duration * target_sample_rate) |
| 252 | + cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) |
| 253 | + |
| 254 | + if cross_fade_samples <= 0: |
| 255 | + # No overlap possible, concatenate |
| 256 | + final_wave = np.concatenate([prev_wave, next_wave]) |
| 257 | + continue |
| 258 | + |
| 259 | + # Overlapping parts |
| 260 | + prev_overlap = prev_wave[-cross_fade_samples:] |
| 261 | + next_overlap = next_wave[:cross_fade_samples] |
| 262 | + |
| 263 | + # Fade out and fade in |
| 264 | + fade_out = np.linspace(1, 0, cross_fade_samples) |
| 265 | + fade_in = np.linspace(0, 1, cross_fade_samples) |
| 266 | + |
| 267 | + # Cross-faded overlap |
| 268 | + cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in |
| 269 | + |
| 270 | + # Combine |
| 271 | + new_wave = np.concatenate( |
| 272 | + [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] |
| 273 | + ) |
| 274 | + |
| 275 | + final_wave = new_wave |
| 276 | + |
| 277 | + # Create a combined spectrogram |
| 278 | + combined_spectrogram = np.concatenate(spectrograms, axis=1) |
| 279 | + |
| 280 | + yield final_wave, target_sample_rate, combined_spectrogram |
| 281 | + |
| 282 | + else: |
| 283 | + yield None, target_sample_rate, None |
0 commit comments