Skip to content

Commit 9a33468

Browse files
authored
update whisper stream (#1796)
1 parent a7f1e39 commit 9a33468

File tree

8 files changed

+151
-1052
lines changed

8 files changed

+151
-1052
lines changed

.github/pylint.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ persistent=yes
8484

8585
# Minimum Python version to use for version dependent checks. Will default to
8686
# the version used to run pylint.
87-
py-version=3.9
87+
py-version=3.11
8888

8989
# Discover python modules and packages in the file system subtree.
9090
recursive=no

.github/workflows/ci_pipeline.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ jobs:
140140
strategy:
141141
matrix:
142142
os: [ubuntu-latest, macos-latest]
143-
python: [3.8]
143+
python: [3.9]
144144
runs-on: ${{ matrix.os }}
145145
steps:
146146
- uses: actions/checkout@v3
@@ -170,10 +170,10 @@ jobs:
170170
if: github.event_name == 'push' && github.repository_owner == 'mindspore-lab'
171171
steps:
172172
- uses: actions/checkout@v3
173-
- name: Set up Python 3.8
173+
- name: Set up Python 3.9
174174
uses: actions/setup-python@v4
175175
with:
176-
python-version: 3.8
176+
python-version: 3.9
177177
- uses: "lvyufeng/action-kaggle-gpu-test@latest"
178178
with:
179179
kaggle_username: "${{ secrets.KAGGLE_USERNAME }}"

.github/workflows/doc_rst_check.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Set up Python
1818
uses: actions/setup-python@v4
1919
with:
20-
python-version: 3.8
20+
python-version: 3.9
2121
- name: Install dependencies
2222
run: |
2323
python -m pip install --upgrade pip==24.0

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,7 @@ data*/
163163

164164
fusion_result.json
165165
aclinit.json
166-
xiyouji.txt
166+
xiyouji.txt
167+
*.safetensors
168+
*.jit
169+
flagged/

llm/inference/whisper/app_realtime.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import gradio as gr
2+
import time
3+
import numpy as np
4+
import mindspore
5+
from mindnlp.transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
6+
7+
ms_dtype = mindspore.float16
8+
MODEL_NAME = "openai/whisper-large-v3-turbo"
9+
10+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
11+
MODEL_NAME, ms_dtype=ms_dtype, low_cpu_mem_usage=True
12+
)
13+
14+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
15+
16+
pipe = pipeline(
17+
task="automatic-speech-recognition",
18+
model=model,
19+
tokenizer=processor.tokenizer,
20+
feature_extractor=processor.feature_extractor,
21+
ms_dtype=ms_dtype,
22+
)
23+
24+
prompt = "以下是普通话的句子。" # must have periods
25+
prompt_ids = processor.get_prompt_ids(prompt, return_tensors="ms")
26+
generate_kwargs = {"prompt_ids": prompt_ids}
27+
28+
def transcribe(inputs, previous_transcription):
29+
start_time = time.time()
30+
try:
31+
sample_rate, audio_data = inputs
32+
audio_data = audio_data.astype(np.float32)
33+
audio_data /= np.max(np.abs(audio_data))
34+
35+
transcription = pipe({"sampling_rate": sample_rate, "raw": audio_data}, generate_kwargs=generate_kwargs)["text"]
36+
previous_transcription += transcription
37+
38+
end_time = time.time()
39+
latency = end_time - start_time
40+
return previous_transcription, f"{latency:.2f}"
41+
except Exception as e:
42+
print(f"Error during Transcription: {e}")
43+
return previous_transcription, "Error"
44+
45+
46+
def clear():
47+
return ""
48+
49+
with gr.Blocks() as microphone:
50+
with gr.Column():
51+
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
52+
with gr.Row():
53+
input_audio_microphone = gr.Audio(streaming=True)
54+
output = gr.Textbox(label="Transcription", value="")
55+
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
56+
with gr.Row():
57+
clear_button = gr.Button("Clear Output")
58+
59+
input_audio_microphone.stream(transcribe, [input_audio_microphone, output], [output, latency_textbox], time_limit=45, stream_every=2, concurrency_limit=None)
60+
clear_button.click(clear, outputs=[output])
61+
62+
with gr.Blocks() as file:
63+
with gr.Column():
64+
gr.Markdown(f"# Realtime Whisper Large V3 Turbo: \n Transcribe Audio in Realtime. This Demo uses the Checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers.\n Note: The first token takes about 5 seconds. After that, it works flawlessly.")
65+
with gr.Row():
66+
input_audio_microphone = gr.Audio(sources="upload", type="numpy")
67+
output = gr.Textbox(label="Transcription", value="")
68+
latency_textbox = gr.Textbox(label="Latency (seconds)", value="0.0", scale=0)
69+
with gr.Row():
70+
submit_button = gr.Button("Submit")
71+
clear_button = gr.Button("Clear Output")
72+
73+
submit_button.click(transcribe, [input_audio_microphone, output], [output, latency_textbox], concurrency_limit=None)
74+
clear_button.click(clear, outputs=[output])
75+
76+
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
77+
gr.TabbedInterface([microphone, file], ["Microphone", "Transcribe from file"])
78+
79+
demo.launch()

