Skip to content

Commit b6c5f9d

Browse files
authored
Merge pull request #1331 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2 parents 4159a18 + cbfed4a commit b6c5f9d

File tree

16 files changed

+39
-483
lines changed

16 files changed

+39
-483
lines changed

cosyvoice/cli/cosyvoice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
class CosyVoice:
2828

29-
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
29+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
3030
self.instruct = True if '-Instruct' in model_dir else False
3131
self.model_dir = model_dir
3232
self.fp16 = fp16
@@ -48,7 +48,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
4848
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
4949
load_jit, load_trt, fp16 = False, False, False
5050
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
51-
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
51+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16, trt_concurrent)
5252
self.model.load('{}/llm.pt'.format(model_dir),
5353
'{}/flow.pt'.format(model_dir),
5454
'{}/hift.pt'.format(model_dir))

cosyvoice/cli/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,6 @@ def __init__(self,
258258
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
259259
self.llm = llm
260260
self.flow = flow
261-
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
262-
self.flow.encoder.streaming = True
263-
self.flow.decoder.estimator.streaming = True
264261
self.hift = hift
265262
self.fp16 = fp16
266263
self.trt_concurrent = trt_concurrent
@@ -290,7 +287,7 @@ def load_jit(self, flow_encoder_model):
290287
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
291288
self.flow.encoder = flow_encoder
292289

293-
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
290+
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
294291
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
295292
tts_mel, _ = self.flow.inference(token=token.to(self.device),
296293
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
@@ -299,6 +296,7 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, u
299296
prompt_feat=prompt_feat.to(self.device),
300297
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
301298
embedding=embedding.to(self.device),
299+
streaming=stream,
302300
finalize=finalize)
303301
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
304302
# append hift cache
@@ -356,6 +354,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
356354
embedding=flow_embedding,
357355
token_offset=token_offset,
358356
uuid=this_uuid,
357+
stream=stream,
359358
finalize=False)
360359
token_offset += this_token_hop_len
361360
yield {'tts_speech': this_tts_speech.cpu()}

cosyvoice/flow/decoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,6 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
419419
Returns:
420420
_type_: _description_
421421
"""
422-
if hasattr(self, 'streaming'):
423-
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
424-
streaming = self.streaming
425-
426422
t = self.time_embeddings(t).to(t.dtype)
427423
t = self.time_mlp(t)
428424

cosyvoice/flow/flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def inference(self,
241241
prompt_feat,
242242
prompt_feat_len,
243243
embedding,
244+
streaming,
244245
finalize):
245246
assert token.shape[0] == 1
246247
# xvec projection
@@ -254,10 +255,10 @@ def inference(self,
254255

255256
# text encode
256257
if finalize is True:
257-
h, h_lengths = self.encoder(token, token_len)
258+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
258259
else:
259260
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
260-
h, h_lengths = self.encoder(token, token_len, context=context)
261+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
261262
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
262263
h = self.encoder_proj(h)
263264

@@ -273,6 +274,7 @@ def inference(self,
273274
spks=embedding,
274275
cond=conds,
275276
n_timesteps=10,
277+
streaming=streaming
276278
)
277279
feat = feat[:, :, mel_len1:]
278280
assert feat.shape[2] == mel_len2

cosyvoice/flow/flow_matching.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None,
6969
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
7070
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
7171

72-
def solve_euler(self, x, t_span, mu, mask, spks, cond):
72+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
7373
"""
7474
Fixed euler solver for ODEs.
7575
Args:
@@ -110,7 +110,8 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
110110
x_in, mask_in,
111111
mu_in, t_in,
112112
spks_in,
113-
cond_in
113+
cond_in,
114+
streaming
114115
)
115116
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
116117
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
@@ -122,9 +123,9 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
122123

123124
return sol[-1].float()
124125

125-
def forward_estimator(self, x, mask, mu, t, spks, cond):
126+
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
126127
if isinstance(self.estimator, torch.nn.Module):
127-
return self.estimator(x, mask, mu, t, spks, cond)
128+
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
128129
else:
129130
estimator, trt_engine = self.estimator.acquire_estimator()
130131
estimator.set_input_shape('x', (2, 80, x.size(2)))
@@ -196,7 +197,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
196197
self.rand_noise = torch.randn([1, 80, 50 * 300])
197198

198199
@torch.inference_mode()
199-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
200+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
200201
"""Forward diffusion
201202
202203
Args:
@@ -220,4 +221,4 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
220221
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
221222
if self.t_scheduler == 'cosine':
222223
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
223-
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
224+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None

cosyvoice/transformer/upsample_encoder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,6 @@ def forward(
272272
checkpointing API because `__call__` attaches all the hooks of the module.
273273
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
274274
"""
275-
if hasattr(self, 'streaming'):
276-
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
277-
streaming = self.streaming
278275
T = xs.size(1)
279276
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
280277
if self.global_cmvn is not None:

examples/libritts/cosyvoice2/conf/cosyvoice2.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
158158
center: False
159159
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
160160
feat_extractor: !ref <feat_extractor>
161+
token_mel_ratio: 2
161162
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
162163
sample_rate: !ref <sample_rate>
163164
hop_size: 480

examples/libritts/cosyvoice2/path.sh

Lines changed: 0 additions & 3 deletions
This file was deleted.

examples/libritts/cosyvoice2/path.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../cosyvoice/path.sh

examples/libritts/cosyvoice2/tts_text.json

Lines changed: 0 additions & 5 deletions
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../cosyvoice/tts_text.json
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../libritts/cosyvoice/conf

examples/magicdata-read/cosyvoice/conf/cosyvoice.fromscratch.yaml

Lines changed: 0 additions & 203 deletions
This file was deleted.

0 commit comments

Comments
 (0)