Skip to content

Add C++ runtime for silero_vad with RKNN #2078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions c-api-examples/vad-whisper-c-api.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ int32_t main() {

while (!is_eof) {
if (i + window_size < wave->num_samples) {
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
window_size);
}
else {
SherpaOnnxVoiceActivityDetectorFlush(vad);
is_eof = 1;
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
window_size);
} else {
SherpaOnnxVoiceActivityDetectorFlush(vad);
is_eof = 1;
}
while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) {
const SherpaOnnxSpeechSegment *segment =
Expand Down
75 changes: 75 additions & 0 deletions scripts/gtcrn/show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)

import onnxruntime
import onnx

"""
[key: "model_type"
value: "gtcrn"
, key: "comment"
value: "gtcrn_simple"
, key: "version"
value: "1"
, key: "sample_rate"
value: "16000"
, key: "model_url"
value: "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx"
, key: "maintainer"
value: "k2-fsa"
, key: "comment2"
value: "Please see also https://github.com/Xiaobin-Rong/gtcrn"
, key: "conv_cache_shape"
value: "2,1,16,16,33"
, key: "tra_cache_shape"
value: "2,3,1,1,16"
, key: "inter_cache_shape"
value: "2,1,33,16"
, key: "n_fft"
value: "512"
, key: "hop_length"
value: "256"
, key: "window_length"
value: "512"
, key: "window_type"
value: "hann_sqrt"
]
"""

"""
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
-----
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
"""


def show(filename):
model = onnx.load(filename)
print(model.metadata_props)

session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3
sess = onnxruntime.InferenceSession(
filename, session_opts, providers=["CPUExecutionProvider"]
)
for i in sess.get_inputs():
print(i)

print("-----")

for i in sess.get_outputs():
print(i)


def main():
show("./gtcrn_simple.onnx")


if __name__ == "__main__":
main()
81 changes: 80 additions & 1 deletion scripts/silero_vad/v4/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,94 @@
import torch
from onnxsim import simplify

import torch
from torch import Tensor


def simple_pad(x: Tensor, pad: int) -> Tensor:
# _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad))
_0 = x[:, :, 1 : 1 + pad]

left_pad = torch.flip(_0, [-1])
# _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1)

_1 = x[:, :, (-1 - pad) : -1]

right_pad = torch.flip(_1, [-1])
_2 = torch.cat([left_pad, x, right_pad], 2)
return _2


class MyModule(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m

def adaptive_normalization_forward(self, spect):
m = self.m._model.adaptive_normalization
_0 = simple_pad

# Note(fangjun): rknn uses fp16 by default, whose max value is 65504
# so we need to re-write the computation for spect0
# spect0 = torch.log1p(torch.mul(spect, 1048576))
spect0 = torch.log1p(spect) + 13.86294

_1 = torch.eq(len(spect0.shape), 2)
if _1:
_2 = torch.unsqueeze(spect0, 0)
spect1 = _2
else:
spect1 = spect0
mean = torch.mean(spect1, [1], True)
to_pad = m.to_pad
mean0 = _0(
mean,
to_pad,
)
filter_ = m.filter_
mean1 = torch.conv1d(mean0, filter_)
mean_mean = torch.mean(mean1, [-1], True)
spect2 = torch.add(spect1, torch.neg(mean_mean))
return spect2

def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
m = self.m._model

feature_extractor = m.feature_extractor
x0 = (feature_extractor).forward(
x,
)
norm = self.adaptive_normalization_forward(x0)
x1 = torch.cat([x0, norm], 1)
first_layer = m.first_layer
x2 = (first_layer).forward(
x1,
)
encoder = m.encoder
x3 = (encoder).forward(
x2,
)
decoder = m.decoder
x4, h0, c0, = (decoder).forward(
x3,
h,
c,
)
_0 = torch.mean(torch.squeeze(x4, 1), [1])
out = torch.unsqueeze(_0, 1)
return (out, h0, c0)


@torch.no_grad()
def main():
m = torch.jit.load("./silero_vad.jit")
m = MyModule(m)
x = torch.rand((1, 512), dtype=torch.float32)
h = torch.rand((2, 1, 64), dtype=torch.float32)
c = torch.rand((2, 1, 64), dtype=torch.float32)
m = torch.jit.script(m)
torch.onnx.export(
m._model,
m,
(x, h, c),
"m.onnx",
input_names=["x", "h", "c"],
Expand Down
2 changes: 1 addition & 1 deletion scripts/silero_vad/v4/show.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)

import onnxruntime
import onnx
Expand Down
141 changes: 141 additions & 0 deletions scripts/silero_vad/v4/test-on-rk3588-board.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)