llm/inference/whisper/app_stream.py

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,94 @@
1+
import queue
2+
import threading
3+
import time
4+
15
import gradio as gr
26
import numpy as np
37
import mindspore
4-
from mindnlp.transformers import pipeline
5-
8+
from mindspore.dataset.audio import Resample
9+
from mindnlp.transformers import pipeline, AutoProcessor
10+
from silero_vad_mindspore import load
611

712
MODEL_NAME = "openai/whisper-large-v3"
8-
BATCH_SIZE = 8
9-
FILE_LIMIT_MB = 1000
10-
YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
13+
THRESH_HOLD = 0.5
1114

15+
stream_queue = queue.Queue()
1216

17+
vad_model = load('silero_vad_v4')
18+
19+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
1320
pipe = pipeline(
1421
task="automatic-speech-recognition",
1522
model=MODEL_NAME,
23+
tokenizer=processor.tokenizer,
24+
feature_extractor=processor.feature_extractor,
1625
ms_dtype=mindspore.float16
1726
)
1827

28+
prompt = "以下是普通话的句子。" # must have periods
29+
prompt_ids = processor.get_prompt_ids(prompt, return_tensors="ms")
30+
31+
text = ""
32+
silence_count = 0
33+
34+
resample = Resample(48000, 16000)
35+
generate_kwargs = {"language": "zh", "task": "transcribe", "prompt_ids": prompt_ids}
36+
# "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), "no_speech_threshold": 0.5, "logprob_threshold": -1.0}
37+
38+
# warm up
39+
random_sample = np.random.randn(16000).astype(np.float32)
40+
vad_model(mindspore.tensor(random_sample), 16000)
41+
pipe(random_sample, generate_kwargs=generate_kwargs, return_timestamps='word')
42+
43+
def pipeline_consumer():
44+
global text
45+
while True:
46+
chunk = stream_queue.get()
47+
# print(speech_score)
48+
genreated_text = pipe(chunk, generate_kwargs=generate_kwargs, return_timestamps='word')["text"]
49+
text += genreated_text + '\n'
50+
51+
stream_queue.task_done()
52+
53+
if stream_queue.empty() and stream_queue.unfinished_tasks == 0:
54+
time.sleep(1)
55+
1956

2057
def transcribe(stream, new_chunk):
21-
generate_kwargs = {"language": "zh", "task": "transcribe"}
58+
global text
2259

2360
sr, y = new_chunk
61+
2462
y = y.astype(np.float32)
2563
y /= np.max(np.abs(y))
26-
print(y)
64+
# print('sample shape:', y.shape)
65+
speech_score = vad_model(mindspore.tensor(y), sr)
66+
speech_score = speech_score.item()
67+
print('speech socre', speech_score)
2768

28-
if stream is not None:
29-
stream = np.concatenate([stream, y])
30-
else:
31-
stream = y
69+
if speech_score > 0.5:
70+
if stream is not None:
71+
if stream.shape < y.shape or (stream[-len(y):] - y).sum() != 0:
72+
stream = np.concatenate([stream, y])
73+
else:
74+
stream = y
3275

33-
if stream.shape[0] < (3 * 48000):
34-
return stream, None
35-
36-
text = pipe({"sampling_rate": sr, "raw": y}, generate_kwargs=generate_kwargs)["text"]
37-
38-
if str(text).endswith((".", "。", '?', "?", '!', "!", ":", ":")):
76+
if stream is not None and stream.shape[0] >= (48000 * 5): # 5s if continue talk
77+
print('stream shape:', stream.shape)
78+
stream_queue.put({"sampling_rate": sr, "raw": stream})
3979
stream = None
80+
4081
return stream, text # type: ignore
4182

83+
input_audio = gr.Audio(sources=["microphone"], streaming=True)
4284
demo = gr.Interface(
4385
transcribe,
44-
["state", gr.Audio(sources=["microphone"], streaming=True)],
86+
["state", input_audio],
4587
["state", "text"],
4688
live=True,
4789
)
4890

4991
if __name__ == "__main__":
92+
c = threading.Thread(target=pipeline_consumer)
93+
c.start()
5094
demo.launch()

mindnlp/core/nn/modules/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def _ntuple(n):
77
def parse(x):
8-
if isinstance(x, collections.Iterable):
8+
if isinstance(x, collections.abc.Iterable):
99
return x
1010
return tuple(repeat(x, n))
1111
return parse

0 commit comments

Comments
 (0)