Skip to content

Commit ec34565

Browse files
committed
✨ feat: 重构启动逻辑,将在线API检查和音频文件夹检查移至新hook模块,删除冗余代码
1 parent da2b083 commit ec34565

File tree

5 files changed

+82
-80
lines changed

5 files changed

+82
-80
lines changed

nonebot_plugin_fishspeech_tts/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
require("nonebot_plugin_alconna")
88

99
from . import matcher as _match # noqa:F401,I001
10-
from . import on_start_up # noqa:F401
10+
from . import hook # noqa:F401
1111

1212

1313
usage: str = """

nonebot_plugin_fishspeech_tts/fish_audio_api.py

+44-44
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,17 @@ class FishAudioAPI:
3636
FishAudioAPI类, 用于调用FishAudio的API接口
3737
"""
3838

39-
def __init__(self):
40-
self.api_url: str = API_URL
41-
self.path_audio: Path = Path(config.tts_audio_path)
42-
self.proxy = API_PROXY
39+
api_url: str = API_URL
40+
path_audio: Path = Path(config.tts_audio_path)
41+
proxy = API_PROXY
42+
from typing import ClassVar
4343

44-
# 如果在线授权码为空, 且使用在线api, 则抛出异常
45-
if not config.online_authorization and config.tts_is_online:
46-
raise AuthorizationException("请先在配置文件中填写在线授权码或使用离线api")
47-
self.headers = {
48-
"Authorization": f"Bearer {config.online_authorization}",
49-
}
50-
51-
# 如果音频文件夹不存在, 则创建音频文件夹
52-
if not self.path_audio.exists():
53-
self.path_audio.mkdir(parents=True)
54-
logger.warning(f"音频文件夹{self.path_audio.name}不存在, 已创建")
55-
elif not self.path_audio.is_dir():
56-
raise FileHandleException(f"{self.path_audio.name}不是一个文件夹")
57-
58-
async def _get_reference_id_by_speaker(self, speaker: str) -> str:
44+
_headers: ClassVar[dict] = {
45+
"Authorization": f"Bearer {config.online_authorization}",
46+
}
47+
48+
@classmethod
49+
async def _get_reference_id_by_speaker(cls, speaker: str) -> str:
5950
"""
6051
通过说话人姓名获取说话人的reference_id
6152
@@ -68,13 +59,13 @@ async def _get_reference_id_by_speaker(self, speaker: str) -> str:
6859
exception:
6960
APIException: 获取语音角色列表为空
7061
"""
71-
request_api = self.api_url + "/model"
62+
request_api = cls.api_url + "/model"
7263
sort_options = ["score", "task_count", "created_at"]
73-
async with AsyncClient(proxy=self.proxy) as client:
64+
async with AsyncClient(proxy=cls.proxy) as client:
7465
for sort_by in sort_options:
7566
params = {"title": speaker, "sort_by": sort_by}
7667
response = await client.get(
77-
request_api, params=params, headers=self.headers
68+
request_api, params=params, headers=cls._headers
7869
)
7970
resp_data = response.json()
8071
if resp_data["total"] == 0:
@@ -84,8 +75,9 @@ async def _get_reference_id_by_speaker(self, speaker: str) -> str:
8475
return item["_id"]
8576
raise APIException("未找到对应的角色")
8677

