diff --git a/inference.py b/inference.py index fef464a..54d06d0 100644 --- a/inference.py +++ b/inference.py @@ -7,6 +7,8 @@ from pathlib import Path from dataclasses import dataclass import io +import nltk +import numpy as np from mars5.model import CodecLM, ResidualTransformer from vocos import Vocos @@ -68,7 +70,6 @@ class InferenceConfig(): # disabling/enabling kv caching won't affect output quality use_kv_cache: bool = True - # Leading and trailing silences will be trimmed from final output # Trim_db is the threshold (in decibels) below reference to consider as silence trim_db: float = 27 @@ -76,6 +77,14 @@ class InferenceConfig(): ref_audio_pad: float = 0 + # How many characters are in the sliding window that chunks the text if it + # is too long. + sliding_window_size: int = 120 + + # If True, we use the generated audio from the previous chunk as the reference + # If False, the reference provided by the user is always used. + sliding_window_reuse_reference: bool = False + class Mars5TTS(nn.Module, ModelHubMixin): def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None: super().__init__() @@ -110,7 +119,6 @@ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None: p_cond_drop=0, dropout=0) self.codecnar.load_state_dict(nar_ckpt['model']) self.codecnar = self.codecnar.to(self.device).eval() - self.default_T = 200 self.sr = 24000 self.latent_sr = 75 @@ -120,6 +128,9 @@ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None: nuke_weight_norm(self.codec) nuke_weight_norm(self.vocos) + # Download `punkt` for sentence segmentation + nltk.download('punkt') + @classmethod def _from_pretrained( cls: Type["Mars5TTS"], @@ -197,10 +208,48 @@ def get_speaker_embedding(self, ref_audio: Tensor) -> Tensor: # pass through transformer res = self.codeclm.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). return res.squeeze(1) + + def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None, + cfg: Optional[InferenceConfig] = InferenceConfig()): + + sentences = nltk.tokenize.sent_tokenize(text) + text_chunks = [] + + for sentence in sentences: + + # Handle sentences that are too long + while len(sentence) > cfg.sliding_window_size: + split_whitespace = sentence.index(' ', cfg.sliding_window_size) + chunk, sentence = sentence[:split_whitespace], sentence[split_whitespace:] + text_chunks.append(chunk) + + # Chunk sentences together + if text_chunks and len(sentence) + len(text_chunks[-1]) <= cfg.sliding_window_size: + text_chunks[-1] += ' ' + sentence + else: + text_chunks.append(sentence) + + text_chunks.insert(0, ref_transcript) + audios = [ref_audio] + ar_codes = [] + + for ref_chunk_text, current_chunk_text in zip(text_chunks, text_chunks[1:]): + ref_chunk_audio = audios[-1] + current_chunk_ar_codes, current_chunk_audio = self.tts_chunk( + text=current_chunk_text, + ref_audio=ref_audio if cfg.sliding_window_reuse_reference else ref_chunk_audio, + ref_transcript=ref_transcript if cfg.sliding_window_reuse_reference else ref_chunk_text, + cfg=cfg, + ) + audios.append(current_chunk_audio) + ar_codes.append(current_chunk_ar_codes.cpu()) + + final_audio = np.hstack(audios[1:]) + return ar_codes, final_audio @torch.inference_mode - def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None, - cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor: + def tts_chunk(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None, + cfg: Optional[InferenceConfig] = InferenceConfig()): """ Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz) which has an associated `ref_transcript`. Perform inference using the inference config given by `cfg`, which controls the temperature, top_p, etc... @@ -283,8 +332,7 @@ def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device) # ---> perform DDPM NAR inference - T = self.default_T - diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device) + diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=cfg.timesteps, device=self.device) dsh_cfg = DSH(last_greedy=True, x_0_temp=cfg.x_0_temp, guidance_w=cfg.nar_guidance_w, diff --git a/requirements.txt b/requirements.txt index bfb3273..930de48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ torch torchvision torchaudio -numpy==1.26.4 +nltk +numpy<2 regex librosa vocos