Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from dataclasses import dataclass
import io
import nltk
import numpy as np

from mars5.model import CodecLM, ResidualTransformer
from vocos import Vocos
Expand Down Expand Up @@ -68,14 +70,21 @@ class InferenceConfig():
# disabling/enabling kv caching won't affect output quality
use_kv_cache: bool = True


# Leading and trailing silences will be trimmed from final output
# Trim_db is the threshold (in decibels) below reference to consider as silence
trim_db: float = 27
beam_width: int = 1 # only beam width of 1 is currently supported

ref_audio_pad: float = 0

# How many characters are in the sliding window that chunks the text if it
# is too long.
sliding_window_size: int = 120

# If True, we use the generated audio from the previous chunk as the reference
# If False, the reference provided by the user is always used.
sliding_window_reuse_reference: bool = False

class Mars5TTS(nn.Module, ModelHubMixin):
def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
super().__init__()
Expand Down Expand Up @@ -110,7 +119,6 @@ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
p_cond_drop=0, dropout=0)
self.codecnar.load_state_dict(nar_ckpt['model'])
self.codecnar = self.codecnar.to(self.device).eval()
self.default_T = 200

self.sr = 24000
self.latent_sr = 75
Expand All @@ -120,6 +128,9 @@ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
nuke_weight_norm(self.codec)
nuke_weight_norm(self.vocos)

# Download `punkt` for sentence segmentation
nltk.download('punkt')

@classmethod
def _from_pretrained(
cls: Type["Mars5TTS"],
Expand Down Expand Up @@ -197,10 +208,48 @@ def get_speaker_embedding(self, ref_audio: Tensor) -> Tensor:
# pass through transformer
res = self.codeclm.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
return res.squeeze(1)

def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
cfg: Optional[InferenceConfig] = InferenceConfig()):

sentences = nltk.tokenize.sent_tokenize(text)
text_chunks = []

for sentence in sentences:

# Handle sentences that are too long
while len(sentence) > cfg.sliding_window_size:
split_whitespace = sentence.index(' ', cfg.sliding_window_size)
chunk, sentence = sentence[:split_whitespace], sentence[split_whitespace:]
text_chunks.append(chunk)

# Chunk sentences together
if text_chunks and len(sentence) + len(text_chunks[-1]) <= cfg.sliding_window_size:
text_chunks[-1] += ' ' + sentence
else:
text_chunks.append(sentence)

text_chunks.insert(0, ref_transcript)
audios = [ref_audio]
ar_codes = []

for ref_chunk_text, current_chunk_text in zip(text_chunks, text_chunks[1:]):
ref_chunk_audio = audios[-1]
current_chunk_ar_codes, current_chunk_audio = self.tts_chunk(
text=current_chunk_text,
ref_audio=ref_audio if cfg.sliding_window_reuse_reference else ref_chunk_audio,
ref_transcript=ref_transcript if cfg.sliding_window_reuse_reference else ref_chunk_text,
cfg=cfg,
)
audios.append(current_chunk_audio)
ar_codes.append(current_chunk_ar_codes.cpu())

final_audio = np.hstack(audios[1:])
return ar_codes, final_audio

@torch.inference_mode
def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
def tts_chunk(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
cfg: Optional[InferenceConfig] = InferenceConfig()):
""" Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz)
which has an associated `ref_transcript`. Perform inference using the inference
config given by `cfg`, which controls the temperature, top_p, etc...
Expand Down Expand Up @@ -283,8 +332,7 @@ def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None
x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)

# ---> perform DDPM NAR inference
T = self.default_T
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=cfg.timesteps, device=self.device)

dsh_cfg = DSH(last_greedy=True, x_0_temp=cfg.x_0_temp,
guidance_w=cfg.nar_guidance_w,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
torch
torchvision
torchaudio
numpy==1.26.4
nltk
numpy<2
regex
librosa
vocos
Expand Down