Skip to content

Commit 0e2c260

Browse files
committed
vad : add initial Voice Activity Detection (VAD) support
This commit add support for Voice Activity Detection (VAD). This is currently a work in progress and is not yet fully functional. A silero-vad model can be converted using: ```console $ (venv) pip install silero-vad $ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin ``` And there is test the tests the VAD support in isolation: ```console $ cmake --build build --target test-vad && \ ctest -R ^test-vad$ --test-dir build --output-on-failure -VV ``` And one that tests VAD in combination with whisper_full: ```console $ cmake --build build --target test-vad-full && \ ctest -R test-vad-full --test-dir build --output-on-failure -VV ``` Resolves: ggml-org#3003
1 parent 3a88f1e commit 0e2c260

8 files changed

+2079
-5
lines changed

include/whisper.h

+74
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,16 @@ extern "C" {
570570
size_t n_grammar_rules;
571571
size_t i_start_rule;
572572
float grammar_penalty;
573+
574+
// Voice Activity Detection (VAD) params
575+
bool vad; // Enable VAD
576+
const char * vad_model_path; // Path to VAD model
577+
float vad_threshold; // Probability threshold to consider as speech.
578+
int vad_min_speech_duration_ms; // Min duration for a valid speech segment.
579+
int vad_min_silence_duration_ms; // Min silence duration to consider speech as ended.
580+
float vad_max_speech_duration_s; // Max duration of a speech segment before forcing a break.
581+
int vad_speech_pad_ms; // Padding added before and after speech segments.
582+
int vad_window_size_samples; // Number of audio samples in each probability window.
573583
};
574584

575585
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
@@ -652,6 +662,70 @@ extern "C" {
652662
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
653663
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
654664

665+
// Voice Activity Detection (VAD)
666+
struct whisper_vad_context;
667+
struct whisper_vad_state;
668+
669+
struct whisper_vad_params {
670+
float threshold; // Probability threshold to consider as speech.
671+
int min_speech_duration_ms; // Min duration for a valid speech segment.
672+
int min_silence_duration_ms; // Min silence duration to consider speech as ended.
673+
float max_speech_duration_s; // Max duration of a speech segment before forcing a break.
674+
int speech_pad_ms; // Padding added before and after speech segments.
675+
int window_size_samples; // Number of audio samples in each probability window.
676+
};
677+
WHISPER_API struct whisper_vad_params whisper_vad_default_params(void);
678+
679+
struct whisper_vad_context_params {
680+
int n_threads; // The number of threads to use for processing.
681+
};
682+
WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void);
683+
684+
WHISPER_API struct whisper_vad_state * whisper_vad_init_state(struct whisper_vad_context * ctx);
685+
686+
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(
687+
const char * path_model,
688+
const whisper_vad_context_params params);
689+
690+
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
691+
const char * path_model,
692+
const whisper_vad_context_params params);
693+
694+
struct whisper_vad_speech {
695+
int n_probs;
696+
float * probs;
697+
};
698+
699+
WHISPER_API struct whisper_vad_speech whisper_vad_detect_speech(
700+
whisper_vad_context * vctx,
701+
const float * pcmf32, int n_samples);
702+
703+
struct whisper_vad_segment {
704+
float start; // Start time in seconds
705+
float end; // End time in seconds
706+
};
707+
708+
struct whisper_vad_timestamps {
709+
int n_segments;
710+
whisper_vad_segment * segments;
711+
};
712+
713+
WHISPER_API struct whisper_vad_timestamps whisper_vad_detect_speech_timestamps(
714+
whisper_vad_context * vctx,
715+
whisper_vad_params params,
716+
const float * pcmf32, int n_samples);
717+
718+
WHISPER_API struct whisper_vad_timestamps whisper_vad_timestamps_from_probs(
719+
whisper_vad_context * vctx,
720+
whisper_vad_params params,
721+
struct whisper_vad_speech * probs);
722+
723+
WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
724+
WHISPER_API void whisper_vad_free_state (struct whisper_vad_state * state);
725+
WHISPER_API void whisper_vad_free_params (struct whisper_vad_params * params);
726+
WHISPER_API void whisper_vad_free_speech (struct whisper_vad_speech * speech);
727+
WHISPER_API void whisper_vad_free_timestamps(struct whisper_vad_timestamps * timestamps);
728+
655729
////////////////////////////////////////////////////////////////////////////
656730

