1
- from os .path import isdir
2
1
import json
3
- from vosk import Model as KaldiModel , KaldiRecognizer
2
+ from os . path import join , exists
4
3
from queue import Queue
5
- import numpy as np
6
- from ovos_utils .log import LOG
4
+
7
5
from ovos_plugin_manager .templates .stt import STT , StreamThread , StreamingSTT
8
6
from ovos_skill_installer import download_extract_zip , download_extract_tar
9
- from os . path import join , exists , isdir
7
+ from ovos_utils . log import LOG
10
8
from ovos_utils .xdg_utils import xdg_data_home
11
- from ovos_utils .file_utils import read_vocab_file , resolve_resource_file , resolve_ovos_resource_file
9
+ from vosk import Model as KaldiModel , KaldiRecognizer
10
+ from speech_recognition import AudioData
12
11
13
12
14
- class VoskKaldiSTT (STT ):
15
- def __init__ (self , * args , ** kwargs ):
16
- super ().__init__ (* args , ** kwargs )
17
- # model_folder for backwards compat
18
- self .model_path = self .config .get ("model_folder" ) or self .config .get ("model" )
19
- if not self .model_path and self .lang :
20
- self .model_path = self .download_language (self .lang )
21
- if not self .model_path or not isdir (self .model_path ):
22
- LOG .error ("You need to provide a valid model path or url" )
23
- LOG .info (
24
- "download a model from https://alphacephei.com/vosk/models" )
25
- raise FileNotFoundError
26
-
27
- self .engines = {
28
- self .lang : KaldiRecognizer (KaldiModel (self .model_path ), 16000 )
29
- }
30
- self .limited_voc_engines = {}
31
- self .limited = False
13
+ class ModelContainer :
14
+ def __init__ (self ):
15
+ self .engines = {}
16
+ self .models = {}
32
17
33
- def download_language (self , lang = None ):
34
- lang = lang or self .lang
18
+ def get_engine (self , lang ):
35
19
lang = lang .split ("-" )[0 ].lower ()
36
- model_path = self .lang2modelurl (lang )
37
- if model_path and model_path .startswith ("http" ):
38
- model_path = self .download_model (model_path )
39
- return model_path
20
+ self .load_language (lang )
21
+ return self .engines [lang ]
40
22
41
- def load_language (self , lang = None ):
42
- lang = lang or self .lang
23
+ def get_partial_transcription (self , lang ):
24
+ engine = self .get_engine (lang )
25
+ res = engine .PartialResult ()
26
+ return json .loads (res )["partial" ]
27
+
28
+ def get_final_transcription (self , lang ):
29
+ engine = self .get_engine (lang )
30
+ res = engine .FinalResult ()
31
+ return json .loads (res )["text" ]
32
+
33
+ def process_audio (self , audio , lang ):
34
+ engine = self .get_engine (lang )
35
+ if isinstance (audio , AudioData ):
36
+ audio = audio .get_wav_data ()
37
+ return engine .AcceptWaveform (audio )
38
+
39
+ def enable_limited_vocabulary (self , words , lang ):
40
+ """
41
+ enable limited vocabulary mode
42
+ will only consider pre defined .voc files
43
+ """
44
+ model_path = self .models [lang ]
45
+ self .engines [lang ] = KaldiRecognizer (
46
+ KaldiModel (model_path ), 16000 , json .dumps (words ))
47
+
48
+ def enable_full_vocabulary (self , lang = None ):
49
+ """ enable default transcription mode """
50
+ model_path = self .models [lang ]
51
+ self .engines [lang ] = KaldiRecognizer (
52
+ KaldiModel (model_path ), 16000 )
53
+
54
+ def load_model (self , model_path , lang ):
43
55
lang = lang .split ("-" )[0 ].lower ()
44
- if lang in self .engines or lang in self .limited_voc_engines :
45
- return
46
- model_path = self .download_language (lang )
56
+ self .models [lang ] = model_path
47
57
if model_path :
48
58
self .engines [lang ] = KaldiRecognizer (KaldiModel (model_path ), 16000 )
49
59
else :
50
- LOG .error (f"No default model available for { lang } " )
51
60
raise FileNotFoundError
52
61
53
- def unload_language (self , lang = None ):
54
- lang = lang or self .lang
62
+ def load_language (self , lang ):
63
+ lang = lang .split ("-" )[0 ].lower ()
64
+ if lang in self .engines :
65
+ return
66
+ model_path = self .download_language (lang )
67
+ self .load_model (model_path , lang )
68
+
69
+ def unload_language (self , lang ):
55
70
if lang in self .engines :
56
71
del self .engines [lang ]
57
72
self .engines .pop (lang )
58
- if lang in self .limited_voc_engines :
59
- del self .limited_voc_engines [lang ]
60
- self .limited_voc_engines .pop (lang )
61
-
62
- def enable_full_vocabulary (self , lang = None ):
63
- """ enable default transcription mode """
64
- lang = lang or self .lang
65
- self .limited = False
66
- if lang in self .limited_voc_engines :
67
- self .limited_voc_engines .pop (lang )
68
- self .engines [lang ] = KaldiRecognizer (KaldiModel (model_path ), 16000 )
69
73
70
- def enable_limited_vocabulary (self , words , lang = None , permanent = True ):
71
- """
72
- enable limited vocabulary mode
73
- will only consider pre defined .voc files
74
- """
75
- lang = lang or self .lang
76
- if lang == self .lang :
77
- model_path = self .model_path
78
- else :
79
- model_path = self .lang2modelurl (lang )
80
- if model_path :
81
- self .limited_voc_engines [lang ] = KaldiRecognizer (KaldiModel (model_path ),
82
- 16000 , json .dumps (words ))
83
- if permanent :
84
- del self .engines [lang ]
85
- self .engines [lang ] = self .limited_voc_engines [lang ]
86
- self .limited = True
74
+ @staticmethod
75
+ def download_language (lang ):
76
+ lang = lang .split ("-" )[0 ].lower ()
77
+ model_path = ModelContainer .lang2modelurl (lang )
78
+ if model_path and model_path .startswith ("http" ):
79
+ model_path = ModelContainer .download_model (model_path )
80
+ return model_path
87
81
88
82
@staticmethod
89
83
def download_model (url ):
@@ -141,64 +135,68 @@ def lang2modelurl(lang, small=True):
141
135
lang = lang .split ("-" )[0 ]
142
136
return lang2url .get (lang )
143
137
144
- def execute (self , audio , language = None ):
145
- # load a new model on the fly if needed
146
- lang = language or self .lang
147
- self .load_language (lang )
148
138
149
- # if limited vocabulary mode is enabled use that model instead
150
- if self .limited :
151
- engine = self .limited_voc_engines .get (lang ) or self .engines [lang ]
139
+ class VoskKaldiSTT (STT ):
140
+ def __init__ (self , * args , ** kwargs ):
141
+ super ().__init__ (* args , ** kwargs )
142
+ # model_folder for backwards compat
143
+ model_path = self .config .get ("model_folder" ) or self .config .get ("model" )
144
+
145
+ self .model = ModelContainer ()
146
+ if model_path :
147
+ if model_path .startswith ("http" ):
148
+ model_path = ModelContainer .download_model (model_path )
149
+ self .model .load_model (model_path , self .lang )
152
150
else :
153
- engine = self .engines [lang ]
151
+ self .model .load_language (self .lang )
152
+ self .verbose = True
154
153
155
- # transcribe
156
- engine .AcceptWaveform (audio .get_wav_data ())
157
- res = engine .FinalResult ()
158
- res = json .loads (res )
159
- return res ["text" ]
154
+ def load_language (self , lang ):
155
+ self .model .load_language (lang )
160
156
161
- def shutdown (self ):
162
- for lang in set (self .engines .keys ()) + \
163
- set (self .limited_voc_engines .keys ()):
164
- self .unload_language (lang )
157
+ def unload_language (self , lang ):
158
+ self .model .unload_language (lang )
159
+
160
+ def enable_limited_vocabulary (self , words , lang ):
161
+ self .model .enable_limited_vocabulary (words , lang or self .lang )
162
+
163
+ def enable_full_vocabulary (self , lang = None ):
164
+ self .model .enable_full_vocabulary (lang or self .lang )
165
+
166
+ def execute (self , audio , language = None ):
167
+ lang = language or self .lang
168
+ self .model .process_audio (audio , lang )
169
+ return self .model .get_final_transcription (lang )
165
170
166
171
167
172
class VoskKaldiStreamThread (StreamThread ):
168
- def __init__ (self , queue , lang , kaldi , verbose = True ):
173
+ def __init__ (self , queue , lang , model , verbose = True ):
169
174
super ().__init__ (queue , lang )
170
- self .kaldi = kaldi
175
+ self .model = model
171
176
self .verbose = verbose
172
177
self .previous_partial = ""
173
178
self .running = True
174
179
175
180
def handle_audio_stream (self , audio , language ):
181
+ lang = language or self .language
176
182
if self .running :
177
183
for a in audio :
178
- data = np .frombuffer (a , np .int16 )
179
- if self .kaldi .AcceptWaveform (data ):
180
- res = self .kaldi .Result ()
181
- res = json .loads (res )
182
- self .text = res ["text" ]
183
- else :
184
- res = self .kaldi .PartialResult ()
185
- res = json .loads (res )
186
- self .text = res ["partial" ]
187
- if self .verbose :
188
- if self .previous_partial != self .text :
189
- LOG .info ("Partial Transcription: " + self .text )
190
- self .previous_partial = self .text
191
-
184
+ self .model .process_audio (a , lang )
185
+ self .text = self .model .get_partial_transcription (lang )
186
+ if self .verbose :
187
+ if self .previous_partial != self .text :
188
+ LOG .info ("Partial Transcription: " + self .text )
189
+ self .previous_partial = self .text
192
190
return self .text
193
191
194
192
def finalize (self ):
195
193
self .running = False
196
194
if self .previous_partial :
197
195
if self .verbose :
198
196
LOG .info ("Finalizing stream" )
199
- self .text = self .kaldi . FinalResult ( )
197
+ self .text = self .model . get_final_transcription ( self . language )
200
198
self .previous_partial = ""
201
- text = self .text
199
+ text = str ( self .text )
202
200
self .text = ""
203
201
return text
204
202
@@ -212,5 +210,5 @@ def __init__(self, *args, **kwargs):
212
210
def create_streaming_thread (self ):
213
211
self .queue = Queue ()
214
212
return VoskKaldiStreamThread (
215
- self .queue , self .lang , self .kaldi , self .verbose
213
+ self .queue , self .lang , self .model , self .verbose
216
214
)
0 commit comments