Skip to content

Commit 998186d

Browse files
committed
- refactor model language handling
- fix streaming STT - handlers for load/unload lang - handlers for limited vocab related OpenVoiceOS/ovos-core/pull/78
1 parent 8c46862 commit 998186d

File tree

2 files changed

+100
-103
lines changed

2 files changed

+100
-103
lines changed

ovos_stt_plugin_vosk/__init__.py

Lines changed: 100 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,83 @@
1-
from os.path import isdir
21
import json
3-
from vosk import Model as KaldiModel, KaldiRecognizer
2+
from os.path import join, exists
43
from queue import Queue
5-
import numpy as np
6-
from ovos_utils.log import LOG
4+
75
from ovos_plugin_manager.templates.stt import STT, StreamThread, StreamingSTT
86
from ovos_skill_installer import download_extract_zip, download_extract_tar
9-
from os.path import join, exists, isdir
7+
from ovos_utils.log import LOG
108
from ovos_utils.xdg_utils import xdg_data_home
11-
from ovos_utils.file_utils import read_vocab_file, resolve_resource_file, resolve_ovos_resource_file
9+
from vosk import Model as KaldiModel, KaldiRecognizer
10+
from speech_recognition import AudioData
1211

1312

14-
class VoskKaldiSTT(STT):
15-
def __init__(self, *args, **kwargs):
16-
super().__init__(*args, **kwargs)
17-
# model_folder for backwards compat
18-
self.model_path = self.config.get("model_folder") or self.config.get("model")
19-
if not self.model_path and self.lang:
20-
self.model_path = self.download_language(self.lang)
21-
if not self.model_path or not isdir(self.model_path):
22-
LOG.error("You need to provide a valid model path or url")
23-
LOG.info(
24-
"download a model from https://alphacephei.com/vosk/models")
25-
raise FileNotFoundError
26-
27-
self.engines = {
28-
self.lang: KaldiRecognizer(KaldiModel(self.model_path), 16000)
29-
}
30-
self.limited_voc_engines = {}
31-
self.limited = False
13+
class ModelContainer:
14+
def __init__(self):
15+
self.engines = {}
16+
self.models = {}
3217

33-
def download_language(self, lang=None):
34-
lang = lang or self.lang
18+
def get_engine(self, lang):
3519
lang = lang.split("-")[0].lower()
36-
model_path = self.lang2modelurl(lang)
37-
if model_path and model_path.startswith("http"):
38-
model_path = self.download_model(model_path)
39-
return model_path
20+
self.load_language(lang)
21+
return self.engines[lang]
4022

41-
def load_language(self, lang=None):
42-
lang = lang or self.lang
23+
def get_partial_transcription(self, lang):
24+
engine = self.get_engine(lang)
25+
res = engine.PartialResult()
26+
return json.loads(res)["partial"]
27+
28+
def get_final_transcription(self, lang):
29+
engine = self.get_engine(lang)
30+
res = engine.FinalResult()
31+
return json.loads(res)["text"]
32+
33+
def process_audio(self, audio, lang):
34+
engine = self.get_engine(lang)
35+
if isinstance(audio, AudioData):
36+
audio = audio.get_wav_data()
37+
return engine.AcceptWaveform(audio)
38+
39+
def enable_limited_vocabulary(self, words, lang):
40+
"""
41+
enable limited vocabulary mode
42+
will only consider pre defined .voc files
43+
"""
44+
model_path = self.models[lang]
45+
self.engines[lang] = KaldiRecognizer(
46+
KaldiModel(model_path), 16000, json.dumps(words))
47+
48+
def enable_full_vocabulary(self, lang=None):
49+
""" enable default transcription mode """
50+
model_path = self.models[lang]
51+
self.engines[lang] = KaldiRecognizer(
52+
KaldiModel(model_path), 16000)
53+
54+
def load_model(self, model_path, lang):
4355
lang = lang.split("-")[0].lower()
44-
if lang in self.engines or lang in self.limited_voc_engines:
45-
return
46-
model_path = self.download_language(lang)
56+
self.models[lang] = model_path
4757
if model_path:
4858
self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000)
4959
else:
50-
LOG.error(f"No default model available for {lang}")
5160
raise FileNotFoundError
5261

53-
def unload_language(self, lang=None):
54-
lang = lang or self.lang
62+
def load_language(self, lang):
63+
lang = lang.split("-")[0].lower()
64+
if lang in self.engines:
65+
return
66+
model_path = self.download_language(lang)
67+
self.load_model(model_path, lang)
68+
69+
def unload_language(self, lang):
5570
if lang in self.engines:
5671
del self.engines[lang]
5772
self.engines.pop(lang)
58-
if lang in self.limited_voc_engines:
59-
del self.limited_voc_engines[lang]
60-
self.limited_voc_engines.pop(lang)
61-
62-
def enable_full_vocabulary(self, lang=None):
63-
""" enable default transcription mode """
64-
lang = lang or self.lang
65-
self.limited = False
66-
if lang in self.limited_voc_engines:
67-
self.limited_voc_engines.pop(lang)
68-
self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000)
6973

