Skip to content

Commit 1be1a30

Browse files
authored
Create whisper_evaluator.py (#3990)
* Create whisper_evaluator.py * Add OptimumIntelPipeline to whisper_evaluator.py * Update whisper_evaluator.py * Update OptimumIntelPipeline * Update naming, avoid errors for long audio * Create test_whisper_evaluator.py * Add datasets to requirements-test.in * Add infect to requirements-extra.in * Add cleanup test_whisper_evaluator.py * Cleanup of test_whisper_evaluator.py * Skip tests if modules not available * Update copyright * Pylint fixes
1 parent 7d15a7b commit 1be1a30

File tree

4 files changed

+282
-0
lines changed

4 files changed

+282
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Copyright (c) 2024 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import re
17+
18+
from ...representation import CharacterRecognitionPrediction
19+
from ...utils import UnsupportedPackage, extract_image_representations
20+
from .base_custom_evaluator import BaseCustomEvaluator
21+
22+
try:
23+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
24+
except ImportError as import_err:
25+
AutoModelForSpeechSeq2Seq = UnsupportedPackage("transformers", import_err.msg)
26+
AutoProcessor = UnsupportedPackage("transformers", import_err.msg)
27+
28+
try:
29+
from transformers.pipelines.automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
30+
except ImportError as import_err:
31+
AutomaticSpeechRecognitionPipeline = UnsupportedPackage("transformers", import_err.msg)
32+
33+
try:
34+
import inflect
35+
except ImportError as import_err:
36+
inflect = UnsupportedPackage("inflect", import_err.msg)
37+
38+
39+
class WhisperEvaluator(BaseCustomEvaluator):
40+
VALID_PIPELINE_CLASSES = [
41+
"GenAIWhisperPipeline",
42+
"HFWhisperPipeline",
43+
"OptimumWhisperPipeline"
44+
]
45+
46+
def __init__(self, dataset_config, pipe, orig_config):
47+
super().__init__(dataset_config, None, orig_config)
48+
self.pipe = pipe
49+
if hasattr(self.pipe, "adapter"):
50+
self.adapter_type = self.pipe.adapter.__provider__
51+
52+
@classmethod
53+
def from_configs(cls, config, delayed_model_loading=False, orig_config=None):
54+
dataset_config = config["datasets"]
55+
pipeline_class_name = config["pipeline_class"]
56+
if 'device' in config['launchers'][0]:
57+
config["_device"] = config['launchers'][0]['device']
58+
59+
if pipeline_class_name not in cls.VALID_PIPELINE_CLASSES:
60+
raise ValueError(f"Invalid pipeline class name: {pipeline_class_name}. "
61+
f"Must be one of {cls.VALID_PIPELINE_CLASSES}")
62+
63+
pipeline_class = globals()[pipeline_class_name]
64+
pipe = pipeline_class(config)
65+
return cls(dataset_config, pipe, orig_config)
66+
67+
def _process(self, output_callback, calculate_metrics, progress_reporter, metric_config, csv_file):
68+
for batch_id, (batch_input_ids, batch_annotation, batch_inputs, batch_identifiers) in enumerate(self.dataset):
69+
batch_inputs = self.preprocessor.process(batch_inputs, batch_annotation)
70+
batch_inputs_extr, batch_meta = extract_image_representations(batch_inputs)
71+
72+
batch_raw_prediction, batch_prediction = self.pipe.predict(
73+
batch_identifiers, batch_inputs_extr, batch_meta
74+
)
75+
metrics_result = self._get_metrics_result(batch_input_ids, batch_annotation, batch_prediction,
76+
calculate_metrics)
77+
if output_callback:
78+
output_callback(batch_raw_prediction[0], metrics_result=metrics_result,
79+
element_identifiers=batch_identifiers, dataset_indices=batch_input_ids)
80+
self._update_progress(progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file)
81+
82+
def release(self):
83+
pass
84+
85+
86+
def normalize_transcription(engine, text):
87+
# Convert numbers to words
88+
tokens = (engine.number_to_words(token) if token.isdigit() else token for token in text.split())
89+
# Remove punctuation except for apostrophes that are in the middle of words
90+
text = re.sub(r"\b'\b|[^\w\s]", "", " ".join(tokens))
91+
# Remove leading, trailing, and multiple consecutive spaces, and convert to uppercase
92+
return " ".join(text.upper().split())
93+
94+
95+
class WhisperPipeline:
96+
def __init__(self, config):
97+
self.engine = inflect.engine()
98+
self.pipeline = self._initialize_pipeline(config)
99+
100+
def _initialize_pipeline(self, config):
101+
raise NotImplementedError
102+
103+
def _get_predictions(self, data, identifiers, input_meta):
104+
raise NotImplementedError
105+
106+
def predict(self, identifiers, input_data, input_meta, encoder_callback=None):
107+
predictions = []
108+
outputs = []
109+
for data in input_data:
110+
transcription = self._get_predictions(data, identifiers, input_meta)
111+
prediction_text = normalize_transcription(self.engine, transcription)
112+
predictions.append(prediction_text)
113+
outputs.append(CharacterRecognitionPrediction(identifiers[0], predictions[0]))
114+
return [], outputs
115+
116+
117+
class GenAIWhisperPipeline(WhisperPipeline):
118+
def _initialize_pipeline(self, config):
119+
try:
120+
import openvino_genai as ov_genai # pylint: disable=C0415
121+
except ImportError as import_error:
122+
UnsupportedPackage("openvino_genai", import_error.msg).raise_error(self.__class__.__name__)
123+
124+
model_dir = config.get("_models", [None])[0]
125+
device = config.get("_device", "CPU")
126+
pipeline = ov_genai.WhisperPipeline(str(model_dir), device=device)
127+
return pipeline
128+
129+
def _get_predictions(self, data, identifiers, input_meta):
130+
return self.pipeline.generate(data[0], return_timestamps=True).texts[0]
131+
132+
133+
class HFWhisperPipeline(WhisperPipeline):
134+
def _initialize_pipeline(self, config):
135+
try:
136+
import torch # pylint: disable=C0415
137+
except ImportError as import_error:
138+
UnsupportedPackage("torch", import_error.msg).raise_error(self.__class__.__name__)
139+
140+
model_id = config.get("model_id")
141+
device = "cpu"
142+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
143+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
144+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
145+
).to(device)
146+
147+
processor = AutoProcessor.from_pretrained(model_id)
148+
149+
pipeline = AutomaticSpeechRecognitionPipeline(
150+
model=model,
151+
tokenizer=processor.tokenizer,
152+
feature_extractor=processor.feature_extractor,
153+
torch_dtype=torch_dtype,
154+
device=device,
155+
)
156+
return pipeline
157+
158+
def _get_predictions(self, data, identifiers, input_meta):
159+
sampling_rate = input_meta[0].get("sample_rate")
160+
sample = {"path": identifiers[0], "array": data[0], "sampling_rate": sampling_rate}
161+
return self.pipeline(sample, return_timestamps=True)["text"]
162+
163+
164+
class OptimumWhisperPipeline(WhisperPipeline):
165+
def _initialize_pipeline(self, config):
166+
try:
167+
from optimum.intel.openvino import OVModelForSpeechSeq2Seq # pylint: disable=C0415
168+
except ImportError as import_error:
169+
UnsupportedPackage("optimum.intel.openvino", import_error.msg).raise_error(self.__class__.__name__)
170+
171+
device = config.get("_device", "CPU")
172+
model_dir = config.get("_models", [None])[0]
173+
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(str(model_dir)).to(device)
174+
ov_processor = AutoProcessor.from_pretrained(str(model_dir))
175+
176+
pipeline = AutomaticSpeechRecognitionPipeline(
177+
model=ov_model,
178+
tokenizer=ov_processor.tokenizer,
179+
feature_extractor=ov_processor.feature_extractor
180+
)
181+
return pipeline
182+
183+
def _get_predictions(self, data, identifiers, input_meta):
184+
sampling_rate = input_meta[0].get("sample_rate")
185+
sample = {"path": identifiers[0], "array": data[0], "sampling_rate": sampling_rate}
186+
return self.pipeline(sample, return_timestamps=True)["text"]

