14
14
limitations under the License.
15
15
"""
16
16
import re
17
+ import os
17
18
18
19
from ...representation import CharacterRecognitionPrediction
19
20
from ...utils import UnsupportedPackage , extract_image_representations
@@ -121,7 +122,7 @@ def _initialize_pipeline(self, config):
121
122
except ImportError as import_error :
122
123
UnsupportedPackage ("openvino_genai" , import_error .msg ).raise_error (self .__class__ .__name__ )
123
124
124
- model_dir = config . get ( "_models" , [ None ])[ 0 ]
125
+ model_dir = get_model_dir ( config )
125
126
device = config .get ("_device" , "CPU" )
126
127
pipeline = ov_genai .WhisperPipeline (str (model_dir ), device = device )
127
128
return pipeline
@@ -169,7 +170,7 @@ def _initialize_pipeline(self, config):
169
170
UnsupportedPackage ("optimum.intel.openvino" , import_error .msg ).raise_error (self .__class__ .__name__ )
170
171
171
172
device = config .get ("_device" , "CPU" )
172
- model_dir = config . get ( "_models" , [ None ])[ 0 ]
173
+ model_dir = get_model_dir ( config )
173
174
ov_model = OVModelForSpeechSeq2Seq .from_pretrained (str (model_dir )).to (device )
174
175
ov_processor = AutoProcessor .from_pretrained (str (model_dir ))
175
176
@@ -184,3 +185,12 @@ def _get_predictions(self, data, identifiers, input_meta):
184
185
sampling_rate = input_meta [0 ].get ("sample_rate" )
185
186
sample = {"path" : identifiers [0 ], "array" : data [0 ], "sampling_rate" : sampling_rate }
186
187
return self .pipeline (sample , return_timestamps = True )["text" ]
188
+
189
+
190
+
191
+ def get_model_dir (config ):
192
+ model_path = config .get ("_models" , [None ])[0 ]
193
+
194
+ if os .path .isfile (model_path ):
195
+ return os .path .dirname (model_path )
196
+ return model_path
0 commit comments