Skip to content

Commit a6271ec

Browse files
committed
vad : add build encoder layer graph
1 parent 93883e7 commit a6271ec

File tree

3 files changed

+156
-13
lines changed

3 files changed

+156
-13
lines changed

include/whisper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ extern "C" {
687687

688688
WHISPER_API struct whisper_vad_segments whisper_vad_detect_speech(
689689
whisper_vad_context * vctx,
690-
const float * pcmf32, int n_samples);
690+
const float * pcmf32, int n_samples, int n_threads);
691691

692692
WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
693693
WHISPER_API void whisper_vad_free_state (struct whisper_vad_state * state);

src/whisper.cpp

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4431,6 +4431,17 @@ struct whisper_vad_context {
44314431
whisper_vad_model model;
44324432
whisper_vad_state * state = nullptr;
44334433

4434+
int window_size_samples;
4435+
int context_samples;
4436+
int effective_window_size;
4437+
4438+
bool triggered;
4439+
std::vector<float> context_buffer;
4440+
unsigned int current_sample;
4441+
unsigned int temp_end;
4442+
4443+
std::vector<whisper_vad_segment> detected_segments;
4444+
44344445
whisper_context_params params;
44354446

44364447
std::string path_model;
@@ -4470,11 +4481,50 @@ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams &
44704481
return nullptr;
44714482
}
44724483

4484+
static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context* ctx0,
4485+
const whisper_vad_model & model, ggml_tensor * cur) {
4486+
WHISPER_LOG_INFO("%s: building encoder layer\n", __func__);
4487+
// Reshape from the STFT output which is [258, 1, 1, 1] where are complex
4488+
// number pairs. I think we can ignore the imaginary part and just use the
4489+
// real part here.
4490+
struct ggml_tensor * real_part = ggml_view_1d(ctx0, cur, 129, 0);
4491+
struct ggml_tensor * reshaped = ggml_reshape_3d(ctx0, real_part, 1, 129, 1);
4492+
4493+
// First Conv1D: expands to 128 channels.
4494+
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, reshaped, 1, 1, 1);
4495+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
4496+
cur = ggml_relu(ctx0, cur);
4497+
4498+
// First Conv1D: reduces to 64 channels.
4499+
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 1, 1, 1);
4500+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
4501+
cur = ggml_relu(ctx0, cur);
4502+
4503+
// Third Conv1D: maintains 64 channels
4504+
cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 1, 1, 1);
4505+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
4506+
cur = ggml_relu(ctx0, cur);
4507+
4508+
// Fourth Conv1D: expands to 128 channels
4509+
cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
4510+
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
4511+
cur = ggml_relu(ctx0, cur);
4512+
4513+
return cur;
4514+
}
4515+
4516+
static ggml_tensor * whisper_vad_lstm_layer(ggml_context* ctx0,
4517+
const whisper_vad_context & vctx, ggml_tensor * cur) {
4518+
WHISPER_LOG_INFO("%s: building LSTM layer\n", __func__);
4519+
4520+
return cur;
4521+
}
4522+
44734523
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
44744524
whisper_vad_state & vstate) {
44754525
const auto & model = vctx.model;
44764526
const auto & hparams = model.hparams;
4477-
const int n_window = 256;
4527+
const int n_window = vctx.effective_window_size;
44784528

44794529
WHISPER_LOG_INFO("%s: Building VAD graph\n", __func__);
44804530
struct ggml_init_params params = {
@@ -4487,18 +4537,25 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx,
44874537

44884538
ggml_cgraph * gf = ggml_new_graph(ctx0);
44894539

4490-
struct ggml_tensor * samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_window);
4491-
ggml_set_name(samples, "samples");
4492-
ggml_set_input(samples);
4540+
// We process one frame/segment at a time of size n_window.
4541+
struct ggml_tensor * frame = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_window);
4542+
ggml_set_name(frame, "frame");
4543+
ggml_set_input(frame);
44934544

44944545
struct ggml_tensor * cur = nullptr;
44954546
{
4496-
cur = ggml_mul_mat(ctx0, model.stft_forward_basis, samples);
4547+
cur = ggml_mul_mat(ctx0, model.stft_forward_basis, frame);
4548+
ggml_set_name(cur, "stft");
4549+
ggml_set_output(cur);
4550+
4551+
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
4552+
4553+
cur = whisper_vad_lstm_layer(ctx0, vctx, cur);
44974554
}
44984555

4499-
//ggml_build_forward_expand(gf, cur);
4556+
ggml_build_forward_expand(gf, cur);
45004557

4501-
//ggml_free(ctx0);
4558+
ggml_free(ctx0);
45024559

45034560
return gf;
45044561
}
@@ -4604,6 +4661,14 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
46044661
whisper_vad_context * vctx = new whisper_vad_context;
46054662
vctx->path_model = path_model;
46064663

4664+
vctx->window_size_samples = 192;
4665+
vctx->context_samples = 64;
4666+
vctx->effective_window_size = vctx->window_size_samples + vctx->context_samples;
4667+
vctx->triggered = false;
4668+
vctx->context_buffer.resize(vctx->context_samples, 0.0f);
4669+
vctx->current_sample = 0;
4670+
vctx->temp_end = 0;
4671+
46074672
auto & model = vctx->model;
46084673
auto & hparams = model.hparams;
46094674

@@ -4899,19 +4964,91 @@ whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
48994964
return vctx;
49004965
}
49014966

