@@ -85,9 +85,8 @@ OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel(
85
85
}
86
86
}
87
87
88
-
89
88
void OnlineEbranchformerTransducerModel::InitEncoder (void *model_data,
90
- size_t model_data_length) {
89
+ size_t model_data_length) {
91
90
encoder_sess_ = std::make_unique<Ort::Session>(
92
91
env_, model_data, model_data_length, encoder_sess_opts_);
93
92
@@ -153,9 +152,8 @@ void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data,
153
152
}
154
153
}
155
154
156
-
157
155
void OnlineEbranchformerTransducerModel::InitDecoder (void *model_data,
158
- size_t model_data_length) {
156
+ size_t model_data_length) {
159
157
decoder_sess_ = std::make_unique<Ort::Session>(
160
158
env_, model_data, model_data_length, decoder_sess_opts_);
161
159
@@ -180,7 +178,7 @@ void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data,
180
178
}
181
179
182
180
void OnlineEbranchformerTransducerModel::InitJoiner (void *model_data,
183
- size_t model_data_length) {
181
+ size_t model_data_length) {
184
182
joiner_sess_ = std::make_unique<Ort::Session>(
185
183
env_, model_data, model_data_length, joiner_sess_opts_);
186
184
@@ -200,7 +198,6 @@ void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data,
200
198
}
201
199
}
202
200
203
-
204
201
std::vector<Ort::Value> OnlineEbranchformerTransducerModel::StackStates (
205
202
const std::vector<std::vector<Ort::Value>> &states) const {
206
203
int32_t batch_size = static_cast <int32_t >(states.size ());
@@ -215,28 +212,28 @@ std::vector<Ort::Value> OnlineEbranchformerTransducerModel::StackStates(
215
212
ans.reserve (num_states);
216
213
217
214
for (int32_t i = 0 ; i != num_hidden_layers_; ++i) {
218
- { // cached_key
215
+ { // cached_key
219
216
for (int32_t n = 0 ; n != batch_size; ++n) {
220
217
buf[n] = &states[n][4 * i];
221
218
}
222
219
auto v = Cat (allocator, buf, /* axis */ 0 );
223
220
ans.push_back (std::move (v));
224
221
}
225
- { // cached_value
222
+ { // cached_value
226
223
for (int32_t n = 0 ; n != batch_size; ++n) {
227
224
buf[n] = &states[n][4 * i + 1 ];
228
225
}
229
226
auto v = Cat (allocator, buf, 0 );
230
227
ans.push_back (std::move (v));
231
228
}
232
- { // cached_conv
229
+ { // cached_conv
233
230
for (int32_t n = 0 ; n != batch_size; ++n) {
234
231
buf[n] = &states[n][4 * i + 2 ];
235
232
}
236
233
auto v = Cat (allocator, buf, 0 );
237
234
ans.push_back (std::move (v));
238
235
}
239
- { // cached_conv_fusion
236
+ { // cached_conv_fusion
240
237
for (int32_t n = 0 ; n != batch_size; ++n) {
241
238
buf[n] = &states[n][4 * i + 3 ];
242
239
}
@@ -245,7 +242,7 @@ std::vector<Ort::Value> OnlineEbranchformerTransducerModel::StackStates(
245
242
}
246
243
}
247
244
248
- { // processed_lens
245
+ { // processed_lens
249
246
for (int32_t n = 0 ; n != batch_size; ++n) {
250
247
buf[n] = &states[n][num_states - 1 ];
251
248
}
@@ -256,11 +253,9 @@ std::vector<Ort::Value> OnlineEbranchformerTransducerModel::StackStates(
256
253
return ans;
257
254
}
258
255
259
-
260
256
std::vector<std::vector<Ort::Value>>
261
257
OnlineEbranchformerTransducerModel::UnStackStates (
262
258
const std::vector<Ort::Value> &states) const {
263
-
264
259
assert (static_cast <int32_t >(states.size ()) == num_hidden_layers_ * 4 + 1 );
265
260
266
261
int32_t batch_size = states[0 ].GetTensorTypeAndShapeInfo ().GetShape ()[0 ];
@@ -272,31 +267,31 @@ OnlineEbranchformerTransducerModel::UnStackStates(
272
267
ans.resize (batch_size);
273
268
274
269
for (int32_t i = 0 ; i != num_hidden_layers_; ++i) {
275
- { // cached_key
270
+ { // cached_key
276
271
auto v = Unbind (allocator, &states[i * 4 ], /* axis */ 0 );
277
272
assert (static_cast <int32_t >(v.size ()) == batch_size);
278
273
279
274
for (int32_t n = 0 ; n != batch_size; ++n) {
280
275
ans[n].push_back (std::move (v[n]));
281
276
}
282
277
}
283
- { // cached_value
278
+ { // cached_value
284
279
auto v = Unbind (allocator, &states[i * 4 + 1 ], 0 );
285
280
assert (static_cast <int32_t >(v.size ()) == batch_size);
286
281
287
282
for (int32_t n = 0 ; n != batch_size; ++n) {
288
283
ans[n].push_back (std::move (v[n]));
289
284
}
290
285
}
291
- { // cached_conv
286
+ { // cached_conv
292
287
auto v = Unbind (allocator, &states[i * 4 + 2 ], 0 );
293
288
assert (static_cast <int32_t >(v.size ()) == batch_size);
294
289
295
290
for (int32_t n = 0 ; n != batch_size; ++n) {
296
291
ans[n].push_back (std::move (v[n]));
297
292
}
298
293
}
299
- { // cached_conv_fusion
294
+ { // cached_conv_fusion
300
295
auto v = Unbind (allocator, &states[i * 4 + 3 ], 0 );
301
296
assert (static_cast <int32_t >(v.size ()) == batch_size);
302
297
@@ -306,7 +301,7 @@ OnlineEbranchformerTransducerModel::UnStackStates(
306
301
}
307
302
}
308
303
309
- { // processed_lens
304
+ { // processed_lens
310
305
auto v = Unbind<int64_t >(allocator, &states.back (), 0 );
311
306
assert (static_cast <int32_t >(v.size ()) == batch_size);
312
307
@@ -318,7 +313,6 @@ OnlineEbranchformerTransducerModel::UnStackStates(
318
313
return ans;
319
314
}
320
315
321
-
322
316
std::vector<Ort::Value>
323
317
OnlineEbranchformerTransducerModel::GetEncoderInitStates () {
324
318
std::vector<Ort::Value> ans;
@@ -332,40 +326,37 @@ OnlineEbranchformerTransducerModel::GetEncoderInitStates() {
332
326
int32_t channels_conv_fusion = 2 * hidden_size_;
333
327
334
328
for (int32_t i = 0 ; i != num_hidden_layers_; ++i) {
335
- { // cached_key_{i}
329
+ { // cached_key_{i}
336
330
std::array<int64_t , 4 > s{1 , num_heads_, left_context_len_, head_dim_};
337
- auto v =
338
- Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
331
+ auto v = Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
339
332
Fill (&v, 0 );
340
333
ans.push_back (std::move (v));
341
334
}
342
335
343
- { // cahced_value_{i}
336
+ { // cahced_value_{i}
344
337
std::array<int64_t , 4 > s{1 , num_heads_, left_context_len_, head_dim_};
345
- auto v =
346
- Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
338
+ auto v = Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
347
339
Fill (&v, 0 );
348
340
ans.push_back (std::move (v));
349
341
}
350
342
351
- { // cached_conv_{i}
343
+ { // cached_conv_{i}
352
344
std::array<int64_t , 3 > s{1 , channels_conv, left_context_conv};
353
- auto v =
354
- Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
345
+ auto v = Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
355
346
Fill (&v, 0 );
356
347
ans.push_back (std::move (v));
357
348
}
358
349
359
- { // cached_conv_fusion_{i}
360
- std::array<int64_t , 3 > s{1 , channels_conv_fusion, left_context_conv_fusion};
361
- auto v =
362
- Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
350
+ { // cached_conv_fusion_{i}
351
+ std::array<int64_t , 3 > s{1 , channels_conv_fusion,
352
+ left_context_conv_fusion};
353
+ auto v = Ort::Value::CreateTensor<float >(allocator_, s.data (), s.size ());
363
354
Fill (&v, 0 );
364
355
ans.push_back (std::move (v));
365
356
}
366
357
} // num_hidden_layers_
367
358
368
- { // processed_lens
359
+ { // processed_lens
369
360
std::array<int64_t , 1 > s{1 };
370
361
auto v = Ort::Value::CreateTensor<int64_t >(allocator_, s.data (), s.size ());
371
362
Fill<int64_t >(&v, 0 );
@@ -375,11 +366,10 @@ OnlineEbranchformerTransducerModel::GetEncoderInitStates() {
375
366
return ans;
376
367
}
377
368
378
-
379
369
std::pair<Ort::Value, std::vector<Ort::Value>>
380
- OnlineEbranchformerTransducerModel::RunEncoder (Ort::Value features,
381
- std::vector<Ort::Value> states,
382
- Ort::Value /* processed_frames */ ) {
370
+ OnlineEbranchformerTransducerModel::RunEncoder (
371
+ Ort::Value features, std::vector<Ort::Value> states,
372
+ Ort::Value /* processed_frames */ ) {
383
373
std::vector<Ort::Value> encoder_inputs;
384
374
encoder_inputs.reserve (1 + states.size ());
385
375
@@ -402,7 +392,6 @@ OnlineEbranchformerTransducerModel::RunEncoder(Ort::Value features,
402
392
return {std::move (encoder_out[0 ]), std::move (next_states)};
403
393
}
404
394
405
-
406
395
Ort::Value OnlineEbranchformerTransducerModel::RunDecoder (
407
396
Ort::Value decoder_input) {
408
397
auto decoder_out = decoder_sess_->Run (
@@ -411,9 +400,8 @@ Ort::Value OnlineEbranchformerTransducerModel::RunDecoder(
411
400
return std::move (decoder_out[0 ]);
412
401
}
413
402
414
-
415
- Ort::Value OnlineEbranchformerTransducerModel::RunJoiner (Ort::Value encoder_out,
416
- Ort::Value decoder_out) {
403
+ Ort::Value OnlineEbranchformerTransducerModel::RunJoiner (
404
+ Ort::Value encoder_out, Ort::Value decoder_out) {
417
405
std::array<Ort::Value, 2 > joiner_input = {std::move (encoder_out),
418
406
std::move (decoder_out)};
419
407
auto logit =
@@ -424,7 +412,6 @@ Ort::Value OnlineEbranchformerTransducerModel::RunJoiner(Ort::Value encoder_out,
424
412
return std::move (logit[0 ]);
425
413
}
426
414
427
-
428
415
#if __ANDROID_API__ >= 9
429
416
template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel(
430
417
AAssetManager *mgr, const OnlineModelConfig &config);
0 commit comments