@@ -118,32 +118,6 @@ def _get_predictions(self, data, identifiers, input_meta):
118
118
return self .pipeline .generate (data [0 ]).texts [0 ]
119
119
120
120
121
- class OptimumIntelPipeline (WhisperPipeline ):
122
- def _initialize_pipeline (self , config ):
123
- try :
124
- from optimum .intel .openvino import \
125
- OVModelForSpeechSeq2Seq # pylint: disable=C0415
126
- except ImportError as import_err :
127
- UnsupportedPackage ("optimum.intel.openvino" , import_err .msg ).raise_error (self .__class__ .__name__ )
128
-
129
- device = config .get ("_device" , "CPU" )
130
- model_dir = config .get ("_models" , [None ])[0 ]
131
- ov_model = OVModelForSpeechSeq2Seq .from_pretrained (str (model_dir )).to (device )
132
- ov_processor = AutoProcessor .from_pretrained (str (model_dir ))
133
-
134
- pipeline = AutomaticSpeechRecognitionPipeline (
135
- model = ov_model ,
136
- tokenizer = ov_processor .tokenizer ,
137
- feature_extractor = ov_processor .feature_extractor
138
- )
139
- return pipeline
140
-
141
- def _get_predictions (self , data , identifiers , input_meta ):
142
- sampling_rate = input_meta [0 ].get ("sample_rate" )
143
- sample = {"path" : identifiers [0 ], "array" : data [0 ], "sampling_rate" : sampling_rate }
144
- return self .pipeline (sample )["text" ]
145
-
146
-
147
121
class TransformersAsrPipeline (WhisperPipeline ):
148
122
def _initialize_pipeline (self , config ):
149
123
try :
@@ -173,3 +147,31 @@ def _get_predictions(self, data, identifiers, input_meta):
173
147
sampling_rate = input_meta [0 ].get ("sample_rate" )
174
148
sample = {"path" : identifiers [0 ], "array" : data [0 ], "sampling_rate" : sampling_rate }
175
149
return self .pipeline (sample )["text" ]
150
+
151
+
152
+ class OptimumIntelPipeline (WhisperPipeline ):
153
+ def _initialize_pipeline (self , config ):
154
+ try :
155
+ from optimum .intel .openvino import \
156
+ OVModelForSpeechSeq2Seq # pylint: disable=C0415
157
+ except ImportError as import_err :
158
+ UnsupportedPackage ("optimum.intel.openvino" , import_err .msg ).raise_error (self .__class__ .__name__ )
159
+
160
+ device = config .get ("_device" , "CPU" )
161
+ model_dir = config .get ("_models" , [None ])[0 ]
162
+ ov_model = OVModelForSpeechSeq2Seq .from_pretrained (str (model_dir ))
163
+ ov_processor = AutoProcessor .from_pretrained (str (model_dir ))
164
+
165
+ pipeline = AutomaticSpeechRecognitionPipeline (
166
+ model = ov_model ,
167
+ tokenizer = ov_processor .tokenizer ,
168
+ feature_extractor = ov_processor .feature_extractor ,
169
+ device = device ,
170
+ )
171
+ return pipeline
172
+
173
+ def _get_predictions (self , data , identifiers , input_meta ):
174
+ sampling_rate = input_meta [0 ].get ("sample_rate" )
175
+ sample = {"path" : identifiers [0 ], "array" : data [0 ], "sampling_rate" : sampling_rate }
176
+ return self .pipeline (sample )["text" ]
177
+
0 commit comments