4902-
49034967
struct whisper_vad_segments whisper_vad_detect_speech(
49044968
struct whisper_vad_context * vctx,
49054969
const float * pcmf32,
4906-
int n_samples) {
4970+
int n_samples,
4971+
int n_threads) {
49074972
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
4973+
auto & sched = vctx->state->sched.sched;
49084974

49094975
struct whisper_vad_segments segments {
49104976
/* n_segments = */ 0,
49114977
/* segments = */ nullptr,
49124978
};
49134979

4914-
const ggml_cgraph * gf = whisper_vad_build_graph(*vctx, *vctx->state);
4980+
// Reset state for this detection
4981+
vctx->triggered = false;
4982+
vctx->current_sample = 0;
4983+
vctx->temp_end = 0;
4984+
std::fill(vctx->context_buffer.begin(), vctx->context_buffer.end(), 0.0f);
4985+
vctx->detected_segments.clear();
4986+
4987+
ggml_cgraph * gf = whisper_vad_build_graph(*vctx, *vctx->state);
4988+
4989+
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
4990+
// TODO(danbev) Add error handling
4991+
return segments;
4992+
}
4993+
4994+
std::vector<float> window_with_context(vctx->effective_window_size);
4995+
WHISPER_LOG_INFO("%s: window_with_context.size() = %zu\n", __func__, window_with_context.size());
4996+
WHISPER_LOG_INFO("%s: window_sample_size: %u\n", __func__, vctx->window_size_samples);
4997+
WHISPER_LOG_INFO("%s: context_sample_size: %u\n", __func__, vctx->context_samples);
4998+
WHISPER_LOG_INFO("%s: effective_window_size: %u\n", __func__, vctx->effective_window_size);
4999+
5000+
whisper_vad_segment current_segment = {-1.0f, -1.0f};
5001+
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
5002+
5003+
WHISPER_LOG_INFO("%s: frame tensor size: %ld\n", __func__, frame->ne[0]);
5004+
5005+
for (int i = 0; i < n_samples; i += vctx->window_size_samples) {
5006+
// Skip if we don't have enough samples for a full window
5007+
if (i + vctx->window_size_samples > n_samples) {
5008+
break;
5009+
}
5010+
//WHISPER_LOG_INFO("%s: processing window %d\n", __func__, i / vctx->window_size_samples);
5011+
5012+
// Copy the previous context buffer into the next window to be processed next
5013+
// context_buffer contains the 64 samples from the previous window and this is
5014+
// part of the overlapping windows to avoid spectral leakage.
5015+
std::copy(vctx->context_buffer.begin(), vctx->context_buffer.end(), window_with_context.begin());
5016+
5017+
// Copy the current samples from pcmf32 into the window_with_context,
5018+
// starting after the context buffer copied above.
5019+
std::copy(&pcmf32[i], &pcmf32[i + vctx->window_size_samples], window_with_context.begin() + vctx->context_samples);
5020+
5021+
ggml_backend_tensor_set(frame, window_with_context.data(), 0, vctx->effective_window_size * sizeof(float));
5022+
5023+
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
5024+
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
5025+
break;
5026+
}
5027+
5028+
// TODO(danbev): get the speech probability once it is implemented
5029+
5030+
// Update the context buffer for the next iteration
5031+
std::copy(&pcmf32[i + vctx->window_size_samples - vctx->context_samples],
5032+
&pcmf32[i + vctx->window_size_samples],
5033+
vctx->context_buffer.begin());
5034+
5035+
vctx->current_sample += vctx->window_size_samples;
5036+
}
5037+
WHISPER_LOG_INFO("%s: finished processing %d samples\n", __func__, n_samples);
5038+
5039+
// Print out the result of one STFT operation
5040+
/*
5041+
{
5042+
struct ggml_tensor * stft = ggml_graph_get_tensor(gf, "stft");
5043+
std::vector<float> output;
5044+
output.resize(256);
5045+
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, stft);
5046+
ggml_backend_tensor_get(stft, output.data(), 0, ggml_nbytes(stft));
5047+
for (int i = 0; i < 10; i++) {
5048+
WHISPER_LOG_INFO("%s: output[%d]: %f\n", __func__, i, output[i]);
5049+
}
5050+
}
5051+
*/
49155052

49165053
return segments;
49175054
}

tests/test-vad.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ int main() {
2828
assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false));
2929
assert(pcmf32.size() > 0);
3030
assert(pcmf32s.size() == 0); // no stereo vector
31-
//printf("Read %zu samples from %s\n", pcmf32.size(), sample_path.c_str());
31+
32+
/*
33+
printf("Read %zu samples from %s\n", pcmf32.size(), sample_path.c_str());
34+
for (int i = 900; i < 1000; i++) {
35+
printf("%s: input pcmf32[%d]: %f\n", __func__, i, pcmf32[i]);
36+
}
37+
*/
3238

3339
// Load the VAD model
3440
struct whisper_vad_params params = whisper_vad_default_params();
@@ -45,7 +51,7 @@ int main() {
4551
// Detect speech segments
4652
struct whisper_vad_segments segments = whisper_vad_detect_speech(
4753
vctx,
48-
pcmf32.data(), pcmf32.size());
54+
pcmf32.data(), pcmf32.size(), 1);
4955

5056
//assert(segments.n_segments > 0);
5157
return 0;

0 commit comments

Comments
 (0)