Skip to content

Commit d9d7f11

Browse files
committed
openai接口支持语音克隆(#55)
1 parent 1a5f75f commit d9d7f11

File tree

9 files changed

+150
-36
lines changed

9 files changed

+150
-36
lines changed

docs/en/server/client.md

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,50 @@ def clone_voice_stream():
199199

200200
### Sample Code
201201

202+
Call a built-in audio character:
203+
202204
```python
203205
from openai import OpenAI
204206

205207

206208
def openai_speech():
207-
client = OpenAI(base_url=f"{BASE_URL}/v1", api_key="YOUR_KEY")
209+
client = OpenAI(
210+
base_url=f"{BASE_URL}/v1",
211+
api_key="not-needed" # If an API key is set, please provide it
212+
)
213+
with client.audio.speech.with_streaming_response.create(
214+
model="spark",
215+
voice="赞助商", # Name of the built-in voice
216+
input="Hello, I am the invincible little cutie."
217+
) as response:
218+
response.stream_to_file("out.mp3")
219+
print("Output file: out.mp3")
220+
```
221+
222+
Or provide a reference audio to use the voice cloning feature:
223+
224+
```python
225+
from openai import OpenAI
226+
import base64
227+
228+
229+
def openai_speech():
230+
client = OpenAI(
231+
base_url=f"{BASE_URL}/v1",
232+
api_key="not-needed" # If an API key is set, please provide it
233+
)
234+
with open("data/mega-roles/御姐/御姐配音.wav", "rb") as f:
235+
audio_bytes = f.read()
236+
# Convert the binary audio data into a base64-encoded string
237+
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
238+
208239
with client.audio.speech.with_streaming_response.create(
209-
model="orpheus", voice="tara", input="Hello"
210-
) as r:
211-
r.stream_to_file("out.mp3")
212-
print("Output saved: out.mp3")
240+
model="spark",
241+
voice=audio_base64, # Replace the 'voice' parameter with the audio's base64 to trigger voice cloning
242+
input="Hello, I am the invincible little cutie."
243+
) as response:
244+
response.stream_to_file("clone.mp3")
245+
print("Cloned file: clone.mp3")
213246
```
214247

215248
### Steps

docs/en/server/server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ curl -X POST http://localhost:8000/clone_voice \
156156
- Uses `OpenAISpeechRequest` format:
157157
- `model`: Model ID or name
158158
- `input`: Text to synthesize
159-
- `voice`: Voice name or preset
159+
- `voice`: The name of the audio character you want to use, or a URL or base64 of a reference audio.
160160
- Other parameters same as Clone/Speak
161161

162162
#### 4.5 Retrieve Available Roles: `GET /audio_roles` or `GET /v1/audio_roles`

docs/zh/server/client.md

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,19 +193,49 @@ def clone_voice_stream():
193193

194194
### 示例代码
195195

196+
调用内置的音频角色
196197
```python
197198
from openai import OpenAI
198199

199200

200201
def openai_speech():
201-
client = OpenAI(base_url=f"{BASE_URL}/v1", api_key="YOUR_KEY")
202+
client = OpenAI(
203+
base_url=f"{BASE_URL}/v1",
204+
api_key="not-needed" # 如果设置了api key,请传入
205+
)
202206
with client.audio.speech.with_streaming_response.create(
203-
model="orpheus", voice="tara", input="Hello"
204-
) as r:
205-
r.stream_to_file("out.mp3")
207+
model="spark",
208+
voice="赞助商",
209+
input="你好,我是无敌的小可爱。"
210+
) as response:
211+
response.stream_to_file("out.mp3")
206212
print("输出文件:out.mp3")
207213
```
214+
或者传入参考音频,调用语音克隆功能
208215

216+
```python
217+
from openai import OpenAI
218+
import base64
219+
220+
221+
def openai_speech():
222+
client = OpenAI(
223+
base_url=f"{BASE_URL}/v1",
224+
api_key="not-needed" # 如果设置了api key,请传入
225+
)
226+
with open("data/mega-roles/御姐/御姐配音.wav", "rb") as f:
227+
audio_bytes = f.read()
228+
# 将二进制音频数据转换为 base64 字符串
229+
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
230+
231+
with client.audio.speech.with_streaming_response.create(
232+
model="spark",
233+
voice=audio_base64, # 使用音频的base64编码替换voice,即可触发语音克隆
234+
input="你好,我是无敌的小可爱。"
235+
) as response:
236+
response.stream_to_file("clone.mp3")
237+
print("克隆文件:out.mp3")
238+
```
209239
### 步骤说明
210240

211241
1. 初始化 OpenAI 客户端,指定 base_url。

docs/zh/server/server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ curl -X POST http://localhost:8000/clone_voice \
150150
- 路径与功能与上述接口一致,使用 `OpenAISpeechRequest` 协议:
151151
- `model`: 模型 ID 或名称
152152
- `input`: 合成文本
153-
- `voice`: 基础或组合 voice 名称
153+
- `voice`: 您想要使用的音频字符的名称,或者参考音频的 URL 或 base64。
154154
- 其他参数同 Clone/Speak。
155155

156156
#### 4.5 获取角色列表:`GET /audio_roles``GET /v1/audio_roles`

examples/client.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,37 @@ def openai_speech():
166166
api_key="not-needed" # 如果设置了api key,请传入
167167
)
168168
with client.audio.speech.with_streaming_response.create(
169-
model="orpheus",
170-
voice="tara",
171-
input="Hey there guys. It's, <giggle> Tara here, and let me introduce you to Zac.. who seems to asleep. Zac, it's time to wakey-wakey!"
169+
model="spark",
170+
voice="赞助商",
171+
input="你好,我是无敌的小可爱。"
172172
) as response:
173173
response.stream_to_file("output.mp3")
174174

