Skip to content

Commit 12aa078

Browse files
committed
fix: encoder args #225 #224
- 修复 ffmpeg 参数覆盖问题 - 增加单测 - 增加 AudioHandler.get_sample_rate()
1 parent e257af3 commit 12aa078

File tree

7 files changed

+121
-31
lines changed

7 files changed

+121
-31
lines changed

modules/core/handler/AudioHandler.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def enqueue_stream(self) -> AsyncGenerator[NP_AUDIO, None]:
6262
"Method 'enqueue_stream' must be implemented by subclass"
6363
)
6464

65+
def get_sample_rate(self) -> int:
66+
raise NotImplementedError(
67+
"Method 'get_sample_rate' must be implemented by subclass"
68+
)
69+
6570
def set_current_request(self, request: Request):
6671
assert self.current_request is None, "current_request has been set"
6772
assert isinstance(
@@ -93,7 +98,9 @@ def get_encoder(self) -> StreamEncoder:
9398
else:
9499
raise ValueError(f"Unsupported audio format: {format}")
95100

101+
encoder.set_header(sample_rate=self.get_sample_rate())
96102
encoder.open(bitrate=bitrate, acodec=acodec)
103+
encoder.write_header_data()
97104

98105
return encoder
99106

@@ -104,9 +111,6 @@ async def enqueue_to_stream(self) -> AsyncGenerator[bytes, None]:
104111

105112
chunk_data = bytes()
106113
async for sample_rate, audio_data in self.enqueue_stream():
107-
encoder.set_header(
108-
sample_rate=sample_rate, sample_width=audio_data.dtype.itemsize
109-
)
110114
audio_bytes = covert_to_s16le(audio_data=audio_data)
111115

112116
logger.debug(f"write audio_bytes len: {len(audio_bytes)}")
@@ -153,9 +157,6 @@ async def enqueue_to_stream_join(self) -> AsyncGenerator[bytes, None]:
153157
encoder = self.get_encoder()
154158
chunk_data = bytes()
155159
async for sample_rate, audio_data in self.enqueue_stream():
156-
encoder.set_header(
157-
sample_rate=sample_rate, sample_width=audio_data.dtype.itemsize
158-
)
159160
audio_bytes = covert_to_s16le(audio_data=audio_data)
160161
encoder.write(audio_bytes)
161162

@@ -166,26 +167,30 @@ async def enqueue_to_stream_join(self) -> AsyncGenerator[bytes, None]:
166167

167168
encoder.terminate()
168169

170+
async def _enqueue_to_bytes(self) -> bytes:
171+
"""
172+
为了测试拆分的函数
173+
这个函数不依赖 current_request 状态
174+
"""
175+
encoder = self.get_encoder()
176+
buffer = bytes()
177+
try:
178+
sample_rate, audio_data = await self.enqueue()
179+
audio_bytes = covert_to_s16le(audio_data=audio_data)
180+
encoder.write(audio_bytes)
181+
encoder.close()
182+
buffer = encoder.read_all()
183+
finally:
184+
encoder.terminate()
185+
return buffer
186+
169187
async def enqueue_to_bytes(self) -> bytes:
170188
if self.current_request is None:
171189
raise ValueError("current_request is not set")
172190

173-
encoder = self.get_encoder()
174-
175191
# NOTE: 这里的逻辑类似 goto
176192
async with cancel_on_disconnect(self.current_request):
177-
try:
178-
sample_rate, audio_data = await self.enqueue()
179-
audio_bytes = covert_to_s16le(audio_data=audio_data)
180-
encoder.set_header(
181-
sample_rate=sample_rate, sample_width=audio_data.dtype.itemsize
182-
)
183-
encoder.write(audio_bytes)
184-
encoder.close()
185-
buffer = encoder.read_all()
186-
finally:
187-
encoder.terminate()
188-
return buffer
193+
return self._enqueue_to_bytes()
189194

190195
logger.debug(f"disconnected")
191196
self.interrupt()

modules/core/handler/TTSHandler.py

+3
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def interrupt(self):
100100
self.ctx.stop = True
101101
self.pipeline.model.interrupt()
102102

103+
def get_sample_rate(self):
104+
return self.pipeline.model.get_sample_rate()
105+
103106
async def enqueue(self) -> NP_AUDIO:
104107
timeout = self.ctx.infer_config.timeout
105108
return await self.pipeline.generate(timeout=timeout)

modules/core/handler/VCHandler.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
if self.model is None:
3333
raise Exception(f"Model {self.vc_config.mid} is not supported")
3434

35-
def get_model(self):
35+
def get_model(self) -> VCModel:
3636
model_id = (
3737
self.vc_config.mid.lower()
3838
.replace(" ", "")
@@ -55,3 +55,6 @@ async def enqueue_stream(self) -> AsyncGenerator[NP_AUDIO, None]:
5555
raise NotImplementedError(
5656
"Method 'enqueue_stream' not implemented in VCHandler"
5757
)
58+
59+
def get_sample_rate(self):
60+
return self.model.get_sample_rate()

modules/core/handler/encoder/StreamEncoder.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,47 @@ def __init__(self) -> None:
4343
def set_header(
4444
self, *, frame_input=b"", channels=1, sample_width=2, sample_rate=24000
4545
):
46+
"""
47+
基本上只需要改 sample_rate 因为我们输入的都是 pcm s16le (int16)
48+
"""
49+
self.channels = channels
50+
self.sample_width = sample_width
51+
self.sample_rate = sample_rate
52+
53+
logger.info(
54+
f"StreamEncoder header set, channels: {channels}, sample_width: {sample_width}, sample_rate: {sample_rate}"
55+
)
56+
57+
def write_header_data(self):
4658
if self.header:
4759
return
4860
header_bytes = wave_header_chunk(
49-
frame_input, channels, sample_width, sample_rate
61+
channels=self.channels,
62+
sample_width=self.sample_width,
63+
sample_rate=self.sample_rate,
5064
)
5165
self.header = header_bytes
5266
self.write(header_bytes)
5367

5468
logger.info(
55-
f"StreamEncoder header set, channels: {channels}, sample_width: {sample_width}, sample_rate: {sample_rate}"
69+
f"StreamEncoder header written, channels: {self.channels}, sample_width: {self.sample_width}, sample_rate: {self.sample_rate}"
5670
)
5771

5872
def open(
59-
self, format: str = "mp3", acodec: str = "libmp3lame", bitrate: str = "320k"
73+
self,
74+
format: str = "mp3",
75+
acodec: str = "libmp3lame",
76+
bitrate: str = "320k",
77+
input_dtype: str = "s16le", # s16le or s32le
6078
):
79+
"""
80+
打开编码器
81+
82+
:param format: 输出格式
83+
:param acodec: 输出编码器
84+
:param bitrate: 输出比特率
85+
:param input_dtype: 输入数据类型 s16le or s32le
86+
"""
6187
encoder = self.encoder
6288
self.p = subprocess.Popen(
6389
[
@@ -66,14 +92,12 @@ def open(
6692
"-threads",
6793
str(os.cpu_count() or 4),
6894
# NOTE: 指定输入格式为 16 位 PCM
69-
# NOTE: 其实文件头里面有写,但是没有文件名,所以需要手动指定
7095
"-f",
71-
"s16le",
72-
# NOTE: 不要在这里传递 ar/ac ,我们写在wav文件头上,这里会覆盖掉文件头读取的数据
73-
# "-ar",
74-
# str(self.sample_rate), # 输入采样率
75-
# "-ac",
76-
# "1", # 输入单声道
96+
input_dtype,
97+
"-ar",
98+
str(self.sample_rate), # 输入采样率
99+
"-ac",
100+
str(self.channels), # 输入单声道
77101
"-i",
78102
"pipe:0",
79103
"-f",

modules/core/models/vc/OpenVoice.py

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def sampling_rate(self) -> int:
5959
hps = self.model.hps
6060
return hps.data.sampling_rate
6161

62+
def get_sample_rate(self):
63+
return self.sampling_rate
64+
6265
def audio_to_se(self, audio: NP_AUDIO) -> torch.Tensor:
6366
hps = self.model.hps
6467
device = self.device

modules/core/models/vc/VCModel.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ def convert(
1010
self, src_audio: NP_AUDIO, ref_spk: TTSSpeaker, config: VCConfig
1111
) -> NP_AUDIO:
1212
raise NotImplementedError
13+
14+
def get_sample_rate(self) -> int:
15+
raise NotImplementedError

tests/encders/test_encoders.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
from modules.core.handler.TTSHandler import TTSHandler
4+
from modules.core.handler.datacls.audio_model import EncoderConfig
5+
from modules.core.handler.datacls.tts_model import TTSConfig
6+
from modules.core.spk.SpkMgr import spk_mgr
7+
8+
# 这里测试 encoder 是否可以正常工作
9+
# 正常工作的定义是输出音频正常无噪音无变速
10+
11+
need_test_models = ["chat-tts", "fish-speech", "cosy-voice"]
12+
13+
14+
@pytest.mark.parametrize(
15+
"model_id, format",
16+
[
17+
# raw 格式就是 wav,只是直接输出pcm
18+
("chat-tts", "raw"),
19+
("fish-speech", "raw"),
20+
("cosy-voice", "raw"),
21+
("chat-tts", "mp3"),
22+
("fish-speech", "mp3"),
23+
("cosy-voice", "mp3"),
24+
("chat-tts", "wav"),
25+
("fish-speech", "wav"),
26+
("cosy-voice", "wav"),
27+
],
28+
)
29+
@pytest.mark.encoders
30+
@pytest.mark.asyncio
31+
async def test_encoders(model_id, format):
32+
spk_mona = spk_mgr.get_speaker("mona")
33+
handler = TTSHandler(
34+
text_content="云想衣裳花想容,春风拂槛露华浓。 若非群玉山头见,会向瑶台月下逢。",
35+
spk=spk_mona,
36+
tts_config=TTSConfig(mid=model_id),
37+
encoder_config=EncoderConfig(format=format),
38+
)
39+
file_bytes = await handler._enqueue_to_bytes()
40+
41+
ext = format
42+
if format == "raw":
43+
ext = "wav"
44+
45+
# 1. 不为空
46+
assert len(file_bytes) > 0
47+
# 2. 保存到 tests/test_outputs 之下,然后人工检查
48+
with open(f"tests/test_outputs/test_encder_{model_id}_{format}.{ext}", "wb") as f:
49+
f.write(file_bytes)

0 commit comments

Comments
 (0)