tools/accuracy_checker/requirements-extra.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ lmdb>=1.2.1
4848

4949
# pandas datasets support
5050
pandas>=1.1.5,<2.1
51+
52+
# word-based representations of numbers
53+
inflect>=7.4.0

tools/accuracy_checker/requirements-test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pytest-mock~=2.0
77
# will not include atomicwrites and thus will not work on Windows.
88
# So as a workaround, make the atomicwrites dependency unconditional.
99
atomicwrites
10+
datasets
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Copyright (c) 2024-2025 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import os
17+
from pathlib import Path
18+
from unittest.mock import MagicMock, patch
19+
20+
import pytest
21+
from accuracy_checker.evaluators.custom_evaluators.whisper_evaluator import (
22+
GenAIWhisperPipeline, HFWhisperPipeline, OptimumWhisperPipeline,
23+
WhisperEvaluator)
24+
from datasets import load_dataset
25+
26+
AutoProcessor = pytest.importorskip("transformers", reason="transformers is not available").AutoProcessor
27+
AutoTokenizer = pytest.importorskip("transformers", reason="transformers is not available").AutoTokenizer
28+
export_tokenizer = pytest.importorskip("optimum.exporters.openvino.convert", reason="optimum.exporters.openvino.convert is not available").export_tokenizer
29+
OVModelForSpeechSeq2Seq = pytest.importorskip("optimum.intel.openvino", reason="optimum.intel.openvino is not available").OVModelForSpeechSeq2Seq
30+
31+
32+
model_id = "openai/whisper-tiny"
33+
model_dir = Path("/tmp/whisper-tiny")
34+
35+
def setup_module(module):
36+
# Setup code here
37+
global input_data, input_meta, identifiers
38+
39+
# Load a single sample from the dataset
40+
dataset = load_dataset("openslr/librispeech_asr", "clean", split="validation", streaming=True, trust_remote_code=True)
41+
sample = next(iter(dataset))
42+
input_data = [sample["audio"]["array"]]
43+
input_meta = [{"sample_rate": sample["audio"]["sampling_rate"]}]
44+
identifiers = [sample["id"]]
45+
46+
def teardown_module(module):
47+
# Cleanup code here
48+
if model_dir.exists():
49+
for item in model_dir.iterdir():
50+
if item.is_file():
51+
item.unlink()
52+
model_dir.rmdir()
53+
54+
def test_optimum_convert_model_to_ir():
55+
tokenizer = AutoTokenizer.from_pretrained(model_id)
56+
processor = AutoProcessor.from_pretrained(model_id)
57+
base_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id)
58+
59+
model_dir.mkdir(parents=True, exist_ok=True)
60+
base_model.save_pretrained(model_dir)
61+
tokenizer.save_pretrained(model_dir)
62+
processor.save_pretrained(model_dir)
63+
export_tokenizer(tokenizer, model_dir)
64+
65+
assert base_model.__class__.__module__.startswith('optimum.intel.openvino')
66+
67+
class TestWhisperEvaluator:
68+
def test_hf_whisper_pipeline(self):
69+
config = {"model_id": model_id}
70+
pipeline = HFWhisperPipeline(config)
71+
evaluator = WhisperEvaluator(None, pipeline, None)
72+
73+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
74+
assert isinstance(result, str)
75+
76+
@pytest.mark.dependency(depends=["test_optimum_convert_model_to_ir"])
77+
def test_genai_whisper_pipeline(self):
78+
config = {"_models": [model_dir], "_device": "CPU"}
79+
pipeline = GenAIWhisperPipeline(config)
80+
evaluator = WhisperEvaluator(None, pipeline, None)
81+
82+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
83+
assert isinstance(result, str)
84+
85+
@pytest.mark.dependency(depends=["test_optimum_convert_model_to_ir"])
86+
def test_optimum_whisper_pipeline(self):
87+
config = {"_models": [model_dir], "_device": "CPU"}
88+
pipeline = OptimumWhisperPipeline(config)
89+
evaluator = WhisperEvaluator(None, pipeline, None)
90+
91+
result = evaluator.pipe._get_predictions(input_data, identifiers, input_meta)
92+
assert isinstance(result, str)

0 commit comments

Comments
 (0)