Skip to content

Commit 8414c34

Browse files
committed
web增加音色固定(#70)
1 parent eb2ae13 commit 8414c34

File tree

12 files changed

+176
-67
lines changed

12 files changed

+176
-67
lines changed

README.MD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,14 @@ flashtts infer \
209209
--torch_dtype "bfloat16" \
210210
--max_length 32768 \
211211
--llm_gpu_memory_utilization 0.6 \
212+
--fix_voice \ # 启动后将固定住spark-tts中内置音色(female和male)
212213
--host 0.0.0.0 \
213214
--port 8000
214215
```
216+
web地址:`http://localhost:8000`
217+
218+
接口文档地址:`http://localhost:8000/docs`
219+
215220
详细部署说明,请参考:[server.md](docs/zh/server/server.md)
216221

217222
## ⚡ 推理速度

README_EN.MD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,15 @@ Server deployment:
161161
--torch_dtype "bfloat16" \
162162
--max_length 32768 \
163163
--llm_gpu_memory_utilization 0.6 \
164+
--fix_voice \ # Whether to fix the spark-tts timbre (female and male)
164165
--host 0.0.0.0 \
165166
--port 8000
166167
```
167168

169+
Web address: `http://localhost:8000`
170+
171+
Interface document address: `http://localhost:8000/docs`
172+
168173
For detailed deployment,please refer to: [server.md](docs/en/server/server.md)
169174

170175
## ⚡ Inference Speed

docs/en/server/server.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
--torch_dtype "bfloat16" \ # Spark-TTS does not support bfloat16 on all devices; use float32 if needed
1919
--max_length 32768 \
2020
--llm_gpu_memory_utilization 0.6 \
21+
--fix_voice \ # Whether to fix the spark-tts timbre (female and male)
2122
--host 0.0.0.0 \
2223
--port 8000
2324
```
@@ -89,6 +90,7 @@
8990
| `--wait_timeout` | float | Timeout (in seconds) for dynamic batching | 0.01 |
9091
| `--host` | str | Host address to bind | `0.0.0.0` |
9192
| `--port` | int | Port number to listen on | 8000 |
93+
| `--fix_voice` | bool | Fixes the female and male timbres in the spark-tts model, ensuring they remain unchanged. | False |
9294
9395
---
9496

docs/zh/server/server.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
--torch_dtype "bfloat16" \ # 对于spark-tts模型,不支持bfloat16的设备,只能设置为float32.
1919
--max_length 32768 \
2020
--llm_gpu_memory_utilization 0.6 \
21+
--fix_voice \ # 是否固定spark-tts音色(female和male)
2122
--host 0.0.0.0 \
2223
--port 8000
2324
```
@@ -89,6 +90,7 @@
8990
| `--wait_timeout` | float | 动态批处理请求超时秒数 | 0.01 |
9091
| `--host` | str | 服务监听地址 | `0.0.0.0` |
9192
| `--port` | int | 服务监听端口 | 8000 |
93+
| `--fix_voice` | bool | 是否固定住spark-tts模型的内置音色 | False |
9294
9395
### 3. 接口使用流程
9496

examples/inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,16 @@ async def retain_acoustic_example(engine: AutoEngine):
178178
name="female",
179179
return_acoustic_tokens=True
180180
)
181-
# 2. 真巧,这是我想要的音色,直接保存为txt
182-
tokens.save("acoustic_tokens.txt")
181+
# 2. 真巧,这是我想要的音色,直接保存为json
182+
tokens.save("acoustic_tokens.json")
183183
# 同时保存第一次生成的音频,以便对比
184184
engine.write_audio(wav, "first.wav")
185185

186186
# 3. 加载保存的音色,生成第二个音频
187187
wav = await engine.speak_async(
188188
text="国际局势中,某国领导人围绕地区冲突停火问题展开对话,双方同意停止攻击对方能源设施并推动谈判,但对全面停火提议的落实仍存分歧。",
189189
name="female",
190-
acoustic_tokens=SparkAcousticTokens.load("acoustic_tokens.txt"),
190+
acoustic_tokens=SparkAcousticTokens.load("acoustic_tokens.json"),
191191
)
192192
engine.write_audio(wav, "second.wav")
193193
# 4. 试听first.wav和second.wav,惊奇发现,这两个音频的音色是一致的
@@ -212,16 +212,16 @@ async def retain_acoustic_stream_example(engine: AutoEngine):
212212
audios.append(chunk)
213213
audio = np.concatenate(audios)
214214

