Skip to content

Commit 705c00a

Browse files
committed
vad : add initial Voice Activity Detection (VAD) support
This commit add support for Voice Activity Detection (VAD). When enabled this feature will process the audio input and detect speech segments. This information is then used to reduce the number of samples that need to be processed by whisper_full. This initial support is based on the Silero VAD model which needs to be converted to GGML format: ```console $ (venv) pip install silero-vad $ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin ``` There is test the tests the VAD support in isolation: ```console $ cmake --build build --target test-vad && \ ctest -R ^test-vad$ --test-dir build -C Debug --output-on-failure -VV ``` And one that tests VAD in combination with whisper_full: ```console $ cmake --build build --target test-vad-full && \ ctest -R test-vad-full --test-dir build -C Debug --output-on-failure -VV ``` Resolves: ggml-org#3003
1 parent b7db9e7 commit 705c00a

8 files changed

+1787
-5
lines changed

include/whisper.h

+82
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,17 @@ extern "C" {
570570
size_t n_grammar_rules;
571571
size_t i_start_rule;
572572
float grammar_penalty;
573+
574+
// Voice Activity Detection (VAD) params
575+
bool vad; // Enable VAD
576+
const char * vad_model_path; // Path to VAD model
577+
float vad_threshold; // Probability threshold to consider as speech.
578+
int vad_min_speech_duration_ms; // Min duration for a valid speech segment.
579+
int vad_min_silence_duration_ms; // Min silence duration to consider speech as ended.
580+
float vad_max_speech_duration_s; // Max duration of a speech segment before forcing a break.
581+
int vad_speech_pad_ms; // Padding added before and after speech segments.
582+
int vad_window_size_samples; // Number of audio samples in each probability window.
583+
float vad_samples_overlap; // Overlap in seconds when copying audio samples from speech segment.
573584
};
574585

575586
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
@@ -652,6 +663,77 @@ extern "C" {
652663
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
653664
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
654665

666+
// Voice Activity Detection (VAD)
667+
struct whisper_vad_context;
668+
struct whisper_vad_state;
669+
670+
struct whisper_vad_params {
671+
float threshold; // Probability threshold to consider as speech.
672+
int min_speech_duration_ms; // Min duration for a valid speech segment.
673+
int min_silence_duration_ms; // Min silence duration to consider speech as ended.
674+
float max_speech_duration_s; // Max duration of a speech segment before forcing a new segment.
675+
int speech_pad_ms; // Padding added before and after speech segments.
676+
int window_size_samples; // Number of audio samples in each probability window.
677+
float samples_overlap; // Overlap in seconds when copying audio samples from speech segment.
678+
};
679+
WHISPER_API struct whisper_vad_params whisper_vad_default_params(void);
680+
WHISPER_API struct whisper_vad_params whisper_vad_params_from(struct whisper_full_params wparams);
681+
682+
struct whisper_vad_context_params {
683+
int n_threads; // The number of threads to use for processing.
684+
bool use_gpu;
685+
int gpu_device; // CUDA device
686+
};
687+
WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void);
688+
689+
WHISPER_API struct whisper_vad_state * whisper_vad_init_state(struct whisper_vad_context * ctx);
690+
691+
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(
692+
const char * path_model,
693+
const struct whisper_vad_context_params params);
694+
695+
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
696+
const char * path_model,
697+
const struct whisper_vad_context_params params);
698+
699+
WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params_no_state(struct whisper_model_loader * loader,
700+
struct whisper_vad_context_params params);
701+
702+
struct whisper_vad_speech {
703+
int n_probs;
704+
float * probs;
705+
};
706+
707+
WHISPER_API struct whisper_vad_speech whisper_vad_detect_speech(
708+
struct whisper_vad_context * vctx,
709+
const float * samples, int n_samples);
710+
711+
struct whisper_vad_segment {
712+
float start; // Start time in seconds
713+
float end; // End time in seconds
714+
};
715+
716+
struct whisper_vad_timestamps {
717+
int n_segments;
718+
struct whisper_vad_segment * segments;
719+
};
720+
721+
WHISPER_API struct whisper_vad_timestamps whisper_vad_detect_speech_timestamps(
722+
struct whisper_vad_context * vctx,
723+
struct whisper_vad_params params,
724+
const float * samples, int n_samples);
725+
726+
WHISPER_API struct whisper_vad_timestamps whisper_vad_timestamps_from_probs(
727+
struct whisper_vad_context * vctx,
728+
struct whisper_vad_params params,
729+
struct whisper_vad_speech * probs);
730+
731+
WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
732+
WHISPER_API void whisper_vad_free_state (struct whisper_vad_state * state);
733+
WHISPER_API void whisper_vad_free_params (struct whisper_vad_params * params);
734+
WHISPER_API void whisper_vad_free_speech (struct whisper_vad_speech * speech);
735+
WHISPER_API void whisper_vad_free_timestamps(struct whisper_vad_timestamps * timestamps);
736+
655737
////////////////////////////////////////////////////////////////////////////
656738

