From 0a17694edee04ad4605d8352034339374691d94a Mon Sep 17 00:00:00 2001 From: Jin Date: Sat, 28 Sep 2024 19:42:27 +0800 Subject: [PATCH 1/2] api_v2.py: support ref_audio input as base64 string --- api_v2.py | 21 ++++++++++++++++++++- requirements.txt | 1 + 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/api_v2.py b/api_v2.py index ea1d0c7f3..3008cb382 100644 --- a/api_v2.py +++ b/api_v2.py @@ -212,6 +212,22 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str): return io_buffer +_base64_audio_cache = {} +def save_base64_audio(b64str:str): + import filetype, base64, uuid + global _base64_audio_cache + if b64str in _base64_audio_cache: + return _base64_audio_cache[b64str] + savedir = 'TEMP/upload' + data = base64.b64decode(b64str) + ft = filetype.guess(data) + ext = f'.{ft.extension}' if ft else '' + os.makedirs(savedir, exist_ok=True) + saveto = f'{savedir}/{uuid.uuid1()}{ext}' + with open(saveto, 'wb') as outf: + outf.write(data) + _base64_audio_cache[b64str] = saveto + return saveto # from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): @@ -277,7 +293,7 @@ async def tts_handle(req:dict): { "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized - "ref_audio_path": "", # str.(required) reference audio path + "ref_audio_path": "", # str.(required) reference audio path ; allow data of format base64:xxxxxx "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio @@ -303,6 +319,9 @@ async def tts_handle(req:dict): streaming_mode = req.get("streaming_mode", False) return_fragment = req.get("return_fragment", False) media_type = req.get("media_type", "wav") + ref_audio_path = req.get("ref_audio_path", "") + if ref_audio_path.startswith("base64:"): + req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):]) check_res = check_params(req) if check_res is not None: diff --git a/requirements.txt b/requirements.txt index 280d9d992..3792d88b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ opencc; sys_platform != 'linux' opencc==1.1.1; sys_platform == 'linux' python_mecab_ko; sys_platform != 'win32' fastapi<0.112.2 +filetype From 27664703d2fb3c86504d8168ae79639b784c56f7 Mon Sep 17 00:00:00 2001 From: Jin Date: Sat, 28 Sep 2024 19:42:36 +0800 Subject: [PATCH 2/2] Generate & return subtitles with the audio. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 生成与音频同步的字幕并返回: - TTS_infer_pack/TTS.py 生成与音频对应的字幕信息 - api_v2.py /tts 接口可用JSON同时返回生成的音频(转为base64)和字幕 - 通过参数控制是否生成字幕,默认关闭,不影响其他模块 --- GPT_SoVITS/TTS_infer_pack/TTS.py | 54 ++++++++++++------- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 1 + api_v2.py | 37 ++++++++++++- 3 files changed, 72 insertions(+), 20 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index a1eeb28ce..2f7ed7231 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -550,6 +550,7 @@ def to_batch(self, data:list, all_phones_len_list = [] all_bert_features_list = [] norm_text_batch = [] + origin_text_batch = [] all_bert_max_len = 0 all_phones_max_len = 0 for item in item_list: @@ -575,6 +576,7 @@ def to_batch(self, data:list, all_phones_len_list.append(all_phones.shape[-1]) all_bert_features_list.append(all_bert_features) norm_text_batch.append(item["norm_text"]) + origin_text_batch.append(item["origin_text"]) phones_batch = phones_list all_phones_batch = all_phones_list @@ -606,6 +608,7 @@ def to_batch(self, data:list, "all_phones_len": torch.LongTensor(all_phones_len_list).to(device), "all_bert_features": all_bert_features_batch, "norm_text": norm_text_batch, + "origin_text": origin_text_batch, "max_len": max_len, } _data.append(batch) @@ -658,6 +661,7 @@ def run(self, inputs:dict): "batch_threshold": 0.75, # float. threshold for batch splitting. "split_bucket: True, # bool. whether to split the batch into multiple buckets. "return_fragment": False, # bool. step by step return the audio fragment. + "return_with_srt": "", # str. return with or without("") subtitles, using "orig"inal or "norm"alized text "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. "seed": -1, # int. random seed for reproducibility. @@ -685,6 +689,7 @@ def run(self, inputs:dict): split_bucket = inputs.get("split_bucket", True) return_fragment = inputs.get("return_fragment", False) fragment_interval = inputs.get("fragment_interval", 0.3) + return_with_srt = inputs.get("return_with_srt", "") seed = inputs.get("seed", -1) seed = -1 if seed in ["", None] else seed actual_seed = set_seed(seed) @@ -704,6 +709,9 @@ def run(self, inputs:dict): split_bucket = False print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) + ret_width = 3 if return_with_srt else 2 # return (sr, audio, srt) or (sr, audio) + srt_text = "norm_text" if return_with_srt.startswith("norm") else "origin_text" + if split_bucket and speed_factor==1.0: print(i18n("分桶处理模式已开启")) elif speed_factor!=1.0: @@ -773,8 +781,7 @@ def run(self, inputs:dict): if not return_fragment: data = self.text_preprocessor.preprocess(text, text_lang, text_split_method, self.configs.version) if len(data) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield self.audio_failure()[:ret_width] return batch_index_list:list = None @@ -806,6 +813,7 @@ def make_batch(batch_texts): "phones": phones, "bert_features": bert_features, "norm_text": norm_text, + "origin_text": text, } batch_data.append(res) if len(batch_data) == 0: @@ -841,10 +849,11 @@ def make_batch(batch_texts): all_phoneme_ids:torch.LongTensor = item["all_phones"] all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_bert_features:torch.LongTensor = item["all_bert_features"] - norm_text:str = item["norm_text"] + # norm_text:List[str] = item["norm_text"] + # origin_text:List[str] = item["origin_text"] max_len = item["max_len"] - print(i18n("前端处理后的文本(每句):"), norm_text) + print(i18n("前端处理后的文本(每批):"), item["norm_text"]) if no_prompt_text : prompt = None else: @@ -915,39 +924,38 @@ def make_batch(batch_texts): if return_fragment: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4)) yield self.audio_postprocess([batch_audio_fragment], + [item[srt_text]], self.configs.sampling_rate, None, speed_factor, False, fragment_interval - ) + )[:ret_width] else: audio.append(batch_audio_fragment) if self.stop_flag: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield self.audio_failure()[:ret_width] return if not return_fragment: print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45)) if len(audio) == 0: - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield self.audio_failure()[:ret_width] return yield self.audio_postprocess(audio, + [v[srt_text] for v in data], self.configs.sampling_rate, batch_index_list, speed_factor, split_bucket, fragment_interval - ) + )[:ret_width] except Exception as e: traceback.print_exc() # 必须返回一个空音频, 否则会导致显存不释放。 - yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), - dtype=np.int16) + yield self.audio_failure()[:ret_width] # 重置模型, 否则会导致显存释放不完全。 del self.t2s_model del self.vits_model @@ -968,15 +976,19 @@ def empty_cache(self): torch.mps.empty_cache() except: pass - + + def audio_failure(self): + return self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate), dtype=np.int16), [] + def audio_postprocess(self, - audio:List[torch.Tensor], + audio:List[torch.Tensor], + texts:List[List[str]], sr:int, batch_index_list:list=None, speed_factor:float=1.0, split_bucket:bool=True, fragment_interval:float=0.3 - )->Tuple[int, np.ndarray]: + )->Tuple[int, np.ndarray, List]: zero_wav = torch.zeros( int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, @@ -993,11 +1005,17 @@ def audio_postprocess(self, if split_bucket: audio = self.recovery_order(audio, batch_index_list) + texts = self.recovery_order(texts, batch_index_list) else: # audio = [item for batch in audio for item in batch] audio = sum(audio, []) - - + texts = sum(texts, []) + + # 按顺序计算每段语音的起止时间,并与文字一一对应,用于生成字幕 + from itertools import accumulate + stamps = [0.0] + [x/sr for x in accumulate([v.size for v in audio])] + srts = list(zip(stamps[:-1], stamps[1:], texts)) # time start, end, text + audio = np.concatenate(audio, 0) audio = (audio * 32768).astype(np.int16) @@ -1007,7 +1025,7 @@ def audio_postprocess(self, # except Exception as e: # print(f"Failed to change speed of audio: \n{e}") - return sr, audio + return sr, audio, srts diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index b90bd929e..7c5b256f5 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -69,6 +69,7 @@ def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2" "phones": phones, "bert_features": bert_features, "norm_text": norm_text, + "origin_text": text, } result.append(res) return result diff --git a/api_v2.py b/api_v2.py index 3008cb382..74ad61937 100644 --- a/api_v2.py +++ b/api_v2.py @@ -36,6 +36,7 @@ "split_bucket: True, # bool. whether to split the batch into multiple buckets. "speed_factor":1.0, # float. control the speed of the synthesized audio. "streaming_mode": False, # bool. whether to return a streaming response. + "with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet) "seed": -1, # int. random seed for reproducibility. "parallel_infer": True, # bool. whether to use parallel inference. "repetition_penalty": 1.35 # float. repetition penalty for T2S model. @@ -98,7 +99,7 @@ import os import sys import traceback -from typing import Generator +from typing import Generator, List, Union now_dir = os.getcwd() sys.path.append(now_dir) @@ -162,6 +163,7 @@ class TTS_Request(BaseModel): seed:int = -1 media_type:str = "wav" streaming_mode:bool = False + with_srt_format:str = "" parallel_infer:bool = True repetition_penalty:float = 1.35 @@ -211,6 +213,21 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str): io_buffer.seek(0) return io_buffer +def pack_srt(srt:List, fmt:str): + if fmt == "raw": + return srt + # TODO: support formats like "srt", "lrc", "vtt", ... + return srt + +def load_base64_audio(audio): + import base64 + if isinstance(audio, (bytes, bytearray)): + audio = bytes(audio) + elif hasattr(audio, 'read'): # file-like obj + audio = audio.read() + else: # path-like + audio = open(audio, 'rb').read() + return base64.b64encode(audio).decode('ascii') _base64_audio_cache = {} def save_base64_audio(b64str:str): @@ -309,6 +326,7 @@ async def tts_handle(req:dict): "seed": -1, # int. random seed for reproducibility. "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". "streaming_mode": False, # bool. whether to return a streaming response. + "with_srt_format": "", # str. ""(no srt) or "raw" or "srt", "lrc", "vtt", ... formats (not implemented yet) "parallel_infer": True, # bool.(optional) whether to use parallel inference. "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. } @@ -319,6 +337,7 @@ async def tts_handle(req:dict): streaming_mode = req.get("streaming_mode", False) return_fragment = req.get("return_fragment", False) media_type = req.get("media_type", "wav") + with_srt_format = req.get("with_srt_format", "") ref_audio_path = req.get("ref_audio_path", "") if ref_audio_path.startswith("base64:"): req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):]) @@ -329,7 +348,10 @@ async def tts_handle(req:dict): if streaming_mode or return_fragment: req["return_fragment"] = True - + + if streaming_mode: with_srt_format = "" # streaming not support srt + req["return_with_srt"] = "orig" if with_srt_format else "" + try: tts_generator=tts_pipeline.run(req) @@ -343,6 +365,16 @@ def streaming_generator(tts_generator:Generator, media_type:str): # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") + elif with_srt_format: + output = [] + for sr, audio_data, srt_data in tts_generator: + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + output.append({ + "audio": load_base64_audio(audio_data), "media_type": f"audio/{media_type}", + "srt": pack_srt(srt_data, with_srt_format), "srt_fmt": with_srt_format, + }) + return { "message":"succeed", "output":output } # Jsonresponse(status_code=200, content=...) + else: sr, audio_data = next(tts_generator) audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() @@ -383,6 +415,7 @@ async def tts_get_endpoint( seed:int = -1, media_type:str = "wav", streaming_mode:bool = False, + with_srt_format:str = "", parallel_infer:bool = True, repetition_penalty:float = 1.35 ):