Skip to content

Commit 2378e96

Browse files
authored
Provide dir path not xml file for whisper pipline (#4008)
1 parent 42ddeea commit 2378e96

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/whisper_evaluator.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
limitations under the License.
1515
"""
1616
import re
17+
import os
1718

1819
from ...representation import CharacterRecognitionPrediction
1920
from ...utils import UnsupportedPackage, extract_image_representations
@@ -121,7 +122,7 @@ def _initialize_pipeline(self, config):
121122
except ImportError as import_error:
122123
UnsupportedPackage("openvino_genai", import_error.msg).raise_error(self.__class__.__name__)
123124

124-
model_dir = config.get("_models", [None])[0]
125+
model_dir = get_model_dir(config)
125126
device = config.get("_device", "CPU")
126127
pipeline = ov_genai.WhisperPipeline(str(model_dir), device=device)
127128
return pipeline
@@ -169,7 +170,7 @@ def _initialize_pipeline(self, config):
169170
UnsupportedPackage("optimum.intel.openvino", import_error.msg).raise_error(self.__class__.__name__)
170171

171172
device = config.get("_device", "CPU")
172-
model_dir = config.get("_models", [None])[0]
173+
model_dir = get_model_dir(config)
173174
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(str(model_dir)).to(device)
174175
ov_processor = AutoProcessor.from_pretrained(str(model_dir))
175176

@@ -184,3 +185,12 @@ def _get_predictions(self, data, identifiers, input_meta):
184185
sampling_rate = input_meta[0].get("sample_rate")
185186
sample = {"path": identifiers[0], "array": data[0], "sampling_rate": sampling_rate}
186187
return self.pipeline(sample, return_timestamps=True)["text"]
188+
189+
190+
191+
def get_model_dir(config):
192+
model_path = config.get("_models", [None])[0]
193+
194+
if os.path.isfile(model_path):
195+
return os.path.dirname(model_path)
196+
return model_path

0 commit comments

Comments
 (0)