657739
// Temporary helpers needed for exposing ggml interface

models/convert-silero-vad-to-ggml.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import os
2+
import struct
3+
import argparse
4+
import torch
5+
import numpy as np
6+
from silero_vad import load_silero_vad, __version__ as silero_version
7+
8+
def convert_silero_vad(output_path, print_tensors=True):
9+
model = load_silero_vad()
10+
state_dict = model.state_dict()
11+
12+
# Clean up state dict keys - filter out 8k model
13+
cleaned_dict = {}
14+
for key, value in state_dict.items():
15+
# Skip 8k model
16+
if "_8k" not in key:
17+
clean_key = key
18+
if not key.startswith("_model."):
19+
clean_key = "_model." + key
20+
cleaned_dict[clean_key] = value
21+
22+
base, ext = os.path.splitext(output_path)
23+
output_file = f"{base}-v{silero_version}-ggml{ext}"
24+
print(f"Saving GGML Silero-VAD model to {output_file}")
25+
26+
print("\nTensor info for debugging:")
27+
for key, tensor in cleaned_dict.items():
28+
print(f" - {key}: {tensor.shape} ({tensor.dtype})")
29+
print()
30+
31+
with open(output_file, "wb") as fout:
32+
# Write magic and version
33+
fout.write(struct.pack("i", 0x67676d6c))
34+
35+
model_type = "silero-16k"
36+
str_len = len(model_type)
37+
fout.write(struct.pack("i", str_len))
38+
fout.write(model_type.encode('utf-8'))
39+
40+
version_parts = silero_version.split('.')
41+
major, minor, patch = map(int, version_parts)
42+
print(f"Version: {major}.{minor}.{patch}")
43+
fout.write(struct.pack("i", major))
44+
fout.write(struct.pack("i", minor))
45+
fout.write(struct.pack("i", patch))
46+
47+
# Write model architecture parameters
48+
window_size = 512
49+
fout.write(struct.pack("i", window_size))
50+
context_size = 64
51+
fout.write(struct.pack("i", context_size))
52+
53+
n_encoder_layers = 4
54+
fout.write(struct.pack("i", n_encoder_layers))
55+
56+
# Write encoder dimensions
57+
input_channels = 129
58+
encoder_in_channels = [input_channels, 128, 64, 64]
59+
encoder_out_channels = [128, 64, 64, 128]
60+
kernel_size = 3
61+
62+
for i in range(n_encoder_layers):
63+
fout.write(struct.pack("i", encoder_in_channels[i]))
64+
fout.write(struct.pack("i", encoder_out_channels[i]))
65+
fout.write(struct.pack("i", kernel_size))
66+
67+
# Write LSTM dimensions
68+
lstm_input_size = 128
69+
lstm_hidden_size = 128
70+
fout.write(struct.pack("i", lstm_input_size))
71+
fout.write(struct.pack("i", lstm_hidden_size))
72+
73+
# Write final conv dimensions
74+
final_conv_in = 128
75+
final_conv_out = 1
76+
fout.write(struct.pack("i", final_conv_in))
77+
fout.write(struct.pack("i", final_conv_out))
78+
79+
# Define tensor keys to write
80+
tensor_keys = []
81+
82+
# Encoder weights
83+
for i in range(n_encoder_layers):
84+
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
85+
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
86+
if weight_key in cleaned_dict and bias_key in cleaned_dict:
87+
tensor_keys.append(weight_key)
88+
tensor_keys.append(bias_key)
89+
90+
# LSTM weights
91+
lstm_keys = [
92+
"_model.decoder.rnn.weight_ih",
93+
"_model.decoder.rnn.weight_hh",
94+
"_model.decoder.rnn.bias_ih",
95+
"_model.decoder.rnn.bias_hh"
96+
]
97+
tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
98+
99+
# Final conv weights
100+
final_keys = [
101+
"_model.decoder.decoder.2.weight",
102+
"_model.decoder.decoder.2.bias"
103+
]
104+
tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
105+
106+
# STFT basis - add this last
107+
stft_tensor = "_model.stft.forward_basis_buffer"
108+
tensor_keys.append(stft_tensor)
109+
110+
print(f"Writing {len(tensor_keys)} tensors:")
111+
for key in tensor_keys:
112+
if key in cleaned_dict:
113+
print(f" - {key}: {cleaned_dict[key].shape}")
114+
else:
115+
print(f" - {key}: MISSING")
116+
117+
# Process each tensor
118+
for key in tensor_keys:
119+
if key not in cleaned_dict:
120+
print(f"Warning: Missing tensor {key}, skipping")
121+
continue
122+
123+
tensor = cleaned_dict[key]
124+
125+
# Special handling for STFT tensor
126+
if key == "_model.stft.forward_basis_buffer":
127+
# Get the original numpy array without squeezing
128+
data = tensor.detach().cpu().numpy()
129+
# Ensure it has the expected shape
130+
print(f"STFT tensor original shape: {data.shape}")
131+
n_dims = 3
132+
tensor_shape = [data.shape[0], data.shape[1], data.shape[2]]
133+
is_conv_weight = True
134+
else:
135+
# For other tensors, we can use standard processing
136+
data = tensor.detach().cpu().squeeze().numpy()
137+
tensor_shape = list(data.shape)
138+
139+
# Ensure we have at most 4 dimensions for GGML
140+
n_dims = min(len(tensor_shape), 4)
141+
142+
# Reverse dimensions for GGML
143+
tensor_shape = tensor_shape[:n_dims]
144+
tensor_shape.reverse()
145+
146+
# Check if this is a convolution weight tensor
147+
is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
148+
149+
# Convert to float16 for convolution weights
150+
if is_conv_weight:
151+
data = data.astype(np.float16)
152+
ftype = 1 # float16
153+
else:
154+
ftype = 0 # float32
155+
156+
# Debug printing of tensor info
157+
print(f"\nWriting tensor: {key}")
158+
print(f" Original shape: {tensor.shape}")
159+
print(f" Processed shape: {data.shape}")
160+
print(f" GGML dimensions: {n_dims}")
161+
print(f" GGML shape: {tensor_shape}")
162+
print(f" Type: {'float16' if ftype == 1 else 'float32'}")
163+
164+
# Convert tensor name to bytes
165+
name_bytes = key.encode('utf-8')
166+
name_length = len(name_bytes)
167+
168+
# Write tensor header
169+
fout.write(struct.pack("i", n_dims))
170+
fout.write(struct.pack("i", name_length))
171+
fout.write(struct.pack("i", ftype))
172+
173+
# Write tensor dimensions
174+
for i in range(n_dims):
175+
size = tensor_shape[i] if i < len(tensor_shape) else 1
176+
fout.write(struct.pack("i", size))
177+
print(f" Writing dimension {i}: {size}")
178+
179+
# Write tensor name
180+
fout.write(name_bytes)
181+
182+
# Write tensor data
183+
data.tofile(fout)
184+
185+
print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
186+
187+
print(f"\nDone! Model has been converted to GGML format: {output_file}")
188+
print(f"File size: {os.path.getsize(output_file)} bytes")
189+
190+
if __name__ == "__main__":
191+
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
192+
parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
193+
parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
194+
args = parser.parse_args()
195+
196+
convert_silero_vad(args.output, args.print_tensors)
864 KB
Binary file not shown.

