Skip to content

Commit a8d59e7

Browse files
committed
推理部分,增加gradio的wav音频流式输出与TTS API
1 parent 939971a commit a8d59e7

File tree

2 files changed

+511
-124
lines changed

2 files changed

+511
-124
lines changed

GPT_SoVITS/inference_stream.py

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
import os
2+
import tempfile, io, wave
3+
import gradio as gr
4+
import uvicorn
5+
import argparse
6+
from fastapi import FastAPI
7+
from fastapi.responses import StreamingResponse
8+
from pydub import AudioSegment
9+
from tools.i18n.i18n import I18nAuto
10+
from GPT_SoVITS.inference_webui import (
11+
get_weights_names,
12+
custom_sort_key,
13+
change_choices,
14+
change_gpt_weights,
15+
change_sovits_weights,
16+
get_tts_wav,
17+
)
18+
19+
api_app = FastAPI()
20+
i18n = I18nAuto()
21+
22+
# API mode Usage: python GPT_SoVITS/inference_stream.py --api
23+
parser = argparse.ArgumentParser(description="GPT-SoVITS Streaming API")
24+
parser.add_argument(
25+
"-api",
26+
"--api",
27+
action="store_true",
28+
default=False,
29+
help="是否开启API模式(不开启则是WebUI模式)",
30+
)
31+
parser.add_argument(
32+
"-s",
33+
"--sovits_path",
34+
type=str,
35+
default="GPT_SoVITS/pretrained_models/s2G488k.pth",
36+
help="SoVITS模型路径",
37+
)
38+
parser.add_argument(
39+
"-g",
40+
"--gpt_path",
41+
type=str,
42+
default="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
43+
help="GPT模型路径",
44+
)
45+
parser.add_argument(
46+
"-rw",
47+
"--ref_wav",
48+
type=str,
49+
# default="./example/archive_ruanmei_8.wav",
50+
help="参考音频路径",
51+
)
52+
parser.add_argument(
53+
"-rt",
54+
"--prompt_text",
55+
type=str,
56+
# default="我听不惯现代乐,听戏却极易入迷,琴弦拨动,时间便流往过去。",
57+
help="参考音频文本",
58+
)
59+
parser.add_argument(
60+
"-rl",
61+
"--prompt_language",
62+
type=str,
63+
default="中文",
64+
help="参考音频语种",
65+
)
66+
67+
args = parser.parse_args()
68+
69+
sovits_path = args.sovits_path
70+
gpt_path = args.gpt_path
71+
SoVITS_names, GPT_names = get_weights_names()
72+
73+
EXAMPLES = [
74+
[
75+
"中文",
76+
"根据过年的传说,远古时代有一隻凶残年兽,每到岁末就会从海底跑出来吃人。"
77+
+ "人们为了不被年兽吃掉,家家户户都会祭拜祖先祈求平安,也会聚在一起吃一顿丰盛的晚餐。"
78+
+ "后来人们发现年兽害怕红色、噪音与火光,便开始在当天穿上红衣、门上贴上红纸、燃烧爆竹声,藉此把年兽赶走。"
79+
+ "而这些之后也成为过年吃团圆饭、穿红衣、放鞭炮、贴春联的过年习俗。",
80+
],
81+
[
82+
"中文",
83+
"神霄折戟录其二"
84+
+"「嗯,好吃。」被附体的未央变得温柔了许多,也冷淡了很多。"
85+
+ "她拿起弥耳做的馅饼,小口小口吃了起来。第一口被烫到了,还很可爱地吐着舌头吸气。"
86+
+ "「我一下子有点接受不了, 需要消化消化。」用一只眼睛作为代价维持降灵的弥耳自己也拿了一个馅饼,「你再说一 遍?」"
87+
+ "「当年所谓的陨铁其实是神戟。它被凡人折断,铸成魔剑九柄。这一把是雾海魔剑。 加上他们之前已经收集了两柄。「然后你是?」"
88+
+ "「我是曾经的天帝之女,名字已经忘了。我司掌审判与断罪,用你们的话说,就是刑律。」"
89+
+ "因为光禄寺执掌祭祀典礼的事情,所以仪式、祝词什么的,弥耳被老爹逼得倒是能倒背如流。同时因为尽是接触怪力乱神,弥耳也是知道一些小]道的。 神明要是被知道了真正的秘密名讳,就只能任人驱使了。眼前这位未必是忘了。"
90+
+ "「所以朝廷是想重铸神霄之戟吗?」弥耳说服自己接受了这个设定,追问道。"
91+
+ "「我不知道。这具身体的主人并不知道别的事。她只是很愤怒,想要证明自己。」未央把手放在了胸口上。"
92+
+"「那接下来,我是应该弄个什么送神仪式把你送走吗?」弥耳摸了摸绷带下已经失去功能的眼睛,「然后我的眼睛也会回来?」"
93+
],
94+
]
95+
96+
97+
# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py
98+
def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000):
99+
# This will create a wave header then append the frame input
100+
# It should be first on a streaming wav file
101+
# Other frames better should not have it (else you will hear some artifacts each chunk start)
102+
wav_buf = io.BytesIO()
103+
with wave.open(wav_buf, "wb") as vfout:
104+
vfout.setnchannels(channels)
105+
vfout.setsampwidth(sample_width)
106+
vfout.setframerate(sample_rate)
107+
vfout.writeframes(frame_input)
108+
109+
wav_buf.seek(0)
110+
return wav_buf.read()
111+
112+
113+
def get_streaming_tts_wav(
114+
ref_wav_path,
115+
prompt_text,
116+
prompt_language,
117+
text,
118+
text_language,
119+
how_to_cut,
120+
top_k,
121+
top_p,
122+
temperature,
123+
ref_free,
124+
byte_stream=True,
125+
):
126+
chunks = get_tts_wav(
127+
ref_wav_path=ref_wav_path,
128+
prompt_text=prompt_text,
129+
prompt_language=prompt_language,
130+
text=text,
131+
text_language=text_language,
132+
how_to_cut=how_to_cut,
133+
top_k=top_k,
134+
top_p=top_p,
135+
temperature=temperature,
136+
ref_free=ref_free,
137+
stream=True,
138+
)
139+
140+
if byte_stream:
141+
yield wave_header_chunk()
142+
for chunk in chunks:
143+
yield chunk
144+
else:
145+
# Send chunk files
146+
i = 0
147+
format = "wav"
148+
for chunk in chunks:
149+
i += 1
150+
file = f"{tempfile.gettempdir()}/{i}.{format}"
151+
segment = AudioSegment(chunk, frame_rate=32000, sample_width=2, channels=1)
152+
segment.export(file, format=format)
153+
yield file
154+
155+
156+
def webui():
157+
with gr.Blocks(title="GPT-SoVITS Streaming Demo") as app:
158+
gr.Markdown(
159+
value=i18n(
160+
"流式输出演示,分句推理后推送到组件中。由于目前bytes模式的限制,采用<a href='https://github.com/gradio-app/gradio/blob/gradio%404.17.0/demo/stream_audio_out/run.py'>stream_audio_out</a>中临时文件的方案输出分句。这种方式相比bytes,会增加wav文件解析的延迟。"
161+
),
162+
)
163+
164+
gr.Markdown(value=i18n("模型切换"))
165+
with gr.Row():
166+
GPT_dropdown = gr.Dropdown(
167+
label=i18n("GPT模型列表"),
168+
choices=sorted(GPT_names, key=custom_sort_key),
169+
value=gpt_path,
170+
interactive=True,
171+
)
172+
SoVITS_dropdown = gr.Dropdown(
173+
label=i18n("SoVITS模型列表"),
174+
choices=sorted(SoVITS_names, key=custom_sort_key),
175+
value=sovits_path,
176+
interactive=True,
177+
)
178+
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
179+
refresh_button.click(
180+
fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]
181+
)
182+
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown], [])
183+
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
184+
185+
gr.Markdown(value=i18n("*请上传并填写参考信息"))
186+
with gr.Row():
187+
inp_ref = gr.Audio(
188+
label=i18n("请上传3~10秒内参考音频,超过会报错!"), value=args.ref_wav, type="filepath"
189+
)
190+
with gr.Column():
191+
ref_text_free = gr.Checkbox(
192+
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
193+
value=False,
194+
interactive=True,
195+
show_label=True,
196+
)
197+
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT"))
198+
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value=args.prompt_text)
199+
prompt_language = gr.Dropdown(
200+
label=i18n("参考音频的语种"),
201+
choices=[
202+
i18n("中文"),
203+
i18n("英文"),
204+
i18n("日文"),
205+
i18n("中英混合"),
206+
i18n("日英混合"),
207+
i18n("多语种混合"),
208+
],
209+
value=args.prompt_language,
210+
)
211+
212+
def load_text(file):
213+
with open(file.name, "r", encoding="utf-8") as file:
214+
return file.read()
215+
216+
load_button = gr.UploadButton(i18n("加载参考文本"), variant="secondary")
217+
load_button.upload(load_text, load_button, prompt_text)
218+
219+
gr.Markdown(
220+
value=i18n(
221+
"*请填写需要合成的目标文本。中英混合选中文,日英混合选日文,中日混合暂不支持,非目标语言文本自动遗弃。"
222+
)
223+
)
224+
with gr.Row():
225+
text = gr.Textbox(
226+
label=i18n("需要合成的文本"), value="", lines=5, interactive=True
227+
)
228+
text_language = gr.Dropdown(
229+
label=i18n("需要合成的语种"),
230+
choices=[
231+
i18n("中文"),
232+
i18n("英文"),
233+
i18n("日文"),
234+
i18n("中英混合"),
235+
i18n("日英混合"),
236+
i18n("多语种混合"),
237+
],
238+
value=i18n("中文"),
239+
)
240+
how_to_cut = gr.Radio(
241+
label=i18n("怎么切"),
242+
choices=[
243+
i18n("不切"),
244+
i18n("凑四句一切"),
245+
i18n("凑50字一切"),
246+
i18n("按中文句号。切"),
247+
i18n("按英文句号.切"),
248+
i18n("按标点符号切"),
249+
],
250+
value=i18n("按标点符号切"),
251+
interactive=True,
252+
)
253+
254+
gr.Markdown(value=i18n("* 参数设置"))
255+
with gr.Row():
256+
with gr.Column():
257+
top_k = gr.Slider(
258+
minimum=1,
259+
maximum=100,
260+
step=1,
261+
label=i18n("top_k"),
262+
value=5,
263+
interactive=True,
264+
)
265+
top_p = gr.Slider(
266+
minimum=0,
267+
maximum=1,
268+
step=0.05,
269+
label=i18n("top_p"),
270+
value=1,
271+
interactive=True,
272+
)
273+
temperature = gr.Slider(
274+
minimum=0,
275+
maximum=1,
276+
step=0.05,
277+
label=i18n("temperature"),
278+
value=1,
279+
interactive=True,
280+
)
281+
inference_button = gr.Button(i18n("合成语音"), variant="primary")
282+
283+
gr.Markdown(value=i18n("* 结果输出(等待第2句推理结束后会自动播放)"))
284+
with gr.Row():
285+
audio_file = gr.Audio(
286+
value=None,
287+
label=i18n("输出的语音"),
288+
streaming=True,
289+
autoplay=True,
290+
interactive=False,
291+
show_label=True,
292+
)
293+
294+
inference_button.click(
295+
get_streaming_tts_wav,
296+
[
297+
inp_ref,
298+
prompt_text,
299+
prompt_language,
300+
text,
301+
text_language,
302+
how_to_cut,
303+
top_k,
304+
top_p,
305+
temperature,
306+
ref_text_free,
307+
],
308+
[audio_file],
309+
).then(lambda: gr.update(interactive=True), None, [text], queue=False)
310+
311+
with gr.Row():
312+
gr.Examples(
313+
EXAMPLES,
314+
[text_language, text],
315+
cache_examples=False,
316+
run_on_click=False, # Will not work , user should submit it
317+
)
318+
319+
app.queue().launch(
320+
server_name="0.0.0.0",
321+
inbrowser=True,
322+
share=False,
323+
server_port=8080,
324+
quiet=True,
325+
)
326+
327+
328+
@api_app.get("/")
329+
async def tts(
330+
text: str, # 必选参数
331+
language: str = i18n("中文"),
332+
top_k: int = 5,
333+
top_p: float = 1,
334+
temperature: float = 1,
335+
):
336+
ref_wav_path = args.ref_wav
337+
prompt_text = args.prompt_text
338+
prompt_language = args.prompt_language
339+
how_to_cut = i18n("按标点符号切")
340+
341+
return StreamingResponse(
342+
get_streaming_tts_wav(
343+
ref_wav_path=ref_wav_path,
344+
prompt_text=prompt_text,
345+
prompt_language=prompt_language,
346+
text=text,
347+
text_language=language,
348+
how_to_cut=how_to_cut,
349+
top_k=top_k,
350+
top_p=top_p,
351+
temperature=temperature,
352+
ref_free=False,
353+
byte_stream=True,
354+
),
355+
media_type="audio/x-wav",
356+
)
357+
358+
359+
def api():
360+
uvicorn.run(
361+
app="inference_stream:api_app", host="127.0.0.1", port=8080, reload=True
362+
)
363+
364+
365+
if __name__ == "__main__":
366+
# 模式选择,默认是webui模式
367+
if not args.api:
368+
webui()
369+
else:
370+
api()

0 commit comments

Comments
 (0)