175+
def openai_clone():
176+
"""
177+
openai 克隆模式,目前仅支持spark tts
178+
Returns:
179+
180+
"""
181+
from openai import OpenAI
182+
183+
client = OpenAI(
184+
base_url=f"{BASE_URL}/v1",
185+
api_key="not-needed" # 如果设置了api key,请传入
186+
)
187+
188+
# 选取一个没有在spark tts内置角色中的音频
189+
with open("data/mega-roles/御姐/御姐配音.wav", "rb") as f:
190+
audio_bytes = f.read()
191+
# 将二进制音频数据转换为 base64 字符串
192+
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
193+
194+
with client.audio.speech.with_streaming_response.create(
195+
model="spark",
196+
voice=audio_base64, # 使用音频的base64编码替换voice,即可触发语音克隆
197+
input="你好,我是无敌的小可爱。"
198+
) as response:
199+
response.stream_to_file("output.mp3")
175200

176201
if __name__ == "__main__":
177202
clone_voice_stream()

flashtts/server/openai_router.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi.responses import JSONResponse, Response, StreamingResponse
77
from .protocol import OpenAISpeechRequest, ModelCard, ModelList
88
from .utils.audio_writer import StreamingAudioWriter
9-
from .utils.utils import generate_audio, generate_audio_stream
9+
from .utils.utils import generate_audio, generate_audio_stream, load_base64_or_url
1010
from ..engine import AutoEngine
1111
from ..logger import get_logger
1212

@@ -91,7 +91,6 @@ async def create_speech(
9191
"type": "invalid_request_error",
9292
},
9393
)
94-
audio_writer = StreamingAudioWriter(request.response_format, sample_rate=engine.SAMPLE_RATE)
9594