70-
def enable_limited_vocabulary(self, words, lang=None, permanent=True):
71-
"""
72-
enable limited vocabulary mode
73-
will only consider pre defined .voc files
74-
"""
75-
lang = lang or self.lang
76-
if lang == self.lang:
77-
model_path = self.model_path
78-
else:
79-
model_path = self.lang2modelurl(lang)
80-
if model_path:
81-
self.limited_voc_engines[lang] = KaldiRecognizer(KaldiModel(model_path),
82-
16000, json.dumps(words))
83-
if permanent:
84-
del self.engines[lang]
85-
self.engines[lang] = self.limited_voc_engines[lang]
86-
self.limited = True
74+
@staticmethod
75+
def download_language(lang):
76+
lang = lang.split("-")[0].lower()
77+
model_path = ModelContainer.lang2modelurl(lang)
78+
if model_path and model_path.startswith("http"):
79+
model_path = ModelContainer.download_model(model_path)
80+
return model_path
8781

8882
@staticmethod
8983
def download_model(url):
@@ -141,64 +135,68 @@ def lang2modelurl(lang, small=True):
141135
lang = lang.split("-")[0]
142136
return lang2url.get(lang)
143137

144-
def execute(self, audio, language=None):
145-
# load a new model on the fly if needed
146-
lang = language or self.lang
147-
self.load_language(lang)
148138

149-
# if limited vocabulary mode is enabled use that model instead
150-
if self.limited:
151-
engine = self.limited_voc_engines.get(lang) or self.engines[lang]
139+
class VoskKaldiSTT(STT):
140+
def __init__(self, *args, **kwargs):
141+
super().__init__(*args, **kwargs)
142+
# model_folder for backwards compat
143+
model_path = self.config.get("model_folder") or self.config.get("model")
144+
145+
self.model = ModelContainer()
146+
if model_path:
147+
if model_path.startswith("http"):
148+
model_path = ModelContainer.download_model(model_path)
149+
self.model.load_model(model_path, self.lang)
152150
else:
153-
engine = self.engines[lang]
151+
self.model.load_language(self.lang)
152+
self.verbose = True
154153

155-
# transcribe
156-
engine.AcceptWaveform(audio.get_wav_data())
157-
res = engine.FinalResult()
158-
res = json.loads(res)
159-
return res["text"]
154+
def load_language(self, lang):
155+
self.model.load_language(lang)
160156

161-
def shutdown(self):
162-
for lang in set(self.engines.keys()) + \
163-
set(self.limited_voc_engines.keys()):
164-
self.unload_language(lang)
157+
def unload_language(self, lang):
158+
self.model.unload_language(lang)
159+
160+
def enable_limited_vocabulary(self, words, lang):
161+
self.model.enable_limited_vocabulary(words, lang or self.lang)
162+
163+
def enable_full_vocabulary(self, lang=None):
164+
self.model.enable_full_vocabulary(lang or self.lang)
165+
166+
def execute(self, audio, language=None):
167+
lang = language or self.lang
168+
self.model.process_audio(audio, lang)
169+
return self.model.get_final_transcription(lang)
165170

166171

167172
class VoskKaldiStreamThread(StreamThread):
168-
def __init__(self, queue, lang, kaldi, verbose=True):
173+
def __init__(self, queue, lang, model, verbose=True):
169174
super().__init__(queue, lang)
170-
self.kaldi = kaldi
175+
self.model = model
171176
self.verbose = verbose
172177
self.previous_partial = ""
173178
self.running = True
174179

175180
def handle_audio_stream(self, audio, language):
181+
lang = language or self.language
176182
if self.running:
177183
for a in audio:
178-
data = np.frombuffer(a, np.int16)
179-
if self.kaldi.AcceptWaveform(data):
180-
res = self.kaldi.Result()
181-
res = json.loads(res)
182-
self.text = res["text"]
183-
else:
184-
res = self.kaldi.PartialResult()
185-
res = json.loads(res)
186-
self.text = res["partial"]
187-
if self.verbose:
188-
if self.previous_partial != self.text:
189-
LOG.info("Partial Transcription: " + self.text)
190-
self.previous_partial = self.text
191-
184+
self.model.process_audio(a, lang)
185+
self.text = self.model.get_partial_transcription(lang)
186+
if self.verbose:
187+
if self.previous_partial != self.text:
188+
LOG.info("Partial Transcription: " + self.text)
189+
self.previous_partial = self.text
192190
return self.text
193191

194192
def finalize(self):
195193
self.running = False
196194
if self.previous_partial:
197195
if self.verbose:
198196
LOG.info("Finalizing stream")
199-
self.text = self.kaldi.FinalResult()
197+
self.text = self.model.get_final_transcription(self.language)
200198
self.previous_partial = ""
201-
text = self.text
199+
text = str(self.text)
202200
self.text = ""
203201
return text
204202

@@ -212,5 +210,5 @@ def __init__(self, *args, **kwargs):
212210
def create_streaming_thread(self):
213211
self.queue = Queue()
214212
return VoskKaldiStreamThread(
215-
self.queue, self.lang, self.kaldi, self.verbose
213+
self.queue, self.lang, self.model, self.verbose
216214
)

requirements/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
numpy
21
vosk
32
ovos-plugin-manager>=0.0.1
43
ovos_skill_installer

0 commit comments

Comments
 (0)