657731
// Temporary helpers needed for exposing ggml interface

models/convert-silero-vad-to-ggml.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import os
2+
import struct
3+
import argparse
4+
import torch
5+
import numpy as np
6+
from silero_vad import load_silero_vad, __version__ as silero_version
7+
8+
def convert_silero_vad(output_path, print_tensors=True):
9+
model = load_silero_vad()
10+
state_dict = model.state_dict()
11+
12+
# Clean up state dict keys - filter out 8k model
13+
cleaned_dict = {}
14+
for key, value in state_dict.items():
15+
# Skip 8k model
16+
if "_8k" not in key:
17+
clean_key = key
18+
if not key.startswith("_model."):
19+
clean_key = "_model." + key
20+
cleaned_dict[clean_key] = value
21+
22+
base, ext = os.path.splitext(output_path)
23+
output_file = f"{base}-v{silero_version}-ggml{ext}"
24+
print(f"Saving GGML Silero-VAD model to {output_file}")
25+
26+
print("\nTensor info for debugging:")
27+
for key, tensor in cleaned_dict.items():
28+
print(f" - {key}: {tensor.shape} ({tensor.dtype})")
29+
print()
30+
31+
with open(output_file, "wb") as fout:
32+
# Write magic and version
33+
fout.write(struct.pack("i", 0x67676d6c))
34+
35+
36+
# Write model architecture parameters
37+
window_size = 512
38+
fout.write(struct.pack("i", window_size))
39+
context_size = 64
40+
fout.write(struct.pack("i", context_size))
41+
42+
n_encoder_layers = 4
43+
fout.write(struct.pack("i", n_encoder_layers))
44+
45+
# Write encoder dimensions
46+
input_channels = 129
47+
encoder_in_channels = [input_channels, 128, 64, 64]
48+
encoder_out_channels = [128, 64, 64, 128]
49+
kernel_size = 3
50+
51+
for i in range(n_encoder_layers):
52+
fout.write(struct.pack("i", encoder_in_channels[i]))
53+
fout.write(struct.pack("i", encoder_out_channels[i]))
54+
fout.write(struct.pack("i", kernel_size))
55+
56+
# Write LSTM dimensions
57+
lstm_input_size = 128
58+
lstm_hidden_size = 128
59+
fout.write(struct.pack("i", lstm_input_size))
60+
fout.write(struct.pack("i", lstm_hidden_size))
61+
62+
# Write final conv dimensions
63+
final_conv_in = 128
64+
final_conv_out = 1
65+
fout.write(struct.pack("i", final_conv_in))
66+
fout.write(struct.pack("i", final_conv_out))
67+
68+
# Define tensor keys to write
69+
tensor_keys = []
70+
71+
# Encoder weights
72+
for i in range(n_encoder_layers):
73+
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
74+
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
75+
if weight_key in cleaned_dict and bias_key in cleaned_dict:
76+
tensor_keys.append(weight_key)
77+
tensor_keys.append(bias_key)
78+
79+
# LSTM weights
80+
lstm_keys = [
81+
"_model.decoder.rnn.weight_ih",
82+
"_model.decoder.rnn.weight_hh",
83+
"_model.decoder.rnn.bias_ih",
84+
"_model.decoder.rnn.bias_hh"
85+
]
86+
tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
87+
88+
# Final conv weights
89+
final_keys = [
90+
"_model.decoder.decoder.2.weight",
91+
"_model.decoder.decoder.2.bias"
92+
]
93+
tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
94+
95+
# STFT basis - add this last
96+
stft_tensor = "_model.stft.forward_basis_buffer"
97+
tensor_keys.append(stft_tensor)
98+
99+
print(f"Writing {len(tensor_keys)} tensors:")
100+
for key in tensor_keys:
101+
if key in cleaned_dict:
102+
print(f" - {key}: {cleaned_dict[key].shape}")
103+
else:
104+
print(f" - {key}: MISSING")
105+
106+
# Process each tensor
107+
for key in tensor_keys:
108+
if key not in cleaned_dict:
109+
print(f"Warning: Missing tensor {key}, skipping")
110+
continue
111+
112+
tensor = cleaned_dict[key]
113+
114+
# Special handling for STFT tensor
115+
if key == "_model.stft.forward_basis_buffer":
116+
# Get the original numpy array without squeezing
117+
data = tensor.detach().cpu().numpy()
118+
# Ensure it has the expected shape
119+
print(f"STFT tensor original shape: {data.shape}")
120+
n_dims = 3
121+
tensor_shape = [data.shape[0], data.shape[1], data.shape[2]]
122+
is_conv_weight = True
123+
else:
124+
# For other tensors, we can use standard processing
125+
data = tensor.detach().cpu().squeeze().numpy()
126+
tensor_shape = list(data.shape)
127+
128+
# Ensure we have at most 4 dimensions for GGML
129+
n_dims = min(len(tensor_shape), 4)
130+
131+
# Reverse dimensions for GGML
132+
tensor_shape = tensor_shape[:n_dims]
133+
tensor_shape.reverse()
134+
135+
# Check if this is a convolution weight tensor
136+
is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
137+
138+
# Convert to float16 for convolution weights
139+
if is_conv_weight:
140+
data = data.astype(np.float16)
141+
ftype = 1 # float16
142+
else:
143+
ftype = 0 # float32
144+
145+
# Debug printing of tensor info
146+
print(f"\nWriting tensor: {key}")
147+
print(f" Original shape: {tensor.shape}")
148+
print(f" Processed shape: {data.shape}")
149+
print(f" GGML dimensions: {n_dims}")
150+
print(f" GGML shape: {tensor_shape}")
151+
print(f" Type: {'float16' if ftype == 1 else 'float32'}")
152+
153+
# Convert tensor name to bytes
154+
name_bytes = key.encode('utf-8')
155+
name_length = len(name_bytes)
156+
157+
# Write tensor header
158+
fout.write(struct.pack("i", n_dims))
159+
fout.write(struct.pack("i", name_length))
160+
fout.write(struct.pack("i", ftype))
161+
162+
# Write tensor dimensions
163+
for i in range(n_dims):
164+
size = tensor_shape[i] if i < len(tensor_shape) else 1
165+
fout.write(struct.pack("i", size))
166+
print(f" Writing dimension {i}: {size}")
167+
168+
# Write tensor name
169+
fout.write(name_bytes)
170+
171+
# Write tensor data
172+
data.tofile(fout)
173+
174+
print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
175+
176+
print(f"\nDone! Model has been converted to GGML format: {output_file}")
177+
print(f"File size: {os.path.getsize(output_file)} bytes")
178+
179+
if __name__ == "__main__":
180+
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
181+
parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
182+
parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
183+
args = parser.parse_args()
184+
185+
convert_silero_vad(args.output, args.print_tensors)
864 KB
Binary file not shown.

