Skip to content

Commit cbfed4a

Browse files
committed
send streaming as args
1 parent 54d21b4 commit cbfed4a

File tree

5 files changed

+14
-19
lines changed

5 files changed

+14
-19
lines changed

cosyvoice/cli/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,6 @@ def __init__(self,
258258
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
259259
self.llm = llm
260260
self.flow = flow
261-
# NOTE default setting for jit/onnx export, you can set to False when using pytorch inference
262-
self.flow.encoder.streaming = True
263-
self.flow.decoder.estimator.streaming = True
264261
self.hift = hift
265262
self.fp16 = fp16
266263
self.trt_concurrent = trt_concurrent
@@ -290,7 +287,7 @@ def load_jit(self, flow_encoder_model):
290287
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
291288
self.flow.encoder = flow_encoder
292289

293-
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, finalize=False, speed=1.0):
290+
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
294291
with torch.cuda.amp.autocast(self.fp16), self.trt_context_dict[uuid]:
295292
tts_mel, _ = self.flow.inference(token=token.to(self.device),
296293
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
@@ -299,6 +296,7 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, u
299296
prompt_feat=prompt_feat.to(self.device),
300297
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
301298
embedding=embedding.to(self.device),
299+
streaming=stream,
302300
finalize=finalize)
303301
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
304302
# append hift cache
@@ -356,6 +354,7 @@ def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.ze
356354
embedding=flow_embedding,
357355
token_offset=token_offset,
358356
uuid=this_uuid,
357+
stream=stream,
359358
finalize=False)
360359
token_offset += this_token_hop_len
361360
yield {'tts_speech': this_tts_speech.cpu()}

cosyvoice/flow/decoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,6 @@ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
419419
Returns:
420420
_type_: _description_
421421
"""
422-
if hasattr(self, 'streaming'):
423-
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
424-
streaming = self.streaming
425-
426422
t = self.time_embeddings(t).to(t.dtype)
427423
t = self.time_mlp(t)
428424

cosyvoice/flow/flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def inference(self,
241241
prompt_feat,
242242
prompt_feat_len,
243243
embedding,
244+
streaming,
244245
finalize):
245246
assert token.shape[0] == 1
246247
# xvec projection
@@ -254,10 +255,10 @@ def inference(self,
254255

255256
# text encode
256257
if finalize is True:
257-
h, h_lengths = self.encoder(token, token_len)
258+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
258259
else:
259260
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
260-
h, h_lengths = self.encoder(token, token_len, context=context)
261+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
261262
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
262263
h = self.encoder_proj(h)
263264

@@ -273,6 +274,7 @@ def inference(self,
273274
spks=embedding,
274275
cond=conds,
275276
n_timesteps=10,
277+
streaming=streaming
276278
)
277279
feat = feat[:, :, mel_len1:]
278280
assert feat.shape[2] == mel_len2

cosyvoice/flow/flow_matching.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None,
6969
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
7070
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
7171

72-
def solve_euler(self, x, t_span, mu, mask, spks, cond):
72+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
7373
"""
7474
Fixed euler solver for ODEs.
7575
Args:
@@ -110,7 +110,8 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
110110
x_in, mask_in,
111111
mu_in, t_in,
112112
spks_in,
113-
cond_in
113+
cond_in,
114+
streaming
114115
)
115116
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
116117
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
@@ -122,9 +123,9 @@ def solve_euler(self, x, t_span, mu, mask, spks, cond):
122123

123124
return sol[-1].float()
124125

125-
def forward_estimator(self, x, mask, mu, t, spks, cond):
126+
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
126127
if isinstance(self.estimator, torch.nn.Module):
127-
return self.estimator(x, mask, mu, t, spks, cond)
128+
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
128129
else:
129130
estimator, trt_engine = self.estimator.acquire_estimator()
130131
estimator.set_input_shape('x', (2, 80, x.size(2)))
@@ -196,7 +197,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
196197
self.rand_noise = torch.randn([1, 80, 50 * 300])
197198

198199
@torch.inference_mode()
199-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
200+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
200201
"""Forward diffusion
201202
202203
Args:
@@ -220,4 +221,4 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
220221
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
221222
if self.t_scheduler == 'cosine':
222223
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
223-
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
224+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None

cosyvoice/transformer/upsample_encoder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,6 @@ def forward(
272272
checkpointing API because `__call__` attaches all the hooks of the module.
273273
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
274274
"""
275-
if hasattr(self, 'streaming'):
276-
assert self.training is False, 'you have self.streaming attr, make sure that you are running inference mode'
277-
streaming = self.streaming
278275
T = xs.size(1)
279276
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
280277
if self.global_cmvn is not None:

0 commit comments

Comments
 (0)