Skip to content

Commit 34c5ec6

Browse files
committed
vad : add initial Voice Activity Detection (VAD) support
This commit add support for Voice Activity Detection (VAD). This is currently a work in progress and is not yet fully functional. A silero-vad model can be converted using: ```console $ (venv) pip install silero-vad $ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin ``` And there is test that can be run using the following command: ```console $ cmake --build build --target test-vad && \ ctest -R test-vad --test-dir build --output-on-failure -VV ```
1 parent 43f5030 commit 34c5ec6

File tree

6 files changed

+1324
-0
lines changed

6 files changed

+1324
-0
lines changed

include/whisper.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,47 @@ extern "C" {
652652
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
653653
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);
654654

655+
// Voice Activity Detection (VAD)
656+
struct whisper_vad_context;
657+
struct whisper_vad_state;
658+
659+
struct whisper_vad_params {
660+
float threshold; // Probability threshold for speech detection
661+
int min_speech_duration_ms; // Minimum speech segment duration
662+
int min_silence_duration_ms; // Minimum silence segment duration
663+
int window_size_samples; // Window size for processing
664+
int sample_rate; // 16000
665+
};
666+
WHISPER_API struct whisper_vad_params whisper_vad_default_params(void);
667+
668+
WHISPER_API struct whisper_vad_state * whisper_vad_init_state(struct whisper_vad_context * ctx);
669+
670+
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(
671+
const char * path_model,
672+
const whisper_vad_params params);
673+
674+
WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params_no_state(
675+
const char * path_model,
676+
const whisper_vad_params params);
677+
678+
struct whisper_vad_segment {
679+
float start; // Start time in seconds
680+
float end; // End time in seconds
681+
};
682+
683+
struct whisper_vad_segments {
684+
int n_segments;
685+
whisper_vad_segment * segments;
686+
};
687+
688+
WHISPER_API struct whisper_vad_segments whisper_vad_detect_speech(
689+
whisper_vad_context * vctx,
690+
const float * pcmf32, int n_samples, int n_threads);
691+
692+
WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
693+
WHISPER_API void whisper_vad_free_state (struct whisper_vad_state * state);
694+
WHISPER_API void whisper_vad_free_params(struct whisper_vad_params * params);
695+
655696
////////////////////////////////////////////////////////////////////////////
656697

657698
// Temporary helpers needed for exposing ggml interface

