Skip to content

Commit 31984d9

Browse files
committed
update silero_vad to v5.1
1 parent 642f375 commit 31984d9

File tree

2 files changed

+24
-32
lines changed

2 files changed

+24
-32
lines changed

examples/common-portaudio.h

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ class VadIterator
336336
// The method should be called in each thread/proc in multi-thread/proc work
337337
session_options.SetIntraOpNumThreads(intra_threads);
338338
session_options.SetInterOpNumThreads(inter_threads);
339-
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
339+
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
340340
};
341341

342342
void init_onnx_model(const std::wstring& model_path)
@@ -350,8 +350,7 @@ class VadIterator
350350
void reset_states()
351351
{
352352
// Call reset before each audio start
353-
std::memset(_h.data(), 0, _h.size() * sizeof(float));
354-
std::memset(_c.data(), 0, _c.size() * sizeof(float));
353+
std::memset(_state.data(), 0.0f, _state.size() * sizeof(float));
355354
triggered = false;
356355
temp_end = 0;
357356
current_sample = 0;
@@ -362,39 +361,34 @@ class VadIterator
362361
current_speech = timestamp_t();
363362
};
364363

365-
void predict(const std::vector<float>& data)
364+
void predict(const std::vector<float> &data)
366365
{
367366
// Infer
368367
// Create ort tensors
369368
input.assign(data.begin(), data.end());
370369
Ort::Value input_ort = Ort::Value::CreateTensor<float>(
371370
memory_info, input.data(), input.size(), input_node_dims, 2);
371+
Ort::Value state_ort = Ort::Value::CreateTensor<float>(
372+
memory_info, _state.data(), _state.size(), state_node_dims, 3);
372373
Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t>(
373374
memory_info, sr.data(), sr.size(), sr_node_dims, 1);
374-
Ort::Value h_ort = Ort::Value::CreateTensor<float>(
375-
memory_info, _h.data(), _h.size(), hc_node_dims, 3);
376-
Ort::Value c_ort = Ort::Value::CreateTensor<float>(
377-
memory_info, _c.data(), _c.size(), hc_node_dims, 3);
378375

379376
// Clear and add inputs
380377
ort_inputs.clear();
381378
ort_inputs.emplace_back(std::move(input_ort));
379+
ort_inputs.emplace_back(std::move(state_ort));
382380
ort_inputs.emplace_back(std::move(sr_ort));
383-
ort_inputs.emplace_back(std::move(h_ort));
384-
ort_inputs.emplace_back(std::move(c_ort));
385381

386382
// Infer
387383
ort_outputs = session->Run(
388-
Ort::RunOptions{ nullptr },
384+
Ort::RunOptions{nullptr},
389385
input_node_names.data(), ort_inputs.data(), ort_inputs.size(),
390386
output_node_names.data(), output_node_names.size());
391387

392388
// Output probability & update h,c recursively
393389
float speech_prob = ort_outputs[0].GetTensorMutableData<float>()[0];
394-
float* hn = ort_outputs[1].GetTensorMutableData<float>();
395-
std::memcpy(_h.data(), hn, size_hc * sizeof(float));
396-
float* cn = ort_outputs[2].GetTensorMutableData<float>();
397-
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
390+
float *stateN = ort_outputs[1].GetTensorMutableData<float>();
391+
std::memcpy(_state.data(), stateN, size_state * sizeof(float));
398392

399393
// Push forward sample index
400394
current_sample += window_size_samples;
@@ -419,7 +413,7 @@ class VadIterator
419413
current_speech.start = current_sample - window_size_samples;
420414
}
421415
return;
422-
}
416+
}
423417

424418
if (
425419
(triggered == true)
@@ -429,19 +423,19 @@ class VadIterator
429423
current_speech.end = prev_end;
430424
speeches.push_back(current_speech);
431425
current_speech = timestamp_t();
432-
426+
433427
// previously reached silence(< neg_thres) and is still not speech(< thres)
434428
if (next_start < prev_end)
435429
triggered = false;
436-
else {
430+
else{
437431
current_speech.start = next_start;
438432
}
439433
prev_end = 0;
440434
next_start = 0;
441435
temp_end = 0;
442436

443437
}
444-
else {
438+
else{
445439
current_speech.end = current_sample;
446440
speeches.push_back(current_speech);
447441
current_speech = timestamp_t();
@@ -466,7 +460,7 @@ class VadIterator
466460
float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point.
467461
printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples);
468462
#endif //__DEBUG_SPEECH_PROB___
469-
}
463+
}
470464
return;
471465
}
472466

@@ -552,7 +546,7 @@ class VadIterator
552546
std::cout << speeches[i].c_str() << std::endl;
553547
#endif //#ifdef __DEBUG_SPEECH_PROB___
554548
std::vector<float> slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]);
555-
output_wav.insert(output_wav.end(), slice.begin(), slice.end());
549+
output_wav.insert(output_wav.end(),slice.begin(),slice.end());
556550
}
557551
};
558552

@@ -606,27 +600,26 @@ class VadIterator
606600
// Inputs
607601
std::vector<Ort::Value> ort_inputs;
608602

609-
std::vector<const char*> input_node_names = { "input", "sr", "h", "c" };
603+
std::vector<const char *> input_node_names = {"input", "state", "sr"};
610604
std::vector<float> input;
605+
unsigned int size_state = 2 * 1 * 128; // It's FIXED.
606+
std::vector<float> _state;
611607
std::vector<int64_t> sr;
612-
unsigned int size_hc = 2 * 1 * 64; // It's FIXED.
613-
std::vector<float> _h;
614-
std::vector<float> _c;
615608

616609
int64_t input_node_dims[2] = {};
617-
const int64_t sr_node_dims[1] = { 1 };
618-
const int64_t hc_node_dims[3] = { 2, 1, 64 };
610+
const int64_t state_node_dims[3] = {2, 1, 128};
611+
const int64_t sr_node_dims[1] = {1};
619612

620613
// Outputs
621614
std::vector<Ort::Value> ort_outputs;
622-
std::vector<const char*> output_node_names = { "output", "hn", "cn" };
615+
std::vector<const char *> output_node_names = {"output", "stateN"};
623616

624617
public:
625618
// Construction
626619
VadIterator(const std::wstring ModelPath,
627-
int Sample_rate = 16000, int windows_frame_size = 64,
620+
int Sample_rate = 16000, int windows_frame_size = 32,
628621
float Threshold = 0.5, int min_silence_duration_ms = 0,
629-
int speech_pad_ms = 64, int min_speech_duration_ms = 64,
622+
int speech_pad_ms = 32, int min_speech_duration_ms = 32,
630623
float max_speech_duration_s = std::numeric_limits<float>::infinity())
631624
{
632625
init_onnx_model(ModelPath);
@@ -652,8 +645,7 @@ class VadIterator
652645
input_node_dims[0] = 1;
653646
input_node_dims[1] = window_size_samples;
654647

655-
_h.resize(size_hc);
656-
_c.resize(size_hc);
648+
_state.resize(size_state);
657649
sr.resize(1);
658650
sr[0] = sample_rate;
659651
};

examples/silero_vad.onnx

2.22 MB
Binary file not shown.

0 commit comments

Comments
 (0)