src/whisper-arch.h

+56
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,59 @@ static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
139139
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
140140
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
141141
};
142+
143+
enum vad_tensor {
144+
VAD_TENSOR_STFT_BASIS,
145+
VAD_TENSOR_ENC_0_WEIGHT,
146+
VAD_TENSOR_ENC_0_BIAS,
147+
VAD_TENSOR_ENC_1_WEIGHT,
148+
VAD_TENSOR_ENC_1_BIAS,
149+
VAD_TENSOR_ENC_2_WEIGHT,
150+
VAD_TENSOR_ENC_2_BIAS,
151+
VAD_TENSOR_ENC_3_WEIGHT,
152+
VAD_TENSOR_ENC_3_BIAS,
153+
VAD_TENSOR_LSTM_WEIGHT_IH,
154+
VAD_TENSOR_LSTM_WEIGHT_HH,
155+
VAD_TENSOR_LSTM_BIAS_IH,
156+
VAD_TENSOR_LSTM_BIAS_HH,
157+
VAD_TENSOR_FINAL_CONV_WEIGHT,
158+
VAD_TENSOR_FINAL_CONV_BIAS,
159+
};
160+
161+
static const std::map<vad_tensor, ggml_op> VAD_TENSOR_OPS = {
162+
{VAD_TENSOR_STFT_BASIS, GGML_OP_MUL_MAT},
163+
{VAD_TENSOR_ENC_0_WEIGHT, GGML_OP_MUL_MAT},
164+
{VAD_TENSOR_ENC_0_BIAS, GGML_OP_ADD},
165+
{VAD_TENSOR_ENC_1_WEIGHT, GGML_OP_MUL_MAT},
166+
{VAD_TENSOR_ENC_1_BIAS, GGML_OP_ADD},
167+
{VAD_TENSOR_ENC_2_WEIGHT, GGML_OP_MUL_MAT},
168+
{VAD_TENSOR_ENC_2_BIAS, GGML_OP_ADD},
169+
{VAD_TENSOR_ENC_3_WEIGHT, GGML_OP_MUL_MAT},
170+
{VAD_TENSOR_ENC_3_BIAS, GGML_OP_ADD},
171+
172+
{VAD_TENSOR_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
173+
{VAD_TENSOR_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
174+
{VAD_TENSOR_LSTM_BIAS_IH, GGML_OP_ADD},
175+
{VAD_TENSOR_LSTM_BIAS_HH, GGML_OP_ADD},
176+
177+
{VAD_TENSOR_FINAL_CONV_WEIGHT, GGML_OP_MUL_MAT},
178+
{VAD_TENSOR_FINAL_CONV_BIAS, GGML_OP_ADD}
179+
};
180+
181+
static const std::map<vad_tensor, const char *> VAD_TENSOR_NAMES = {
182+
{VAD_TENSOR_STFT_BASIS, "_model.stft.forward_basis_buffer"},
183+
{VAD_TENSOR_ENC_0_WEIGHT, "_model.encoder.0.reparam_conv.weight"},
184+
{VAD_TENSOR_ENC_0_BIAS, "_model.encoder.0.reparam_conv.bias"},
185+
{VAD_TENSOR_ENC_1_WEIGHT, "_model.encoder.1.reparam_conv.weight"},
186+
{VAD_TENSOR_ENC_1_BIAS, "_model.encoder.1.reparam_conv.bias"},
187+
{VAD_TENSOR_ENC_2_WEIGHT, "_model.encoder.2.reparam_conv.weight"},
188+
{VAD_TENSOR_ENC_2_BIAS, "_model.encoder.2.reparam_conv.bias"},
189+
{VAD_TENSOR_ENC_3_WEIGHT, "_model.encoder.3.reparam_conv.weight"},
190+
{VAD_TENSOR_ENC_3_BIAS, "_model.encoder.3.reparam_conv.bias"},
191+
{VAD_TENSOR_LSTM_WEIGHT_IH, "_model.decoder.rnn.weight_ih"},
192+
{VAD_TENSOR_LSTM_WEIGHT_HH, "_model.decoder.rnn.weight_hh"},
193+
{VAD_TENSOR_LSTM_BIAS_IH, "_model.decoder.rnn.bias_ih"},
194+
{VAD_TENSOR_LSTM_BIAS_HH, "_model.decoder.rnn.bias_hh"},
195+
{VAD_TENSOR_FINAL_CONV_WEIGHT, "_model.decoder.decoder.2.weight"},
196+
{VAD_TENSOR_FINAL_CONV_BIAS, "_model.decoder.decoder.2.bias"}
197+
};

0 commit comments

Comments
 (0)