Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from dataclasses import dataclass
import io
import nltk
import numpy as np

from mars5.model import CodecLM, ResidualTransformer
from vocos import Vocos
Expand Down Expand Up @@ -59,10 +61,18 @@ class InferenceConfig():
deep_clone: bool = True

use_kv_cache: bool = True
trim_db: float = 27
trim_db: float = 22
beam_width: int = 1 # only beam width of 1 is currently supported
ref_audio_pad: float = 0

# How many characters are in the sliding window that chunks the text if it
# is too long.
sliding_window_size: int = 120

# If True, we use the generated audio from the previous chunk as the reference
# If False, the reference provided by the user is always used.
sliding_window_reuse_reference: bool = False

class Mars5TTS(nn.Module, ModelHubMixin):
def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
super().__init__()
Expand Down Expand Up @@ -97,7 +107,6 @@ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
p_cond_drop=0, dropout=0)
self.codecnar.load_state_dict(nar_ckpt['model'])
self.codecnar = self.codecnar.to(self.device).eval()
self.default_T = 200

self.sr = 24000
self.latent_sr = 75
Expand All @@ -107,6 +116,9 @@ def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None:
nuke_weight_norm(self.codec)
nuke_weight_norm(self.vocos)

# Download `punkt` for sentence segmentation
nltk.download('punkt', quiet=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a warning that this being downloaded? Or not keep it quiet? Having it like this seems a litte weird?

I.e. this adds NLTK as a dependency in the code, but it isn't specified anywhere. Ideally make it as an optional dependancy or add it to readme/requirements.txt dependancies.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the quiet downloading and added nltk into reqs.


@classmethod
def _from_pretrained(
cls: Type["Mars5TTS"],
Expand Down Expand Up @@ -184,10 +196,48 @@ def get_speaker_embedding(self, ref_audio: Tensor) -> Tensor:
# pass through transformer
res = self.codeclm.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim).
return res.squeeze(1)

def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
cfg: Optional[InferenceConfig] = InferenceConfig()):

sentences = nltk.tokenize.sent_tokenize(text)
text_chunks = []

for sentence in sentences:

# Handle sentences that are too long
while len(sentence) > cfg.sliding_window_size:
split_whitespace = sentence.index(' ', cfg.sliding_window_size)
chunk, sentence = sentence[:split_whitespace], sentence[split_whitespace:]
text_chunks.append(chunk)

# Chunk sentences together
if text_chunks and len(sentence) + len(text_chunks[-1]) <= cfg.sliding_window_size:
text_chunks[-1] += ' ' + sentence
else:
text_chunks.append(sentence)

text_chunks.insert(0, ref_transcript)
audios = [ref_audio]
ar_codes = []

for ref_chunk_text, current_chunk_text in zip(text_chunks, text_chunks[1:]):
ref_chunk_audio = audios[-1]
current_chunk_ar_codes, current_chunk_audio = self.tts_chunk(
text=current_chunk_text,
ref_audio=ref_audio if cfg.sliding_window_reuse_reference else ref_chunk_audio,
ref_transcript=ref_transcript if cfg.sliding_window_reuse_reference else ref_chunk_text,
cfg=cfg,
)
audios.append(current_chunk_audio)
ar_codes.append(current_chunk_ar_codes.cpu())

final_audio = np.hstack(audios[1:])
return ar_codes, final_audio

@torch.inference_mode
def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor:
def tts_chunk(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None,
cfg: Optional[InferenceConfig] = InferenceConfig()):
""" Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz)
which has an associated `ref_transcript`. Perform inference using the inference
config given by `cfg`, which controls the temperature, top_p, etc...
Expand Down Expand Up @@ -270,8 +320,7 @@ def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None
x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device)

# ---> perform DDPM NAR inference
T = self.default_T
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device)
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=cfg.timesteps, device=self.device)

dsh_cfg = DSH(last_greedy=True, x_0_temp=cfg.x_0_temp,
guidance_w=cfg.nar_guidance_w,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch
torchvision
torchaudio
numpy
numpy<2
regex
librosa
vocos
Expand Down