@@ -36,26 +36,17 @@ class FishAudioAPI:
36
36
FishAudioAPI类, 用于调用FishAudio的API接口
37
37
"""
38
38
39
- def __init__ ( self ):
40
- self . api_url : str = API_URL
41
- self . path_audio : Path = Path ( config . tts_audio_path )
42
- self . proxy = API_PROXY
39
+ api_url : str = API_URL
40
+ path_audio : Path = Path ( config . tts_audio_path )
41
+ proxy = API_PROXY
42
+ from typing import ClassVar
43
43
44
- # 如果在线授权码为空, 且使用在线api, 则抛出异常
45
- if not config .online_authorization and config .tts_is_online :
46
- raise AuthorizationException ("请先在配置文件中填写在线授权码或使用离线api" )
47
- self .headers = {
48
- "Authorization" : f"Bearer { config .online_authorization } " ,
49
- }
50
-
51
- # 如果音频文件夹不存在, 则创建音频文件夹
52
- if not self .path_audio .exists ():
53
- self .path_audio .mkdir (parents = True )
54
- logger .warning (f"音频文件夹{ self .path_audio .name } 不存在, 已创建" )
55
- elif not self .path_audio .is_dir ():
56
- raise FileHandleException (f"{ self .path_audio .name } 不是一个文件夹" )
57
-
58
- async def _get_reference_id_by_speaker (self , speaker : str ) -> str :
44
+ _headers : ClassVar [dict ] = {
45
+ "Authorization" : f"Bearer { config .online_authorization } " ,
46
+ }
47
+
48
+ @classmethod
49
+ async def _get_reference_id_by_speaker (cls , speaker : str ) -> str :
59
50
"""
60
51
通过说话人姓名获取说话人的reference_id
61
52
@@ -68,13 +59,13 @@ async def _get_reference_id_by_speaker(self, speaker: str) -> str:
68
59
exception:
69
60
APIException: 获取语音角色列表为空
70
61
"""
71
- request_api = self .api_url + "/model"
62
+ request_api = cls .api_url + "/model"
72
63
sort_options = ["score" , "task_count" , "created_at" ]
73
- async with AsyncClient (proxy = self .proxy ) as client :
64
+ async with AsyncClient (proxy = cls .proxy ) as client :
74
65
for sort_by in sort_options :
75
66
params = {"title" : speaker , "sort_by" : sort_by }
76
67
response = await client .get (
77
- request_api , params = params , headers = self . headers
68
+ request_api , params = params , headers = cls . _headers
78
69
)
79
70
resp_data = response .json ()
80
71
if resp_data ["total" ] == 0 :
@@ -84,8 +75,9 @@ async def _get_reference_id_by_speaker(self, speaker: str) -> str:
84
75
return item ["_id" ]
85
76
raise APIException ("未找到对应的角色" )
86
77
78
+ @classmethod
87
79
async def generate_servettsrequest (
88
- self ,
80
+ cls ,
89
81
text : str ,
90
82
speaker_name : str ,
91
83
chunk_length : ChunkLength = ChunkLength .NORMAL ,
@@ -103,15 +95,18 @@ async def generate_servettsrequest(
103
95
Returns:
104
96
ServeTTSRequest: TTS请求
105
97
"""
98
+ if not config .online_authorization and config .tts_is_online :
99
+ raise AuthorizationException ("请先在配置文件中填写在线授权码或使用离线api" )
100
+
106
101
reference_id = None
107
102
references = []
108
103
try :
109
104
if is_reference_id_first :
110
- reference_id = await self ._get_reference_id_by_speaker (speaker_name )
105
+ reference_id = await cls ._get_reference_id_by_speaker (speaker_name )
111
106
else :
112
107
try :
113
108
speaker_audio_path = get_speaker_audio_path (
114
- self .path_audio , speaker_name
109
+ cls .path_audio , speaker_name
115
110
)
116
111
for audio in speaker_audio_path :
117
112
audio_bytes = audio .read_bytes ()
@@ -121,7 +116,7 @@ async def generate_servettsrequest(
121
116
)
122
117
except FileHandleException :
123
118
logger .warning ("音频文件夹不存在, 已转为在线模型优先模式" )
124
- reference_id = await self ._get_reference_id_by_speaker (speaker_name )
119
+ reference_id = await cls ._get_reference_id_by_speaker (speaker_name )
125
120
except APIException as e :
126
121
raise e from e
127
122
return ServeTTSRequest (
@@ -138,7 +133,8 @@ async def generate_servettsrequest(
138
133
references = references ,
139
134
)
140
135
141
- async def generate_tts (self , request : ServeTTSRequest ) -> bytes :
136
+ @classmethod
137
+ async def generate_tts (cls , request : ServeTTSRequest ) -> bytes :
142
138
"""
143
139
获取TTS音频
144
140
@@ -149,14 +145,14 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
149
145
bytes: TTS音频二进制数据
150
146
"""
151
147
if request .references :
152
- self . headers ["content-type" ] = "application/msgpack"
148
+ cls . _headers ["content-type" ] = "application/msgpack"
153
149
try :
154
150
async with (
155
- AsyncClient (proxy = self .proxy ) as client ,
151
+ AsyncClient (proxy = cls .proxy ) as client ,
156
152
client .stream (
157
153
"POST" ,
158
- self .api_url + "/v1/tts" ,
159
- headers = self . headers ,
154
+ cls .api_url + "/v1/tts" ,
155
+ headers = cls . _headers ,
160
156
content = ormsgpack .packb (
161
157
request , option = ormsgpack .OPT_SERIALIZE_PYDANTIC
162
158
),
@@ -169,17 +165,19 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
169
165
HTTPStatusError ,
170
166
) as e :
171
167
logger .error (f"获取TTS音频失败: { e } " )
172
- if self .proxy :
168
+ if cls .proxy :
173
169
raise HTTPException ("代理地址错误, 请检查代理地址是否正确" ) from e
174
170
raise HTTPException ("网络错误, 请检查网络连接" ) from e
175
171
else :
176
- self . headers ["content-type" ] = "application/json"
172
+ cls . _headers ["content-type" ] = "application/json"
177
173
try :
178
- async with AsyncClient (proxy = self .proxy ) as client :
174
+ async with AsyncClient (proxy = cls .proxy ) as client :
179
175
response = await client .post (
180
- self .api_url + "/v1/tts" ,
181
- headers = self .headers ,
182
- json = request .dict (),
176
+ cls .api_url + "/v1/tts" ,
177
+ headers = cls ._headers ,
178
+ json = ormsgpack .packb (
179
+ request , option = ormsgpack .OPT_SERIALIZE_PYDANTIC
180
+ ),
183
181
timeout = 60 ,
184
182
)
185
183
return response .content
@@ -190,30 +188,32 @@ async def generate_tts(self, request: ServeTTSRequest) -> bytes:
190
188
HTTPStatusError ,
191
189
) as e :
192
190
logger .error (f"获取TTS音频失败: { e } " )
193
- if self .proxy :
191
+ if cls .proxy :
194
192
raise HTTPException ("代理地址错误, 请检查代理地址是否正确" ) from e
195
193
raise HTTPException ("网络错误, 请检查网络连接" ) from e
196
194
197
- async def get_balance (self ) -> float :
195
+ @classmethod
196
+ async def get_balance (cls ) -> float :
198
197
"""
199
198
获取账户余额
200
199
"""
201
- balance_url = self .api_url + "/wallet/self/api-credit"
202
- async with AsyncClient (proxy = self .proxy ) as client :
203
- response = await client .get (balance_url , headers = self . headers )
200
+ balance_url = cls .api_url + "/wallet/self/api-credit"
201
+ async with AsyncClient (proxy = cls .proxy ) as client :
202
+ response = await client .get (balance_url , headers = cls . _headers )
204
203
try :
205
204
return response .json ()["credit" ]
206
205
except KeyError :
207
206
raise AuthorizationException ("授权码错误或已失效" ) from KeyError
208
207
209
- def get_speaker_list (self ) -> list [str ]:
208
+ @classmethod
209
+ def get_speaker_list (cls ) -> list [str ]:
210
210
"""
211
211
获取语音角色列表
212
212
"""
213
213
return_list = ["请查看官网了解更多: https://fish.audio/zh-CN/" ]
214
214
if not is_reference_id_first :
215
215
try :
216
- return_list .extend (get_path_speaker_list (self .path_audio ))
216
+ return_list .extend (get_path_speaker_list (cls .path_audio ))
217
217
except FileHandleException :
218
218
logger .warning ("音频文件夹不存在或无法读取" )
219
219
return return_list
0 commit comments