@@ -336,7 +336,7 @@ class VadIterator
336
336
// The method should be called in each thread/proc in multi-thread/proc work
337
337
session_options.SetIntraOpNumThreads (intra_threads);
338
338
session_options.SetInterOpNumThreads (inter_threads);
339
- session_options.SetGraphOptimizationLevel (GraphOptimizationLevel::ORT_ENABLE_ALL );
339
+ session_options.SetGraphOptimizationLevel (GraphOptimizationLevel::ORT_DISABLE_ALL );
340
340
};
341
341
342
342
void init_onnx_model (const std::wstring& model_path)
@@ -350,8 +350,7 @@ class VadIterator
350
350
void reset_states ()
351
351
{
352
352
// Call reset before each audio start
353
- std::memset (_h.data (), 0 , _h.size () * sizeof (float ));
354
- std::memset (_c.data (), 0 , _c.size () * sizeof (float ));
353
+ std::memset (_state.data (), 0 .0f , _state.size () * sizeof (float ));
355
354
triggered = false ;
356
355
temp_end = 0 ;
357
356
current_sample = 0 ;
@@ -362,39 +361,34 @@ class VadIterator
362
361
current_speech = timestamp_t ();
363
362
};
364
363
365
- void predict (const std::vector<float >& data)
364
+ void predict (const std::vector<float > & data)
366
365
{
367
366
// Infer
368
367
// Create ort tensors
369
368
input.assign (data.begin (), data.end ());
370
369
Ort::Value input_ort = Ort::Value::CreateTensor<float >(
371
370
memory_info, input.data (), input.size (), input_node_dims, 2 );
371
+ Ort::Value state_ort = Ort::Value::CreateTensor<float >(
372
+ memory_info, _state.data (), _state.size (), state_node_dims, 3 );
372
373
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t >(
373
374
memory_info, sr.data (), sr.size (), sr_node_dims, 1 );
374
- Ort::Value h_ort = Ort::Value::CreateTensor<float >(
375
- memory_info, _h.data (), _h.size (), hc_node_dims, 3 );
376
- Ort::Value c_ort = Ort::Value::CreateTensor<float >(
377
- memory_info, _c.data (), _c.size (), hc_node_dims, 3 );
378
375
379
376
// Clear and add inputs
380
377
ort_inputs.clear ();
381
378
ort_inputs.emplace_back (std::move (input_ort));
379
+ ort_inputs.emplace_back (std::move (state_ort));
382
380
ort_inputs.emplace_back (std::move (sr_ort));
383
- ort_inputs.emplace_back (std::move (h_ort));
384
- ort_inputs.emplace_back (std::move (c_ort));
385
381
386
382
// Infer
387
383
ort_outputs = session->Run (
388
- Ort::RunOptions{ nullptr },
384
+ Ort::RunOptions{nullptr },
389
385
input_node_names.data (), ort_inputs.data (), ort_inputs.size (),
390
386
output_node_names.data (), output_node_names.size ());
391
387
392
388
// Output probability & update h,c recursively
393
389
float speech_prob = ort_outputs[0 ].GetTensorMutableData <float >()[0 ];
394
- float * hn = ort_outputs[1 ].GetTensorMutableData <float >();
395
- std::memcpy (_h.data (), hn, size_hc * sizeof (float ));
396
- float * cn = ort_outputs[2 ].GetTensorMutableData <float >();
397
- std::memcpy (_c.data (), cn, size_hc * sizeof (float ));
390
+ float *stateN = ort_outputs[1 ].GetTensorMutableData <float >();
391
+ std::memcpy (_state.data (), stateN, size_state * sizeof (float ));
398
392
399
393
// Push forward sample index
400
394
current_sample += window_size_samples;
@@ -419,7 +413,7 @@ class VadIterator
419
413
current_speech.start = current_sample - window_size_samples;
420
414
}
421
415
return ;
422
- }
416
+ }
423
417
424
418
if (
425
419
(triggered == true )
@@ -429,19 +423,19 @@ class VadIterator
429
423
current_speech.end = prev_end;
430
424
speeches.push_back (current_speech);
431
425
current_speech = timestamp_t ();
432
-
426
+
433
427
// previously reached silence(< neg_thres) and is still not speech(< thres)
434
428
if (next_start < prev_end)
435
429
triggered = false ;
436
- else {
430
+ else {
437
431
current_speech.start = next_start;
438
432
}
439
433
prev_end = 0 ;
440
434
next_start = 0 ;
441
435
temp_end = 0 ;
442
436
443
437
}
444
- else {
438
+ else {
445
439
current_speech.end = current_sample;
446
440
speeches.push_back (current_speech);
447
441
current_speech = timestamp_t ();
@@ -466,7 +460,7 @@ class VadIterator
466
460
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
467
461
printf (" { silence: %.3f s (%.3f) %08d}\n " , 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
468
462
#endif // __DEBUG_SPEECH_PROB___
469
- }
463
+ }
470
464
return ;
471
465
}
472
466
@@ -552,7 +546,7 @@ class VadIterator
552
546
std::cout << speeches[i].c_str () << std::endl;
553
547
#endif // #ifdef __DEBUG_SPEECH_PROB___
554
548
std::vector<float > slice (&input_wav[speeches[i].start ], &input_wav[speeches[i].end ]);
555
- output_wav.insert (output_wav.end (), slice.begin (), slice.end ());
549
+ output_wav.insert (output_wav.end (),slice.begin (),slice.end ());
556
550
}
557
551
};
558
552
@@ -606,27 +600,26 @@ class VadIterator
606
600
// Inputs
607
601
std::vector<Ort::Value> ort_inputs;
608
602
609
- std::vector<const char *> input_node_names = { " input" , " sr " , " h " , " c " };
603
+ std::vector<const char *> input_node_names = {" input" , " state " , " sr " };
610
604
std::vector<float > input;
605
+ unsigned int size_state = 2 * 1 * 128 ; // It's FIXED.
606
+ std::vector<float > _state;
611
607
std::vector<int64_t > sr;
612
- unsigned int size_hc = 2 * 1 * 64 ; // It's FIXED.
613
- std::vector<float > _h;
614
- std::vector<float > _c;
615
608
616
609
int64_t input_node_dims[2 ] = {};
617
- const int64_t sr_node_dims[ 1 ] = { 1 };
618
- const int64_t hc_node_dims[ 3 ] = { 2 , 1 , 64 };
610
+ const int64_t state_node_dims[ 3 ] = {2 , 1 , 128 };
611
+ const int64_t sr_node_dims[ 1 ] = {1 };
619
612
620
613
// Outputs
621
614
std::vector<Ort::Value> ort_outputs;
622
- std::vector<const char *> output_node_names = { " output" , " hn " , " cn " };
615
+ std::vector<const char *> output_node_names = {" output" , " stateN " };
623
616
624
617
public:
625
618
// Construction
626
619
VadIterator (const std::wstring ModelPath,
627
- int Sample_rate = 16000 , int windows_frame_size = 64 ,
620
+ int Sample_rate = 16000 , int windows_frame_size = 32 ,
628
621
float Threshold = 0.5 , int min_silence_duration_ms = 0 ,
629
- int speech_pad_ms = 64 , int min_speech_duration_ms = 64 ,
622
+ int speech_pad_ms = 32 , int min_speech_duration_ms = 32 ,
630
623
float max_speech_duration_s = std::numeric_limits<float >::infinity())
631
624
{
632
625
init_onnx_model (ModelPath);
@@ -652,8 +645,7 @@ class VadIterator
652
645
input_node_dims[0 ] = 1 ;
653
646
input_node_dims[1 ] = window_size_samples;
654
647
655
- _h.resize (size_hc);
656
- _c.resize (size_hc);
648
+ _state.resize (size_state);
657
649
sr.resize (1 );
658
650
sr[0 ] = sample_rate;
659
651
};
0 commit comments