78+
@classmethod
8779
async def generate_servettsrequest(
88-
self,
80+
cls,
8981
text: str,
9082
speaker_name: str,
9183
chunk_length: ChunkLength = ChunkLength.NORMAL,
@@ -103,15 +95,18 @@ async def generate_servettsrequest(
10395
Returns:
10496
ServeTTSRequest: TTS请求
10597
"""
98+
if not config.online_authorization and config.tts_is_online:
99+
raise AuthorizationException("请先在配置文件中填写在线授权码或使用离线api")
100+
106101
reference_id = None
107102
references = []
108103
try:
109104
if is_reference_id_first:
110-
reference_id = await self._get_reference_id_by_speaker(speaker_name)
105+
reference_id = await cls._get_reference_id_by_speaker(speaker_name)
111106
else:
112107
try:
113108
speaker_audio_path = get_speaker_audio_path(
114-
self.path_audio, speaker_name
109+
cls.path_audio, speaker_name
115110
)
116111
for audio in speaker_audio_path:
117112
audio_bytes = audio.read_bytes()
@@ -121,7 +116,7 @@ async def generate_servettsrequest(
121116
)
122117
except FileHandleException:
123118
logger.warning("音频文件夹不存在, 已转为在线模型优先模式")
124-
reference_id = await self._get_reference_id_by_speaker(speaker_name)
119+
reference_id = await cls._get_reference_id_by_speaker(speaker_name)
125120
except APIException as e:
126121
raise e from e
127122
return ServeTTSRequest(
@@ -138,7 +133,8 @@ async def generate_servettsrequest(
138133
references=references,
139134
)
140135

141-
async def generate_tts(self, request: ServeTTSRequest) -> bytes:
136+
@classmethod
137+
async def generate_tts(cls, request: ServeTTSRequest) -> bytes:
142138
"""
143139
获取TTS音频
144140
@@ -149,14 +145,14 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
149145
bytes: TTS音频二进制数据
150146
"""
151147
if request.references:
152-
self.headers["content-type"] = "application/msgpack"
148+
cls._headers["content-type"] = "application/msgpack"
153149
try:
154150
async with (
155-
AsyncClient(proxy=self.proxy) as client,
151+
AsyncClient(proxy=cls.proxy) as client,
156152
client.stream(
157153
"POST",
158-
self.api_url + "/v1/tts",
159-
headers=self.headers,
154+
cls.api_url + "/v1/tts",
155+
headers=cls._headers,
160156
content=ormsgpack.packb(
161157
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
162158
),
@@ -169,17 +165,19 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
169165
HTTPStatusError,
170166
) as e:
171167
logger.error(f"获取TTS音频失败: {e}")
172-
if self.proxy:
168+
if cls.proxy:
173169
raise HTTPException("代理地址错误, 请检查代理地址是否正确") from e
174170
raise HTTPException("网络错误, 请检查网络连接") from e
175171
else:
176-
self.headers["content-type"] = "application/json"
172+
cls._headers["content-type"] = "application/json"
177173
try:
178-
async with AsyncClient(proxy=self.proxy) as client:
174+
async with AsyncClient(proxy=cls.proxy) as client:
179175
response = await client.post(
180-
self.api_url + "/v1/tts",
181-
headers=self.headers,
182-
json=request.dict(),
176+
cls.api_url + "/v1/tts",
177+
headers=cls._headers,
178+
json=ormsgpack.packb(
179+
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
180+
),
183181
timeout=60,
184182
)
185183
return response.content
@@ -190,30 +188,32 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
190188
HTTPStatusError,
191189
) as e:
192190
logger.error(f"获取TTS音频失败: {e}")
193-
if self.proxy:
191+
if cls.proxy:
194192
raise HTTPException("代理地址错误, 请检查代理地址是否正确") from e
195193
raise HTTPException("网络错误, 请检查网络连接") from e
196194

197-
async def get_balance(self) -> float:
195+
@classmethod
196+
async def get_balance(cls) -> float:
198197
"""
199198
获取账户余额
200199
"""
201-
balance_url = self.api_url + "/wallet/self/api-credit"
202-
async with AsyncClient(proxy=self.proxy) as client:
203-
response = await client.get(balance_url, headers=self.headers)
200+
balance_url = cls.api_url + "/wallet/self/api-credit"
201+
async with AsyncClient(proxy=cls.proxy) as client:
202+
response = await client.get(balance_url, headers=cls._headers)
204203
try:
205204
return response.json()["credit"]
206205
except KeyError:
207206
raise AuthorizationException("授权码错误或已失效") from KeyError
208207

209-
def get_speaker_list(self) -> list[str]:
208+
@classmethod
209+
def get_speaker_list(cls) -> list[str]:
210210
"""
211211
获取语音角色列表
212212
"""
213213
return_list = ["请查看官网了解更多: https://fish.audio/zh-CN/"]
214214
if not is_reference_id_first:
215215
try:
216-
return_list.extend(get_path_speaker_list(self.path_audio))
216+
return_list.extend(get_path_speaker_list(cls.path_audio))
217217
except FileHandleException:
218218
logger.warning("音频文件夹不存在或无法读取")
219219
return return_list

nonebot_plugin_fishspeech_tts/fish_speech_api.py

+17-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from typing import ClassVar
23

34
import ormsgpack
45
from httpx import (
@@ -24,22 +25,15 @@
2425

2526

2627
class FishSpeechAPI:
27-
def __init__(self):
28-
self.api_url: str = API_URL
29-
self.path_audio: Path = PATH_AUDIO
30-
self.headers = {
31-
"content-type": "application/msgpack",
32-
}
33-
34-
# 如果音频文件夹不存在, 则创建音频文件夹
35-
if not self.path_audio.exists():
36-
self.path_audio.mkdir(parents=True)
37-
logger.warning(f"音频文件夹{self.path_audio.name}不存在, 已创建")
38-
elif not self.path_audio.is_dir():
39-
raise NotADirectoryError(f"{self.path_audio.name}不是一个文件夹")
28+
api_url: str = API_URL
29+
path_audio: Path = PATH_AUDIO
30+
_headers: ClassVar[dict] = {
31+
"content-type": "application/msgpack",
32+
}
4033

34+
@classmethod
4135
async def generate_servettsrequest(
42-
self,
36+
cls,
4337
text: str,
4438
speaker_name: str,
4539
chunk_length: ChunkLength = ChunkLength.NORMAL,
@@ -59,7 +53,7 @@ async def generate_servettsrequest(
5953

6054
references = []
6155
try:
62-
speaker_audio_path = get_speaker_audio_path(self.path_audio, speaker_name)
56+
speaker_audio_path = get_speaker_audio_path(cls.path_audio, speaker_name)
6357
except FileHandleException as e:
6458
raise APIException(str(e)) from e
6559
for audio in speaker_audio_path:
@@ -82,7 +76,8 @@ async def generate_servettsrequest(
8276
mp3_bitrate=64,
8377
)
8478

85-
async def generate_tts(self, request: ServeTTSRequest) -> bytes:
79+
@classmethod
80+
async def generate_tts(cls, request: ServeTTSRequest) -> bytes:
8681
"""
8782
获取TTS音频
8883
@@ -94,12 +89,12 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
9489
try:
9590
async with AsyncClient() as client:
9691
response = await client.post(
97-
self.api_url,
98-
headers=self.headers,
92+
cls.api_url,
93+
headers=cls._headers,
9994
content=ormsgpack.packb(
10095
request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
10196
),
102-
timeout=120,
97+
timeout=60,
10398
)
10499
return response.content
105100
except (
@@ -113,14 +108,15 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
113108
except Exception as e:
114109
raise APIException(f"{e}\n获取TTS音频失败, 检查API后端") from e
115110

116-
def get_speaker_list(self) -> list[str]:
111+
@classmethod
112+
def get_speaker_list(cls) -> list[str]:
117113
"""
118114
获取说话人列表
119115
120116
Returns:
121117
list[str]: 说话人列表
122118
"""
123119
try:
124-
return get_path_speaker_list(self.path_audio)
120+
return get_path_speaker_list(cls.path_audio)
125121
except FileHandleException as e:
126122
raise APIException(str(e)) from e

nonebot_plugin_fishspeech_tts/on_start_up.py renamed to nonebot_plugin_fishspeech_tts/hook.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from nonebot.log import logger
44

55
from .config import config
6+
from .fish_audio_api import FishAudioAPI
67

78
IS_ONLINE = config.tts_is_online
89
API = config.online_api_url
@@ -18,6 +19,16 @@ async def check_online_api():
1819
response = await client.get(API)
1920
rsp_text = response.text
2021
if "Nothing" in rsp_text:
21-
logger.warning("在线API可用")
22+
logger.success("在线API可用")
2223
except TimeoutException as e:
2324
logger.warning(f"在线API不可用: {e}\n请尝试更换API地址或配置代理")
25+
26+
@driver.on_startup
27+
async def check_files():
28+
"""检查音频文件夹是否存在"""
29+
path_audio = FishAudioAPI.path_audio
30+
if not path_audio.exists():
31+
path_audio.mkdir(parents=True)
32+
logger.warning(f"音频文件夹{path_audio.name}不存在, 已创建")
33+
elif not path_audio.is_dir():
34+
logger.error(f"音频文件夹{path_audio.name}存在, 但不是文件夹")

nonebot_plugin_fishspeech_tts/matcher.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,20 @@ async def tts_handle(message: UniMsg, regex_group: dict = RegexDict()): # noqa:
5656
# TODO: speed = regex_group["speed"]
5757

5858
try:
59-
fish_audio_api = FishAudioAPI()
60-
fish_speech_api = FishSpeechAPI()
6159
if is_online:
6260
await tts_handler.send("正在通过在线api合成语音, 请稍等")
63-
request = await fish_audio_api.generate_servettsrequest(
61+
request = await FishAudioAPI.generate_servettsrequest(
6462
text, speaker, chunk_length
6563
)
6664
# TODO: request = await fish_audio_api.generate_ttsrequest(text, speaker, speed)
67-
audio = await fish_audio_api.generate_tts(request)
65+
audio = await FishAudioAPI.generate_tts(request)
6866
else:
6967
await tts_handler.send("正在通过本地api合成语音, 请稍等")
70-
request = await fish_speech_api.generate_servettsrequest(
68+
request = await FishSpeechAPI.generate_servettsrequest(
7169
text, speaker, chunk_length
7270
)
7371
# TODO: request = await fish_speech_api.generate_ttsrequest(text, speaker, speed)
74-
audio = await fish_speech_api.generate_tts(request)
72+
audio = await FishSpeechAPI.generate_tts(request)
7573
await UniMessage.voice(raw=audio).finish()
7674

7775
except APIException as e:
@@ -81,13 +79,11 @@ async def tts_handle(message: UniMsg, regex_group: dict = RegexDict()): # noqa:
8179
@speaker_list.handle()
8280
async def speaker_list_handle():
8381
try:
84-
fish_audio_api = FishAudioAPI()
85-
fish_speech_api = FishSpeechAPI()
8682
if is_online:
87-
_list = fish_audio_api.get_speaker_list()
83+
_list = FishAudioAPI.get_speaker_list()
8884
await speaker_list.finish("语音角色列表: " + ", ".join(_list))
8985
else:
90-
_list = fish_speech_api.get_speaker_list()
86+
_list = FishSpeechAPI.get_speaker_list()
9187
await speaker_list.finish("语音角色列表: " + ", ".join(_list))
9288
except APIException as e:
9389
await speaker_list.finish(str(e))
@@ -96,12 +92,11 @@ async def speaker_list_handle():
9692
@balance.handle()
9793
async def balance_handle():
9894
try:
99-
fish_audio_api = FishAudioAPI()
10095
if is_online:
10196
await balance.send("正在查询在线语音余额, 请稍等")
102-
balance_float = await fish_audio_api.get_balance()
97+
balance_float = await FishAudioAPI.get_balance()
10398
await balance.finish(f"语音余额为: {balance_float}")
10499
else:
105-
await balance.finish("本地api无法查询余额")
100+
await balance.finish("本地api无需查询余额")
106101
except APIException as e:
107102
await balance.finish(str(e))

0 commit comments

Comments
 (0)