215-
# 2. 真巧,这是我想要的音色,直接保存为txt
215+
# 2. 真巧,这是我想要的音色,直接保存为json
216216
engine.write_audio(audio, "first.wav")
217-
acoustic_tokens.save("acoustic_tokens.txt")
217+
acoustic_tokens.save("acoustic_tokens.json")
218218

219219
# 3. 加载保存的音色,生成第二个音频
220220
audios = []
221221
async for chunk in engine.speak_stream_async(
222222
text="今日是二零二五年三月十九日,国内外热点事件聚焦于国际局势、经济政策及社会民生领域。",
223223
name="female",
224-
acoustic_tokens=SparkAcousticTokens.load("acoustic_tokens.txt")
224+
acoustic_tokens=SparkAcousticTokens.load("acoustic_tokens.json")
225225
):
226226
audios.append(chunk)
227227

flashtts/commands/serve.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from flashtts.server.base_router import base_router, SPEAKER_TMP_PATH
1919
from flashtts.server.openai_router import openai_router
2020
from flashtts.commands.utils import add_model_parser
21+
from flashtts.server.protocol import StateInfo
2122

2223
logger = get_logger()
2324

@@ -132,8 +133,11 @@ async def lifespan(app: FastAPI):
132133
await warmup_engine(engine)
133134
# 将 engine 保存到 app.state 中,方便路由中使用
134135
app.state.engine = engine
135-
app.state.model_name = args.model_name or engine.engine_name
136-
app.state.db_path = args.db_path
136+
app.state.state_info = StateInfo(
137+
model_name=args.model_name or engine.engine_name,
138+
db_path=args.db_path,
139+
fix_voice=args.fix_voice
140+
)
137141
yield
138142

139143
if os.path.exists(SPEAKER_TMP_PATH):
@@ -194,7 +198,12 @@ def register_subcommand(parser: ArgumentParser):
194198
"--api_key",
195199
type=str,
196200
default=None,
197-
help="API key for request authentication",
201+
help="API key for request authentication"
202+
)
203+
serve_parser.add_argument(
204+
"--fix_voice",
205+
action="store_true",
206+
help="Fixes the female and male timbres in the spark-tts model, ensuring they remain unchanged."
198207
)
199208

200209
serve_parser.add_argument(

flashtts/engine/spark_engine.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Time :2025/3/29 11:16
33
# Author :Hui Huang
44
import asyncio
5+
import json
56
import math
67
import os.path
78
import re
@@ -42,15 +43,18 @@
4243
"very_high": 4,
4344
}
4445

45-
GENDER_MAP = {
46+
GENDER_MAP: dict[Literal["male", "female"], int] = {
4647
"female": 0,
4748
"male": 1,
4849
}
4950

51+
ID2GENDER = {v: k for k, v in GENDER_MAP.items()}
52+
5053

5154
@dataclass
5255
class SparkAcousticTokens:
5356
prompt: str
57+
gender: Literal["female", "male"]
5458
global_tokens: Optional[torch.Tensor] = None
5559

5660
def __post_init__(self):
@@ -73,15 +77,21 @@ def _parse_prompt(self):
7377
)
7478
self.global_tokens = global_token_ids
7579

80+
def to_dict(self) -> dict[str, str]:
81+
return {
82+
"prompt": self.prompt,
83+
"gender": self.gender
84+
}
85+
7686
def save(self, filepath: str):
7787
with open(filepath, 'w', encoding='utf8') as w:
78-
w.write(self.prompt)
88+
w.write(json.dumps(self.to_dict(), ensure_ascii=False, indent=2))
7989

8090
@classmethod
8191
def load(cls, filepath: str):
8292
with open(filepath, 'r', encoding='utf8') as r:
83-
prompt = r.read()
84-
return cls(prompt=prompt)
93+
data = json.load(r)
94+
return cls(**data)
8595

8696