# Please run this file on your rk3588 board

try:
from rknnlite.api import RKNNLite
except:
print("Please run this file on your board (linux + aarch64 + npu)")
print("You need to install rknn_toolkit_lite2")
print(
" from https://github.com/airockchip/rknn-toolkit2/tree/master/rknn-toolkit-lite2/packages"
)
print(
"https://github.com/airockchip/rknn-toolkit2/blob/v2.1.0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.1.0-cp310-cp310-linux_aarch64.whl"
)
print("is known to work")
raise

import time
from pathlib import Path
from typing import Tuple

import numpy as np
import soundfile as sf


def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel

samples = np.ascontiguousarray(data)
return samples, sample_rate


def init_model(filename, target_platform="rk3588"):
if not Path(filename).is_file():
exit(f"{filename} does not exist")

rknn_lite = RKNNLite(verbose=False)
ret = rknn_lite.load_rknn(path=filename)
if ret != 0:
exit(f"Load model {filename} failed!")

ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
exit(f"Failed to init rknn runtime for {filename}")
return rknn_lite


class RKNNModel:
def __init__(self, model: str, target_platform="rk3588"):
self.model = init_model(model)

def release(self):
self.model.release()

def __call__(self, x: np.ndarray, h: np.ndarray, c: np.ndarray):
"""
Args:
x: (1, 512), np.float32
h: (2, 1, 64), np.float32
c: (2, 1, 64), np.float32
Returns:
prob:
next_h:
next_c
"""
out, next_h, next_c = self.model.inference(inputs=[x, h, c])
return out.item(), next_h, next_c


def main():
model = RKNNModel(model="./m.rknn")
for i in range(1):
test(model)


def test(model):
print("started")
start = time.time()
samples, sample_rate = load_audio("./lei-jun-test.wav")
assert sample_rate == 16000, sample_rate

window_size = 512

h = np.zeros((2, 1, 64), dtype=np.float32)
c = np.zeros((2, 1, 64), dtype=np.float32)

threshold = 0.5
num_windows = samples.shape[0] // window_size
out = []
for i in range(num_windows):
print(i, num_windows)
this_samples = samples[i * window_size : (i + 1) * window_size]
prob, h, c = model(this_samples[None], h, c)
out.append(prob > threshold)

min_speech_duration = 0.25 * sample_rate / window_size
min_silence_duration = 0.25 * sample_rate / window_size

result = []
last = -1
for k, f in enumerate(out):
if f >= threshold:
if last == -1:
last = k
elif last != -1:
if k - last > min_speech_duration:
result.append((last, k))
last = -1

if last != -1 and k - last > min_speech_duration:
result.append((last, k))

if not result:
print("Empty for ./lei-jun-test.wav")
return

print(result)

final = [result[0]]
for r in result[1:]:
f = final[-1]
if r[0] - f[1] < min_silence_duration:
final[-1] = (f[0], r[1])
else:
final.append(r)

for f in final:
start = f[0] * window_size / sample_rate
end = f[1] * window_size / sample_rate
print("{:.3f} -- {:.3f}".format(start, end))


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions scripts/silero_vad/v4/test-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ def main():
h, c = model.get_init_states()
window_size = 512
num_windows = samples.shape[0] // window_size

for i in range(num_windows):
start = i * window_size
end = start + window_size

p, h, c = model(samples[start:end], h, c)

probs.append(p[0].item())

threshold = 0.5
Expand Down
6 changes: 4 additions & 2 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
./rknn/online-zipformer-ctc-model-rknn.cc
./rknn/online-zipformer-transducer-model-rknn.cc
./rknn/silero-vad-model-rknn.cc
./rknn/utils.cc
)

Expand Down Expand Up @@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
microphone.cc
)


add_executable(sherpa-onnx-microphone-offline
sherpa-onnx-microphone-offline.cc
microphone.cc
Expand Down Expand Up @@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
)

set(exes
sherpa-onnx-microphone
sherpa-onnx-keyword-spotter-microphone
sherpa-onnx-microphone
sherpa-onnx-microphone-offline
sherpa-onnx-microphone-offline-speaker-identification
sherpa-onnx-microphone-offline-audio-tagging
sherpa-onnx-microphone-offline-speaker-identification
sherpa-onnx-vad-microphone
sherpa-onnx-vad-microphone-offline-asr
sherpa-onnx-vad-with-offline-asr
Expand Down
Loading
Loading