models/convert-silero-vad-to-ggml.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import os
2+
import struct
3+
import argparse
4+
import torch
5+
import numpy as np
6+
from silero_vad import load_silero_vad, __version__ as silero_version
7+
8+
def convert_silero_vad(output_path, print_tensors=True):
9+
model = load_silero_vad()
10+
state_dict = model.state_dict()
11+
12+
# Clean up state dict keys - filter out 8k model
13+
cleaned_dict = {}
14+
for key, value in state_dict.items():
15+
# Skip 8k model
16+
if "_8k" not in key:
17+
clean_key = key
18+
if not key.startswith("_model."):
19+
clean_key = "_model." + key
20+
cleaned_dict[clean_key] = value
21+
22+
base, ext = os.path.splitext(output_path)
23+
output_file = f"{base}-v{silero_version}-ggml{ext}"
24+
print(f"Saving GGML Silero-VAD model to {output_file}")
25+
26+
print("\nTensor info for debugging:")
27+
for key, tensor in cleaned_dict.items():
28+
print(f" - {key}: {tensor.shape} ({tensor.dtype})")
29+
print()
30+
31+
with open(output_file, "wb") as fout:
32+
# Write magic and version
33+
fout.write(struct.pack("i", 0x67676d6c))
34+
35+
# Write model version - Try version 0 for simplicity
36+
fout.write(struct.pack("i", 0))
37+
38+
# Write model architecture parameters
39+
n_encoder_layers = 4
40+
fout.write(struct.pack("i", n_encoder_layers))
41+
42+
# Write encoder dimensions
43+
input_channels = 129
44+
encoder_in_channels = [input_channels, 128, 64, 64]
45+
encoder_out_channels = [128, 64, 64, 128]
46+
kernel_size = 3
47+
48+
for i in range(n_encoder_layers):
49+
fout.write(struct.pack("i", encoder_in_channels[i]))
50+
fout.write(struct.pack("i", encoder_out_channels[i]))
51+
fout.write(struct.pack("i", kernel_size))
52+
53+
# Write LSTM dimensions
54+
lstm_input_size = 128
55+
lstm_hidden_size = 128
56+
fout.write(struct.pack("i", lstm_input_size))
57+
fout.write(struct.pack("i", lstm_hidden_size))
58+
59+
# Write final conv dimensions
60+
final_conv_in = 128
61+
final_conv_out = 1
62+
fout.write(struct.pack("i", final_conv_in))
63+
fout.write(struct.pack("i", final_conv_out))
64+
65+
# Define tensor keys to write
66+
tensor_keys = []
67+
68+
# Encoder weights
69+
for i in range(n_encoder_layers):
70+
weight_key = f"_model.encoder.{i}.reparam_conv.weight"
71+
bias_key = f"_model.encoder.{i}.reparam_conv.bias"
72+
if weight_key in cleaned_dict and bias_key in cleaned_dict:
73+
tensor_keys.append(weight_key)
74+
tensor_keys.append(bias_key)
75+
76+
# LSTM weights
77+
lstm_keys = [
78+
"_model.decoder.rnn.weight_ih",
79+
"_model.decoder.rnn.weight_hh",
80+
"_model.decoder.rnn.bias_ih",
81+
"_model.decoder.rnn.bias_hh"
82+
]
83+
tensor_keys.extend([k for k in lstm_keys if k in cleaned_dict])
84+
85+
# Final conv weights
86+
final_keys = [
87+
"_model.decoder.decoder.2.weight",
88+
"_model.decoder.decoder.2.bias"
89+
]
90+
tensor_keys.extend([k for k in final_keys if k in cleaned_dict])
91+
92+
# STFT basis - add this last
93+
stft_tensor = "_model.stft.forward_basis_buffer"
94+
tensor_keys.append(stft_tensor)
95+
96+
print(f"Writing {len(tensor_keys)} tensors:")
97+
for key in tensor_keys:
98+
if key in cleaned_dict:
99+
print(f" - {key}: {cleaned_dict[key].shape}")
100+
else:
101+
print(f" - {key}: MISSING")
102+
103+
# Process each tensor
104+
for key in tensor_keys:
105+
if key not in cleaned_dict:
106+
print(f"Warning: Missing tensor {key}, skipping")
107+
continue
108+
109+
tensor = cleaned_dict[key]
110+
111+
# Special handling for STFT tensor
112+
if key == "_model.stft.forward_basis_buffer":
113+
# Get the original numpy array without squeezing
114+
data = tensor.detach().cpu().numpy()
115+
# Ensure it has the expected shape
116+
print(f"STFT tensor original shape: {data.shape}")
117+
n_dims = 3
118+
tensor_shape = [data.shape[0], data.shape[1], data.shape[2]]
119+
is_conv_weight = True
120+
else:
121+
# For other tensors, we can use standard processing
122+
data = tensor.detach().cpu().squeeze().numpy()
123+
tensor_shape = list(data.shape)
124+
125+
# Ensure we have at most 4 dimensions for GGML
126+
n_dims = min(len(tensor_shape), 4)
127+
128+
# Reverse dimensions for GGML
129+
tensor_shape = tensor_shape[:n_dims]
130+
tensor_shape.reverse()
131+
132+
# Check if this is a convolution weight tensor
133+
is_conv_weight = "weight" in key and ("encoder" in key or "_model.decoder.decoder.2" in key)
134+
135+
# Convert to float16 for convolution weights
136+
if is_conv_weight:
137+
data = data.astype(np.float16)
138+
ftype = 1 # float16
139+
else:
140+
ftype = 0 # float32
141+
142+
# Debug printing of tensor info
143+
print(f"\nWriting tensor: {key}")
144+
print(f" Original shape: {tensor.shape}")
145+
print(f" Processed shape: {data.shape}")
146+
print(f" GGML dimensions: {n_dims}")
147+
print(f" GGML shape: {tensor_shape}")
148+
print(f" Type: {'float16' if ftype == 1 else 'float32'}")
149+
150+
# Convert tensor name to bytes
151+
name_bytes = key.encode('utf-8')
152+
name_length = len(name_bytes)
153+
154+
# Write tensor header
155+
fout.write(struct.pack("i", n_dims))
156+
fout.write(struct.pack("i", name_length))
157+
fout.write(struct.pack("i", ftype))
158+
159+
# Write tensor dimensions
160+
for i in range(n_dims):
161+
size = tensor_shape[i] if i < len(tensor_shape) else 1
162+
fout.write(struct.pack("i", size))
163+
print(f" Writing dimension {i}: {size}")
164+
165+
# Write tensor name
166+
fout.write(name_bytes)
167+
168+
# Write tensor data
169+
data.tofile(fout)
170+
171+
print(f" Wrote {data.size * (2 if ftype==1 else 4)} bytes")
172+
173+
print(f"\nDone! Model has been converted to GGML format: {output_file}")
174+
print(f"File size: {os.path.getsize(output_file)} bytes")
175+
176+
if __name__ == "__main__":
177+
parser = argparse.ArgumentParser(description="Convert Silero-VAD PyTorch model to GGML format")
178+
parser.add_argument("--output", type=str, required=True, help="Path to output GGML model file")
179+
parser.add_argument("--print-tensors", action="store_true", help="Print tensor values", default=True)
180+
args = parser.parse_args()
181+
182+
convert_silero_vad(args.output, args.print_tensors)

