Skip to content

Commit 07f01ae

Browse files
authored
feat/multi_lang_limited_voc (#3)
* feat/multi_lang_limited_voc - add support for loading multiple languages, allows for lang to be set by request - add support for OpenVoiceOS/ovos-core#78 authored-by: jarbasai <jarbasai@mailfence.com>
1 parent b6fc7b1 commit 07f01ae

File tree

1 file changed

+85
-11
lines changed

1 file changed

+85
-11
lines changed

ovos_stt_plugin_vosk/__init__.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from os.path import isdir
12
import json
23
from vosk import Model as KaldiModel, KaldiRecognizer
34
from queue import Queue
@@ -7,25 +8,82 @@
78
from ovos_skill_installer import download_extract_zip, download_extract_tar
89
from os.path import join, exists, isdir
910
from ovos_utils.xdg_utils import xdg_data_home
11+
from ovos_utils.file_utils import read_vocab_file, resolve_resource_file, resolve_ovos_resource_file
1012

1113

1214
class VoskKaldiSTT(STT):
1315
def __init__(self, *args, **kwargs):
1416
super().__init__(*args, **kwargs)
1517
# model_folder for backwards compat
16-
model_path = self.config.get("model_folder") or \
17-
self.config.get("model")
18-
lang = self.config.get("lang")
19-
if not model_path and lang:
20-
model_path = self.lang2modelurl(lang)
21-
if model_path and model_path.startswith("http"):
22-
model_path = self.download_model(model_path)
23-
if not model_path or not isdir(model_path):
18+
self.model_path = self.config.get("model_folder") or self.config.get("model")
19+
if not self.model_path and self.lang:
20+
self.model_path = self.download_language(self.lang)
21+
if not self.model_path or not isdir(self.model_path):
2422
LOG.error("You need to provide a valid model path or url")
2523
LOG.info(
2624
"download a model from https://alphacephei.com/vosk/models")
2725
raise FileNotFoundError
28-
self.kaldi = KaldiRecognizer(KaldiModel(model_path), 16000)
26+
27+
self.engines = {
28+
self.lang: KaldiRecognizer(KaldiModel(self.model_path), 16000)
29+
}
30+
self.limited_voc_engines = {}
31+
self.limited = False
32+
33+
def download_language(self, lang=None):
34+
lang = lang or self.lang
35+
lang = lang.split("-")[0].lower()
36+
model_path = self.lang2modelurl(lang)
37+
if model_path and model_path.startswith("http"):
38+
model_path = self.download_model(model_path)
39+
return model_path
40+
41+
def load_language(self, lang=None):
42+
lang = lang or self.lang
43+
lang = lang.split("-")[0].lower()
44+
if lang in self.engines or lang in self.limited_voc_engines:
45+
return
46+
model_path = self.download_language(lang)
47+
if model_path:
48+
self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000)
49+
else:
50+
LOG.error(f"No default model available for {lang}")
51+
raise FileNotFoundError
52+
53+
def unload_language(self, lang=None):
54+
lang = lang or self.lang
55+
if lang in self.engines:
56+
del self.engines[lang]
57+
self.engines.pop(lang)
58+
if lang in self.limited_voc_engines:
59+
del self.limited_voc_engines[lang]
60+
self.limited_voc_engines.pop(lang)
61+
62+
def enable_full_vocabulary(self, lang=None):
63+
""" enable default transcription mode """
64+
lang = lang or self.lang
65+
self.limited = False
66+
if lang in self.limited_voc_engines:
67+
self.limited_voc_engines.pop(lang)
68+
self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000)
69+
70+
def enable_limited_vocabulary(self, words, lang=None, permanent=True):
71+
"""
72+
enable limited vocabulary mode
73+
will only consider pre defined .voc files
74+
"""
75+
lang = lang or self.lang
76+
if lang == self.lang:
77+
model_path = self.model_path
78+
else:
79+
model_path = self.lang2modelurl(lang)
80+
if model_path:
81+
self.limited_voc_engines[lang] = KaldiRecognizer(KaldiModel(model_path),
82+
16000, json.dumps(words))
83+
if permanent:
84+
del self.engines[lang]
85+
self.engines[lang] = self.limited_voc_engines[lang]
86+
self.limited = True
2987

3088
@staticmethod
3189
def download_model(url):
@@ -84,11 +142,27 @@ def lang2modelurl(lang, small=True):
84142
return lang2url.get(lang)
85143

86144
def execute(self, audio, language=None):
87-
self.kaldi.AcceptWaveform(audio.get_wav_data())
88-
res = self.kaldi.FinalResult()
145+
# load a new model on the fly if needed
146+
lang = language or self.lang
147+
self.load_language(lang)
148+
149+
# if limited vocabulary mode is enabled use that model instead
150+
if self.limited:
151+
engine = self.limited_voc_engines.get(lang) or self.engines[lang]
152+
else:
153+
engine = self.engines[lang]
154+
155+
# transcribe
156+
engine.AcceptWaveform(audio.get_wav_data())
157+
res = engine.FinalResult()
89158
res = json.loads(res)
90159
return res["text"]
91160

161+
def shutdown(self):
162+
for lang in set(self.engines.keys()) + \
163+
set(self.limited_voc_engines.keys()):
164+
self.unload_language(lang)
165+
92166

93167
class VoskKaldiStreamThread(StreamThread):
94168
def __init__(self, queue, lang, kaldi, verbose=True):

0 commit comments

Comments
 (0)