9695
# Set content type based on format
9796
content_type = {
@@ -104,7 +103,6 @@ async def create_speech(
104103
}.get(request.response_format, f"audio/{request.response_format}")
105104

106105
api_inputs = dict(
107-
name=request.voice,
108106
text=request.input,
109107
temperature=request.temperature,
110108
top_k=request.top_k,
@@ -117,10 +115,33 @@ async def create_speech(
117115
if engine.engine_name.lower() == 'spark':
118116
api_inputs['pitch'] = float_to_speed_label(request.pitch)
119117
api_inputs['speed'] = float_to_speed_label(request.speed)
118+
119+
if engine._SUPPORT_CLONE and request.voice not in engine.list_roles():
120+
# 如果传入的voice为url或者base64,将启动语音克隆,暂且不支持mega3
121+
if engine.engine_name == 'mega':
122+
err_msg = ("Openai router does not currently support the voice cloning function of mega tts, "
123+
"because the model requires an additional `latent_file`.")
124+
logger.error(err_msg)
125+
raise HTTPException(status_code=400, detail={"error": err_msg})
126+
ref_audio = await load_base64_or_url(request.voice)
127+
api_inputs['reference_audio'] = ref_audio
128+
129+
if request.stream:
130+
tts_fn = engine.clone_voice_stream_async
131+
else:
132+
tts_fn = engine.clone_voice_async
133+
else:
134+
api_inputs['name'] = request.voice
135+
if request.stream:
136+
tts_fn = engine.speak_stream_async
137+
else:
138+
tts_fn = engine.speak_async
139+
140+
audio_writer = StreamingAudioWriter(request.response_format, sample_rate=engine.SAMPLE_RATE)
120141
if request.stream:
121142
return StreamingResponse(
122143
generate_audio_stream(
123-
engine.speak_stream_async,
144+
tts_fn,
124145
api_inputs,
125146
audio_writer,
126147
client_request
@@ -140,7 +161,7 @@ async def create_speech(
140161
}
141162
try:
142163
# Generate complete audio using public interface
143-
audio_data = await engine.speak_async(
164+
audio_data = await tts_fn(
144165
**api_inputs
145166
)
146167
except Exception as e:

flashtts/server/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class OpenAISpeechRequest(BaseModel):
215215
input: str = Field(..., description="The text to generate audio for")
216216
voice: str = Field(
217217
default=None,
218-
description="The voice to use for generation. Can be a base voice or a combined voice name.",
218+
description="The name of the audio character you want to use, or a URL or base64 of a reference audio.",
219219
)
220220
pitch: float = Field(
221221
default=1.0,

flashtts/server/utils/utils.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,28 @@ async def get_audio_bytes_from_url(url: str) -> bytes:
2323
return response.content
2424

2525

26-
async def load_audio_bytes(audio_file, audio):
27-
if audio_file is None:
28-
# 根据 reference_audio 内容判断读取方式
29-
if audio.startswith("http://") or audio.startswith("https://"):
30-
audio_bytes = await get_audio_bytes_from_url(audio)
31-
else:
32-
try:
33-
audio_bytes = base64.b64decode(audio)
34-
except Exception as e:
35-
logger.warning("无效的 base64 音频数据: " + str(e))
36-
raise HTTPException(status_code=400, detail="无效的 base64 音频数据: " + str(e))
37-
# 利用 BytesIO 包装字节数据,然后使用 soundfile 读取为 numpy 数组
26+
async def load_base64_or_url(audio):
27+
# 根据 reference_audio 内容判断读取方式
28+
if audio.startswith("http://") or audio.startswith("https://"):
29+
audio_bytes = await get_audio_bytes_from_url(audio)
30+
else:
3831
try:
39-
bytes_io = io.BytesIO(audio_bytes)
32+
audio_bytes = base64.b64decode(audio)
4033
except Exception as e:
41-
logger.warning("读取参考音频失败: " + str(e))
42-
raise HTTPException(status_code=400, detail="读取参考音频失败: " + str(e))
34+
logger.warning("无效的 base64 音频数据: " + str(e))
35+
raise HTTPException(status_code=400, detail="无效的 base64 音频数据: " + str(e))
36+
# 利用 BytesIO 包装字节数据,然后使用 soundfile 读取为 numpy 数组
37+
try:
38+
bytes_io = io.BytesIO(audio_bytes)
39+
except Exception as e:
40+
logger.warning("读取参考音频失败: " + str(e))
41+
raise HTTPException(status_code=400, detail="读取参考音频失败: " + str(e))
42+
return bytes_io
43+
44+
45+
async def load_audio_bytes(audio_file, audio):
46+
if audio_file is None:
47+
bytes_io = await load_base64_or_url(audio)
4348
else:
4449
content = await audio_file.read()
4550
if not content:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_readme() -> str:
2424

2525
setup(
2626
name='flashtts',
27-
version='0.1.3',
27+
version='0.1.4',
2828
description='A Fast TTS toolkit',
2929
long_description=get_readme(),
3030
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)