src/whisper-arch.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,59 @@ static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
139139
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
140140
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
141141
};
142+
143+
enum vad_tensor {
144+
VAD_TENSOR_STFT_BASIS,
145+
VAD_TENSOR_ENC_0_WEIGHT,
146+
VAD_TENSOR_ENC_0_BIAS,
147+
VAD_TENSOR_ENC_1_WEIGHT,
148+
VAD_TENSOR_ENC_1_BIAS,
149+
VAD_TENSOR_ENC_2_WEIGHT,
150+
VAD_TENSOR_ENC_2_BIAS,
151+
VAD_TENSOR_ENC_3_WEIGHT,
152+
VAD_TENSOR_ENC_3_BIAS,
153+
VAD_TENSOR_LSTM_WEIGHT_IH,
154+
VAD_TENSOR_LSTM_WEIGHT_HH,
155+
VAD_TENSOR_LSTM_BIAS_IH,
156+
VAD_TENSOR_LSTM_BIAS_HH,
157+
VAD_TENSOR_FINAL_CONV_WEIGHT,
158+
VAD_TENSOR_FINAL_CONV_BIAS,
159+
};
160+
161+
static const std::map<vad_tensor, ggml_op> VAD_TENSOR_OPS = {
162+
{VAD_TENSOR_STFT_BASIS, GGML_OP_MUL_MAT},
163+
{VAD_TENSOR_ENC_0_WEIGHT, GGML_OP_MUL_MAT},
164+
{VAD_TENSOR_ENC_0_BIAS, GGML_OP_ADD},
165+
{VAD_TENSOR_ENC_1_WEIGHT, GGML_OP_MUL_MAT},
166+
{VAD_TENSOR_ENC_1_BIAS, GGML_OP_ADD},
167+
{VAD_TENSOR_ENC_2_WEIGHT, GGML_OP_MUL_MAT},
168+
{VAD_TENSOR_ENC_2_BIAS, GGML_OP_ADD},
169+
{VAD_TENSOR_ENC_3_WEIGHT, GGML_OP_MUL_MAT},
170+
{VAD_TENSOR_ENC_3_BIAS, GGML_OP_ADD},
171+
172+
{VAD_TENSOR_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
173+
{VAD_TENSOR_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
174+
{VAD_TENSOR_LSTM_BIAS_IH, GGML_OP_ADD},
175+
{VAD_TENSOR_LSTM_BIAS_HH, GGML_OP_ADD},
176+
177+
{VAD_TENSOR_FINAL_CONV_WEIGHT, GGML_OP_MUL_MAT},
178+
{VAD_TENSOR_FINAL_CONV_BIAS, GGML_OP_ADD}
179+
};
180+
181+
static const std::map<vad_tensor, const char *> VAD_TENSOR_NAMES = {
182+
{VAD_TENSOR_STFT_BASIS, "_model.stft.forward_basis_buffer"},
183+
{VAD_TENSOR_ENC_0_WEIGHT, "_model.encoder.0.reparam_conv.weight"},
184+
{VAD_TENSOR_ENC_0_BIAS, "_model.encoder.0.reparam_conv.bias"},
185+
{VAD_TENSOR_ENC_1_WEIGHT, "_model.encoder.1.reparam_conv.weight"},
186+
{VAD_TENSOR_ENC_1_BIAS, "_model.encoder.1.reparam_conv.bias"},
187+
{VAD_TENSOR_ENC_2_WEIGHT, "_model.encoder.2.reparam_conv.weight"},
188+
{VAD_TENSOR_ENC_2_BIAS, "_model.encoder.2.reparam_conv.bias"},
189+
{VAD_TENSOR_ENC_3_WEIGHT, "_model.encoder.3.reparam_conv.weight"},
190+
{VAD_TENSOR_ENC_3_BIAS, "_model.encoder.3.reparam_conv.bias"},
191+
{VAD_TENSOR_LSTM_WEIGHT_IH, "_model.decoder.rnn.weight_ih"},
192+
{VAD_TENSOR_LSTM_WEIGHT_HH, "_model.decoder.rnn.weight_hh"},
193+
{VAD_TENSOR_LSTM_BIAS_IH, "_model.decoder.rnn.bias_ih"},
194+
{VAD_TENSOR_LSTM_BIAS_HH, "_model.decoder.rnn.bias_hh"},
195+
{VAD_TENSOR_FINAL_CONV_WEIGHT, "_model.decoder.decoder.2.weight"},
196+
{VAD_TENSOR_FINAL_CONV_BIAS, "_model.decoder.decoder.2.bias"}
197+
};

0 commit comments

Comments
 (0)