src/whisper-arch.h

+56
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,59 @@ static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
139139
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
140140
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
141141
};
142+
143+
enum vad_tensor {
144+
VAD_TENSOR_STFT_BASIS,
145+
VAD_TENSOR_ENC_0_WEIGHT,
146+
VAD_TENSOR_ENC_0_BIAS,
147+
VAD_TENSOR_ENC_1_WEIGHT,
148+
VAD_TENSOR_ENC_1_BIAS,
149+
VAD_TENSOR_ENC_2_WEIGHT,
150+
VAD_TENSOR_ENC_2_BIAS,
151+
VAD_TENSOR_ENC_3_WEIGHT,
152+
VAD_TENSOR_ENC_3_BIAS,
153+
VAD_TENSOR_LSTM_WEIGHT_IH,
154+
VAD_TENSOR_LSTM_WEIGHT_HH,
155+
VAD_TENSOR_LSTM_BIAS_IH,
156+
VAD_TENSOR_LSTM_BIAS_HH,
157+
VAD_TENSOR_FINAL_CONV_WEIGHT,
158+
VAD_TENSOR_FINAL_CONV_BIAS,
159+
};
160+
161+
static const std::map<vad_tensor, ggml_op> VAD_TENSOR_OPS = {
162+
{VAD_TENSOR_STFT_BASIS, GGML_OP_MUL_MAT},
163+
{VAD_TENSOR_ENC_0_WEIGHT, GGML_OP_MUL_MAT},
164+
{VAD_TENSOR_ENC_0_BIAS, GGML_OP_ADD},
165+
{VAD_TENSOR_ENC_1_WEIGHT, GGML_OP_MUL_MAT},
166+
{VAD_TENSOR_ENC_1_BIAS, GGML_OP_ADD},
167+
{VAD_TENSOR_ENC_2_WEIGHT, GGML_OP_MUL_MAT},
168+
{VAD_TENSOR_ENC_2_BIAS, GGML_OP_ADD},
169+
{VAD_TENSOR_ENC_3_WEIGHT, GGML_OP_MUL_MAT},
170+
{VAD_TENSOR_ENC_3_BIAS, GGML_OP_ADD},
171+
172+
{VAD_TENSOR_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
173+
{VAD_TENSOR_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
174+
{VAD_TENSOR_LSTM_BIAS_IH, GGML_OP_ADD},
175+
{VAD_TENSOR_LSTM_BIAS_HH, GGML_OP_ADD},
176+
177+
{VAD_TENSOR_FINAL_CONV_WEIGHT, GGML_OP_MUL_MAT},
178+
{VAD_TENSOR_FINAL_CONV_BIAS, GGML_OP_ADD}
179+
};
180+
181+
static const std::map<vad_tensor, const char *> VAD_TENSOR_NAMES = {
182+
{VAD_TENSOR_STFT_BASIS, "_model.stft.forward_basis_buffer"},
183+
{VAD_TENSOR_ENC_0_WEIGHT, "_model.encoder.0.reparam_conv.weight"},
184+
{VAD_TENSOR_ENC_0_BIAS, "_model.encoder.0.reparam_conv.bias"},
185+
{VAD_TENSOR_ENC_1_WEIGHT, "_model.encoder.1.reparam_conv.weight"},
186+
{VAD_TENSOR_ENC_1_BIAS, "_model.encoder.1.reparam_conv.bias"},
187+
{VAD_TENSOR_ENC_2_WEIGHT, "_model.encoder.2.reparam_conv.weight"},
188+
{VAD_TENSOR_ENC_2_BIAS, "_model.encoder.2.reparam_conv.bias"},
189+
{VAD_TENSOR_ENC_3_WEIGHT, "_model.encoder.3.reparam_conv.weight"},
190+
{VAD_TENSOR_ENC_3_BIAS, "_model.encoder.3.reparam_conv.bias"},
191+
{VAD_TENSOR_LSTM_WEIGHT_IH, "_model.decoder.rnn.weight_ih"},
192+
{VAD_TENSOR_LSTM_WEIGHT_HH, "_model.decoder.rnn.weight_hh"},
193+
{VAD_TENSOR_LSTM_BIAS_IH, "_model.decoder.rnn.bias_ih"},
194+
{VAD_TENSOR_LSTM_BIAS_HH, "_model.decoder.rnn.bias_hh"},
195+
{VAD_TENSOR_FINAL_CONV_WEIGHT, "_model.decoder.decoder.2.weight"},
196+
{VAD_TENSOR_FINAL_CONV_BIAS, "_model.decoder.decoder.2.bias"}
197+
};

0 commit comments

Comments
 (0)