8797
def process_prompt(
@@ -619,6 +629,18 @@ async def _control_generate(
619629
acoustic_tokens: Optional[SparkAcousticTokens | str] = None,
620630
return_acoustic_tokens: bool = False,
621631
**kwargs):
632+
gender: Literal["female", "male"] = gender if gender in ["female", "male"] else "female"
633+
634+
if acoustic_tokens is not None and isinstance(acoustic_tokens, str):
635+
acoustic_tokens = SparkAcousticTokens.load(acoustic_tokens)
636+
637+
if acoustic_tokens is not None:
638+
if acoustic_tokens.gender != gender:
639+
logger.warning(
640+
f"The provided `acoustic_tokens` belong to the `{acoustic_tokens.gender}`, but the specified gender is {gender}. "
641+
f"The `acoustic_tokens` will therefore not be used.")
642+
acoustic_tokens = None
643+
622644
segments = self.preprocess_text(
623645
text,
624646
window_size=window_size,
@@ -654,14 +676,11 @@ async def generate_audio(
654676
"completion": generated['completion']
655677
}
656678

657-
if acoustic_tokens is not None and isinstance(acoustic_tokens, str):
658-
acoustic_tokens = SparkAcousticTokens(acoustic_tokens)
659-
660679
audios = []
661680
if acoustic_tokens is None:
662681
# 如果没有传入音色,使用第一段生成音色token,将其与后面片段一起拼接,使用相同音色token引导输出semantic tokens。
663682
first_output = await generate_audio(segments[0], acoustic_token=None)
664-
acoustic_tokens = SparkAcousticTokens(first_output['completion'])
683+
acoustic_tokens = SparkAcousticTokens(first_output['completion'], gender=gender)
665684
audios.append(first_output['audio'])
666685
segments = segments[1:]
667686

@@ -706,7 +725,7 @@ async def speak_async(
706725
logger.error(err_msg)
707726
raise ValueError(err_msg)
708727
self.set_seed(seed=self.seed)
709-
acoustic_tokens = None
728+
out_acoustic_tokens = None
710729
if name in ["female", "male"]:
711730
output = await self._control_generate(
712731
text=text,
@@ -727,7 +746,7 @@ async def speak_async(
727746
)
728747
if return_acoustic_tokens and isinstance(output, tuple):
729748
audio = output[0]
730-
acoustic_tokens = output[1]
749+
out_acoustic_tokens = output[1]
731750
else:
732751
audio = output
733752
else:
@@ -756,8 +775,8 @@ async def speak_async(
756775

757776
torch.cuda.empty_cache()
758777

759-
if acoustic_tokens is not None:
760-
return audio, acoustic_tokens
778+
if out_acoustic_tokens is not None:
779+
return audio, out_acoustic_tokens
761780
return audio
762781

763782
async def _control_stream_generate(
@@ -782,6 +801,8 @@ async def _control_stream_generate(
782801
return_acoustic_tokens: bool = False,
783802
**kwargs
784803
):
804+
gender: Literal["female", "male"] = gender if gender in ["female", "male"] else "female"
805+
785806
if audio_chunk_duration < 0.5:
786807
err_msg = "audio_chunk_duration at least 0.5 seconds"
787808
logger.error(err_msg)
@@ -792,7 +813,14 @@ async def _control_stream_generate(
792813
raise ValueError(err_msg)
793814

794815
if acoustic_tokens is not None and isinstance(acoustic_tokens, str):
795-
acoustic_tokens = SparkAcousticTokens(acoustic_tokens)
816+
acoustic_tokens = SparkAcousticTokens.load(acoustic_tokens)
817+
818+
if acoustic_tokens is not None:
819+
if acoustic_tokens.gender != gender:
820+
logger.warning(
821+
f"The provided `acoustic_tokens` belong to the `{acoustic_tokens.gender}`, but the specified gender is {gender}. "
822+
f"The `acoustic_tokens` will therefore not be used.")
823+
acoustic_tokens = None
796824

797825
audio_tokenizer_frame_rate = 50
798826
max_chunk_size = math.ceil(max_audio_chunk_duration * audio_tokenizer_frame_rate)
@@ -840,7 +868,7 @@ async def _control_stream_generate(
840868
r"(<\|start_acoustic_token\|>.*?<\|end_global_token\|>)",
841869
completion)
842870
if len(acoustics) > 0:
843-
acoustic_tokens = SparkAcousticTokens(acoustics[0])
871+
acoustic_tokens = SparkAcousticTokens(acoustics[0], gender=gender)
844872
completion = ""
845873
else:
846874
continue

0 commit comments

Comments
 (0)