Skip to content

Add C++ runtime for spleeter about source separation #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/export-spleeter-to-onnx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: export-spleeter-to-onnx
on:
push:
branches:
- spleeter-2
- spleeter-cpp-2
workflow_dispatch:

concurrency:
Expand Down
1 change: 1 addition & 0 deletions cmake/cmake_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_binaries():
"sherpa-onnx-offline-denoiser",
"sherpa-onnx-offline-language-identification",
"sherpa-onnx-offline-punctuation",
"sherpa-onnx-offline-source-separation",
"sherpa-onnx-offline-speaker-diarization",
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
Expand Down
4 changes: 2 additions & 2 deletions scripts/spleeter/convert_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def main(name):
# for the batchnormalization in torch,
# default input shape is NCHW

# NHWC to NCHW
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2))
torch_y1_out = torch_y1_out.permute(1, 0, 2, 3)

# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
assert torch.allclose(
Expand Down
4 changes: 2 additions & 2 deletions scripts/spleeter/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):

def export(model, prefix):
num_splits = 1
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)

filename = f"./2stems/{prefix}.onnx"
torch.onnx.export(
Expand All @@ -56,7 +56,7 @@ def export(model, prefix):
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "num_splits"},
"x": {1: "num_splits"},
},
opset_version=13,
)
Expand Down
8 changes: 6 additions & 2 deletions scripts/spleeter/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,17 @@ def main():
print("y2", y.shape, y.dtype)

y = y.abs()
y = y.permute(0, 3, 1, 2)
# (1, 2, 512, 1024)

y = y.permute(3, 0, 1, 2)
# (2, 1, 512, 1024)
print("y3", y.shape, y.dtype)

vocals_spec = vocals(y)
accompaniment_spec = accompaniment(y)

vocals_spec = vocals_spec.permute(1, 0, 2, 3)
accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3)

sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
print(
"vocals_spec",
Expand Down
21 changes: 10 additions & 11 deletions scripts/spleeter/separate_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@

"""
----------inputs for ./2stems/vocals.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/vocals.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])

----------inputs for ./2stems/accompaniment.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/accompaniment.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])

NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
"""


Expand Down Expand Up @@ -123,16 +122,16 @@ def main():
if padding > 0:
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
stft0 = stft0.reshape(-1, 1, 512, 1024)
stft1 = stft1.reshape(-1, 1, 512, 1024)
stft0 = stft0.reshape(1, -1, 512, 1024)
stft1 = stft1.reshape(1, -1, 512, 1024)

stft_01 = torch.cat([stft0, stft1], axis=1)
stft_01 = torch.cat([stft0, stft1], axis=0)

print("stft_01", stft_01.shape, stft_01.dtype)

vocals_spec = vocals(stft_01)
accompaniment_spec = accompaniment(stft_01)
# (num_splits, num_channels, 512, 1024)
# (num_channels, num_splits, 512, 1024)

sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10

Expand All @@ -142,8 +141,8 @@ def main():
for name, spec in zip(
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
):
spec_c0 = spec[:, 0, :, :]
spec_c1 = spec[:, 1, :, :]
spec_c0 = spec[0]
spec_c1 = spec[1]

spec_c0 = spec_c0.reshape(-1, 1024)
spec_c1 = spec_c1.reshape(-1, 1024)
Expand Down
11 changes: 10 additions & 1 deletion scripts/spleeter/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def __init__(self):
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)

def forward(self, x):
"""
Args:
x: (num_audio_channels, num_splits, 512, 1024)
Returns:
y: (num_audio_channels, num_splits, 512, 1024)
"""
x = x.permute(1, 0, 2, 3)

in_x = x
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
Expand Down Expand Up @@ -147,4 +155,5 @@ def forward(self, x):
up7 = self.up7(batch12)
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)

return up7 * in_x
ans = up7 * in_x
return ans.permute(1, 0, 2, 3)
9 changes: 9 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ set(sources
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-sense-voice-model.cc

offline-source-separation-impl.cc
offline-source-separation-model-config.cc
offline-source-separation-spleeter-model-config.cc
offline-source-separation-spleeter-model.cc
offline-source-separation.cc

offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
Expand Down Expand Up @@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc)
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc)

Expand All @@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-parallel
sherpa-onnx-offline-punctuation
sherpa-onnx-offline-source-separation
sherpa-onnx-online-punctuation
sherpa-onnx-vad
)
Expand Down
40 changes: 40 additions & 0 deletions sherpa-onnx/csrc/offline-source-separation-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// sherpa-onnx/csrc/offline-source-separation-impl.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-source-separation-impl.h"

#include <memory>

#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h"

namespace sherpa_onnx {

std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
const OfflineSourceSeparationConfig &config) {
// TODO(fangjun): Support other models
return std::make_unique<OfflineSourceSeparationSpleeterImpl>(config);
}

template <typename Manager>
std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
Manager *mgr, const OfflineSourceSeparationConfig &config) {
// TODO(fangjun): Support other models
return std::make_unique<OfflineSourceSeparationSpleeterImpl>(mgr, config);
}

#if __ANDROID_API__ >= 9
template std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
#endif

#if __OHOS__
template std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
#endif

} // namespace sherpa_onnx
35 changes: 35 additions & 0 deletions sherpa-onnx/csrc/offline-source-separation-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// sherpa-onnx/csrc/offline-source-separation-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_

#include <vector>

#include "sherpa-onnx/csrc/offline-source-separation.h"

namespace sherpa_onnx {

class OfflineSourceSeparationImpl {
public:
static std::unique_ptr<OfflineSourceSeparationImpl> Create(
const OfflineSourceSeparationConfig &config);

template <typename Manager>
static std::unique_ptr<OfflineSourceSeparationImpl> Create(
Manager *mgr, const OfflineSourceSeparationConfig &config);

virtual ~OfflineSourceSeparationImpl() = default;

virtual OfflineSourceSeparationOutput Process(
const OfflineSourceSeparationInput &input) const = 0;

virtual int32_t GetOutputSampleRate() const = 0;

virtual int32_t GetNumberOfStems() const = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
38 changes: 38 additions & 0 deletions sherpa-onnx/csrc/offline-source-separation-model-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// sherpa-onnx/csrc/offline-source-separation-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"

namespace sherpa_onnx {

void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) {
spleeter.Register(po);

po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");

po->Register("debug", &debug,
"true to print model information while loading it.");

po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}

bool OfflineSourceSeparationModelConfig::Validate() const {
return spleeter.Validate();
}

std::string OfflineSourceSeparationModelConfig::ToString() const {
std::ostringstream os;

os << "OfflineSourceSeparationModelConfig(";
os << "spleeter=" << spleeter.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";

return os.str();
}

} // namespace sherpa_onnx
41 changes: 41 additions & 0 deletions sherpa-onnx/csrc/offline-source-separation-model-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// sherpa-onnx/csrc/offline-source-separation-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"

namespace sherpa_onnx {

struct OfflineSourceSeparationModelConfig {
OfflineSourceSeparationSpleeterModelConfig spleeter;

int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";

OfflineSourceSeparationModelConfig() = default;

OfflineSourceSeparationModelConfig(
const OfflineSourceSeparationSpleeterModelConfig &spleeter,
int32_t num_threads, bool debug, const std::string &provider)
: spleeter(spleeter),
num_threads(num_threads),
debug(debug),
provider(provider) {}

void Register(ParseOptions *po);

bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
Loading
Loading