|
| 1 | +from os.path import isdir |
1 | 2 | import json
|
2 | 3 | from vosk import Model as KaldiModel, KaldiRecognizer
|
3 | 4 | from queue import Queue
|
|
7 | 8 | from ovos_skill_installer import download_extract_zip, download_extract_tar
|
8 | 9 | from os.path import join, exists, isdir
|
9 | 10 | from ovos_utils.xdg_utils import xdg_data_home
|
| 11 | +from ovos_utils.file_utils import read_vocab_file, resolve_resource_file, resolve_ovos_resource_file |
10 | 12 |
|
11 | 13 |
|
12 | 14 | class VoskKaldiSTT(STT):
|
13 | 15 | def __init__(self, *args, **kwargs):
|
14 | 16 | super().__init__(*args, **kwargs)
|
15 | 17 | # model_folder for backwards compat
|
16 |
| - model_path = self.config.get("model_folder") or \ |
17 |
| - self.config.get("model") |
18 |
| - lang = self.config.get("lang") |
19 |
| - if not model_path and lang: |
20 |
| - model_path = self.lang2modelurl(lang) |
21 |
| - if model_path and model_path.startswith("http"): |
22 |
| - model_path = self.download_model(model_path) |
23 |
| - if not model_path or not isdir(model_path): |
| 18 | + self.model_path = self.config.get("model_folder") or self.config.get("model") |
| 19 | + if not self.model_path and self.lang: |
| 20 | + self.model_path = self.download_language(self.lang) |
| 21 | + if not self.model_path or not isdir(self.model_path): |
24 | 22 | LOG.error("You need to provide a valid model path or url")
|
25 | 23 | LOG.info(
|
26 | 24 | "download a model from https://alphacephei.com/vosk/models")
|
27 | 25 | raise FileNotFoundError
|
28 |
| - self.kaldi = KaldiRecognizer(KaldiModel(model_path), 16000) |
| 26 | + |
| 27 | + self.engines = { |
| 28 | + self.lang: KaldiRecognizer(KaldiModel(self.model_path), 16000) |
| 29 | + } |
| 30 | + self.limited_voc_engines = {} |
| 31 | + self.limited = False |
| 32 | + |
| 33 | + def download_language(self, lang=None): |
| 34 | + lang = lang or self.lang |
| 35 | + lang = lang.split("-")[0].lower() |
| 36 | + model_path = self.lang2modelurl(lang) |
| 37 | + if model_path and model_path.startswith("http"): |
| 38 | + model_path = self.download_model(model_path) |
| 39 | + return model_path |
| 40 | + |
| 41 | + def load_language(self, lang=None): |
| 42 | + lang = lang or self.lang |
| 43 | + lang = lang.split("-")[0].lower() |
| 44 | + if lang in self.engines or lang in self.limited_voc_engines: |
| 45 | + return |
| 46 | + model_path = self.download_language(lang) |
| 47 | + if model_path: |
| 48 | + self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000) |
| 49 | + else: |
| 50 | + LOG.error(f"No default model available for {lang}") |
| 51 | + raise FileNotFoundError |
| 52 | + |
| 53 | + def unload_language(self, lang=None): |
| 54 | + lang = lang or self.lang |
| 55 | + if lang in self.engines: |
| 56 | + del self.engines[lang] |
| 57 | + self.engines.pop(lang) |
| 58 | + if lang in self.limited_voc_engines: |
| 59 | + del self.limited_voc_engines[lang] |
| 60 | + self.limited_voc_engines.pop(lang) |
| 61 | + |
| 62 | + def enable_full_vocabulary(self, lang=None): |
| 63 | + """ enable default transcription mode """ |
| 64 | + lang = lang or self.lang |
| 65 | + self.limited = False |
| 66 | + if lang in self.limited_voc_engines: |
| 67 | + self.limited_voc_engines.pop(lang) |
| 68 | + self.engines[lang] = KaldiRecognizer(KaldiModel(model_path), 16000) |
| 69 | + |
| 70 | + def enable_limited_vocabulary(self, words, lang=None, permanent=True): |
| 71 | + """ |
| 72 | + enable limited vocabulary mode |
| 73 | + will only consider pre defined .voc files |
| 74 | + """ |
| 75 | + lang = lang or self.lang |
| 76 | + if lang == self.lang: |
| 77 | + model_path = self.model_path |
| 78 | + else: |
| 79 | + model_path = self.lang2modelurl(lang) |
| 80 | + if model_path: |
| 81 | + self.limited_voc_engines[lang] = KaldiRecognizer(KaldiModel(model_path), |
| 82 | + 16000, json.dumps(words)) |
| 83 | + if permanent: |
| 84 | + del self.engines[lang] |
| 85 | + self.engines[lang] = self.limited_voc_engines[lang] |
| 86 | + self.limited = True |
29 | 87 |
|
30 | 88 | @staticmethod
|
31 | 89 | def download_model(url):
|
@@ -84,11 +142,27 @@ def lang2modelurl(lang, small=True):
|
84 | 142 | return lang2url.get(lang)
|
85 | 143 |
|
86 | 144 | def execute(self, audio, language=None):
|
87 |
| - self.kaldi.AcceptWaveform(audio.get_wav_data()) |
88 |
| - res = self.kaldi.FinalResult() |
| 145 | + # load a new model on the fly if needed |
| 146 | + lang = language or self.lang |
| 147 | + self.load_language(lang) |
| 148 | + |
| 149 | + # if limited vocabulary mode is enabled use that model instead |
| 150 | + if self.limited: |
| 151 | + engine = self.limited_voc_engines.get(lang) or self.engines[lang] |
| 152 | + else: |
| 153 | + engine = self.engines[lang] |
| 154 | + |
| 155 | + # transcribe |
| 156 | + engine.AcceptWaveform(audio.get_wav_data()) |
| 157 | + res = engine.FinalResult() |
89 | 158 | res = json.loads(res)
|
90 | 159 | return res["text"]
|
91 | 160 |
|
| 161 | + def shutdown(self): |
| 162 | + for lang in set(self.engines.keys()) + \ |
| 163 | + set(self.limited_voc_engines.keys()): |
| 164 | + self.unload_language(lang) |
| 165 | + |
92 | 166 |
|
93 | 167 | class VoskKaldiStreamThread(StreamThread):
|
94 | 168 | def __init__(self, queue, lang, kaldi, verbose=True):
|
|
0 commit comments