From 4208b68ba4273fb042042ac0ebf6b9f809e93a14 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 22 May 2025 17:05:51 +0800 Subject: [PATCH 01/14] begin to add files --- sherpa-onnx/csrc/offline-source-separation-impl.cc | 0 sherpa-onnx/csrc/offline-source-separation-impl.h | 0 sherpa-onnx/csrc/offline-source-separation-model-config.cc | 0 sherpa-onnx/csrc/offline-source-separation-model-config.h | 0 sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h | 0 .../csrc/offline-source-separation-spleeter-model-config.cc | 0 .../csrc/offline-source-separation-spleeter-model-config.h | 0 .../csrc/offline-source-separation-spleeter-model-meta-data.h | 0 sherpa-onnx/csrc/offline-source-separation.cc | 0 sherpa-onnx/csrc/offline-source-separation.h | 0 .../csrc/offline-speech-denoiser-gtcrn-model-meta-data.h | 2 +- sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc | 2 +- 12 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-source-separation-impl.cc create mode 100644 sherpa-onnx/csrc/offline-source-separation-impl.h create mode 100644 sherpa-onnx/csrc/offline-source-separation-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-source-separation-model-config.h create mode 100644 sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h create mode 100644 sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h create mode 100644 sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h create mode 100644 sherpa-onnx/csrc/offline-source-separation.cc create mode 100644 sherpa-onnx/csrc/offline-source-separation.h diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.cc b/sherpa-onnx/csrc/offline-source-separation-impl.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.h b/sherpa-onnx/csrc/offline-source-separation-impl.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-model-config.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.h b/sherpa-onnx/csrc/offline-source-separation-model-config.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation.cc b/sherpa-onnx/csrc/offline-source-separation.cc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-source-separation.h b/sherpa-onnx/csrc/offline-source-separation.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h index 8cf0cdabf8..36ebf83b69 100644 --- a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h @@ -12,7 +12,7 @@ namespace sherpa_onnx { // please refer to -// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/gtcrn/add_meta_data.py struct OfflineSpeechDenoiserGtcrnModelMetaData { int32_t sample_rate = 0; int32_t version = 1; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc index 61d6afd259..0ec314879d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc @@ -11,7 +11,7 @@ int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( -Non-stremaing speech denoising with sherpa-onnx. +Non-streaming speech denoising with sherpa-onnx. Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models From 4c88a6a0eca05db3ad7669cca2def187cf6923f2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 11:08:04 +0800 Subject: [PATCH 02/14] add source separation --- sherpa-onnx/csrc/offline-source-separation.cc | 72 ++++++++++++++++++ sherpa-onnx/csrc/offline-source-separation.h | 75 +++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/sherpa-onnx/csrc/offline-source-separation.cc b/sherpa-onnx/csrc/offline-source-separation.cc index e69de29bb2..233036ed64 100644 --- a/sherpa-onnx/csrc/offline-source-separation.cc +++ b/sherpa-onnx/csrc/offline-source-separation.cc @@ -0,0 +1,72 @@ +// sherpa-onnx/csrc/offline-source-separation.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +namespace sherpa_onnx { + +void OfflineSourceSeparationConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OfflineSourceSeparationConfig::Validate() const { + return model.Validate(); +} + +std::string OfflineSourceSeparationConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSourceSeparationConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +template +OfflineSourceSeparation::OfflineSourceSeparation( + Manager *mgr, const OfflineSourceSeparationConfig &config) + : impl_(OfflineRecognizerImpl::Create(mgr, config)) {} + +OfflineSourceSeparation::OfflineSourceSeparation( + const OfflineSourceSeparationConfig &config) + : impl_(OfflineSourceSeparationImpl::Create(config)) {} + +OfflineSourceSeparation::~OfflineSourceSeparation() = default; + +OfflineSourceSeparationOutput OfflineSourceSeparation::Process( + const OfflineSourceSeparationInput &input) const { + return impl_->Process(input); +} + +int32_t OfflineSourceSeparation::GetOutputSampleRate() const { + return impl_->GetOutputSampleRate(); +} + +// e.g., it is 2 for 2stems from spleeter +int32_t OfflineSourceSeparation::GetNumberOfStems() const { + return impl_->GetNumberOfStems(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSourceSeparation::OfflineSourceSeparation( + AAssetManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +#if __OHOS__ +template OfflineSourceSeparation::OfflineSourceSeparation( + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation.h b/sherpa-onnx/csrc/offline-source-separation.h index e69de29bb2..2e12f5ee72 100644 --- a/sherpa-onnx/csrc/offline-source-separation.h +++ b/sherpa-onnx/csrc/offline-source-separation.h @@ -0,0 +1,75 @@ +// sherpa-onnx/csrc/offline-source-separation.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ + +#include +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSourceSeparationConfig { + OfflineSourceSeparationModelConfig model; + + OfflineSourceSeparationConfig() = default; + + OfflineSourceSeparationConfig(const OfflineSourceSeparationModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +struct MultiChannelSamples { + // data[i] is for the i-th channel + // + // each sample is in the range [-1, 1] + std::vector> data; +}; + +struct OfflineSourceSeparationInput { + MultiChannelSamples samples; + + int32_t sample_rate; +}; + +struct OfflineSourceSeparationOutput { + std::vector stems; + + int32_t sample_rate; +}; + +class OfflineSourceSeparationImpl; + +class OfflineSourceSeparation { + public: + ~OfflineSourceSeparation(); + + OfflineSourceSeparation(const OfflineSourceSeparationConfig &config); + + template + OfflineSourceSeparation(Manager *mgr, + const OfflineSourceSeparationConfig &config); + + OfflineSourceSeparationOutput Process( + const OfflineSourceSeparationInput &input) const; + + int32_t GetOutputSampleRate() const; + + // e.g., it is 2 for 2stems from spleeter + int32_t GetNumberOfStems() const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ From e50fe9ef3d04f97461e99e4477472180cbb6e442 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 12:07:54 +0800 Subject: [PATCH 03/14] add model config --- sherpa-onnx/csrc/CMakeLists.txt | 4 ++ .../offline-source-separation-model-config.cc | 26 +++++++++ .../offline-source-separation-model-config.h | 33 ++++++++++++ ...source-separation-spleeter-model-config.cc | 54 +++++++++++++++++++ ...-source-separation-spleeter-model-config.h | 35 ++++++++++++ sherpa-onnx/csrc/offline-source-separation.h | 2 + 6 files changed, 154 insertions(+) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 52fe3bb357..b2b4db788a 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -50,6 +50,10 @@ set(sources offline-rnn-lm.cc offline-sense-voice-model-config.cc offline-sense-voice-model.cc + + offline-source-separation-model-config.cc + offline-source-separation-spleeter-model-config.cc + offline-stream.cc offline-tdnn-ctc-model.cc offline-tdnn-model-config.cc diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-model-config.cc index e69de29bb2..168b002fae 100644 --- a/sherpa-onnx/csrc/offline-source-separation-model-config.cc +++ b/sherpa-onnx/csrc/offline-source-separation-model-config.cc @@ -0,0 +1,26 @@ +// sherpa-onnx/csrc/offline-source-separation-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" + +namespace sherpa_onnx { + +void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) { + spleeter.Register(po); +} + +bool OfflineSourceSeparationModelConfig::Validate() const { + return spleeter.Validate(); +} + +std::string OfflineSourceSeparationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSourceSeparationModelConfig("; + os << "spleeter=" << spleeter.ToString() << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.h b/sherpa-onnx/csrc/offline-source-separation-model-config.h index e69de29bb2..692ba09944 100644 --- a/sherpa-onnx/csrc/offline-source-separation-model-config.h +++ b/sherpa-onnx/csrc/offline-source-separation-model-config.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/offline-source-separation-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSourceSeparationModelConfig { + OfflineSourceSeparationSpleeterModelConfig spleeter; + + OfflineSourceSeparationModelConfig() = default; + + explicit OfflineSourceSeparationModelConfig( + const OfflineSourceSeparationSpleeterModelConfig &spleeter) + : spleeter(spleeter) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc index e69de29bb2..c43f693f91 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc @@ -0,0 +1,54 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSourceSeparationSpleeterModelConfig::Register(ParseOptions *po) { + po->Register("spleeter-vocals", &vocals, "Path to the spleeter vocals model"); + + po->Register("spleeter-accompaniment", &accompaniment, + "Path to the spleeter accompaniment model"); +} + +bool OfflineSourceSeparationSpleeterModelConfig::Validate() const { + if (vocals.empty()) { + SHERPA_ONNX_LOGE("Please provide --spleeter-vocals"); + return false; + } + + if (!FileExists(vocals)) { + SHERPA_ONNX_LOGE("spleeter vocals '%s' does not exist. ", vocals.c_str()); + return false; + } + + if (accompaniment.empty()) { + SHERPA_ONNX_LOGE("Please provide --spleeter-accompaniment"); + return false; + } + + if (!FileExists(accompaniment)) { + SHERPA_ONNX_LOGE("spleeter accompaniment '%s' does not exist. ", + accompaniment.c_str()); + return false; + } + + return true; +} + +std::string OfflineSourceSeparationSpleeterModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSourceSeparationSpleeterModelConfig("; + os << "vocals=\"" << vocals << "\", "; + os << "accompaniment=\"" << accompaniment << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h index e69de29bb2..7e868966ad 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSourceSeparationSpleeterModelConfig { + std::string vocals; + + std::string accompaniment; + + OfflineSourceSeparationSpleeterModelConfig() = default; + + OfflineSourceSeparationSpleeterModelConfig(const std::string &vocals, + const std::string &accompaniment) + : vocals(vocals), accompaniment(accompaniment) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation.h b/sherpa-onnx/csrc/offline-source-separation.h index 2e12f5ee72..dc9e82a510 100644 --- a/sherpa-onnx/csrc/offline-source-separation.h +++ b/sherpa-onnx/csrc/offline-source-separation.h @@ -6,8 +6,10 @@ #define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ #include +#include #include +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" namespace sherpa_onnx { From 13dc9b2793960eba2e1118d2ff30666d27062260 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 12:27:45 +0800 Subject: [PATCH 04/14] begin to add impl --- sherpa-onnx/csrc/CMakeLists.txt | 2 ++ .../csrc/offline-source-separation-impl.cc | 33 +++++++++++++++++ .../csrc/offline-source-separation-impl.h | 35 +++++++++++++++++++ sherpa-onnx/csrc/offline-source-separation.cc | 4 ++- .../sherpa-onnx-offline-source-separation.cc | 0 5 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index b2b4db788a..4df30953c8 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -51,8 +51,10 @@ set(sources offline-sense-voice-model-config.cc offline-sense-voice-model.cc + offline-source-separation-impl.cc offline-source-separation-model-config.cc offline-source-separation-spleeter-model-config.cc + offline-source-separation.cc offline-stream.cc offline-tdnn-ctc-model.cc diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.cc b/sherpa-onnx/csrc/offline-source-separation-impl.cc index e69de29bb2..5605778310 100644 --- a/sherpa-onnx/csrc/offline-source-separation-impl.cc +++ b/sherpa-onnx/csrc/offline-source-separation-impl.cc @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/offline-source-separation-impl.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-impl.h" +namespace sherpa_onnx { + +std::unique_ptr +OfflineSourceSeparationImpl::Create( + const OfflineSourceSeparationConfig &config) { + return nullptr; +} + +template +std::unique_ptr +OfflineSourceSeparationImpl::Create( + Manager *mgr, const OfflineSourceSeparationConfig &config) { + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +OfflineSourceSeparationImpl::Create( + AAssetManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +OfflineSourceSeparationImpl::Create( + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.h b/sherpa-onnx/csrc/offline-source-separation-impl.h index e69de29bb2..8bb6852d0b 100644 --- a/sherpa-onnx/csrc/offline-source-separation-impl.h +++ b/sherpa-onnx/csrc/offline-source-separation-impl.h @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-source-separation-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ + +#include + +#include "sherpa-onnx/csrc/offline-source-separation.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationImpl { + public: + static std::unique_ptr Create( + const OfflineSourceSeparationConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineSourceSeparationConfig &config); + + virtual ~OfflineSourceSeparationImpl() = default; + + virtual OfflineSourceSeparationOutput Process( + const OfflineSourceSeparationInput &input) const = 0; + + virtual int32_t GetOutputSampleRate() const = 0; + + virtual int32_t GetNumberOfStems() const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation.cc b/sherpa-onnx/csrc/offline-source-separation.cc index 233036ed64..d352d9ab58 100644 --- a/sherpa-onnx/csrc/offline-source-separation.cc +++ b/sherpa-onnx/csrc/offline-source-separation.cc @@ -6,6 +6,8 @@ #include +#include "sherpa-onnx/csrc/offline-source-separation-impl.h" + #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" @@ -37,7 +39,7 @@ std::string OfflineSourceSeparationConfig::ToString() const { template OfflineSourceSeparation::OfflineSourceSeparation( Manager *mgr, const OfflineSourceSeparationConfig &config) - : impl_(OfflineRecognizerImpl::Create(mgr, config)) {} + : impl_(OfflineSourceSeparationImpl::Create(mgr, config)) {} OfflineSourceSeparation::OfflineSourceSeparation( const OfflineSourceSeparationConfig &config) diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc new file mode 100644 index 0000000000..e69de29bb2 From 6250c55ae5bafb1257d31338071f804aeb54b211 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 14:24:03 +0800 Subject: [PATCH 05/14] begin to add spleeter impl --- .../csrc/offline-source-separation-spleeter-impl.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h index e69de29bb2..cbc1097374 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h @@ -0,0 +1,10 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ + +namespace sherpa_onnx {} + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ From 5bd106a6b5dadd71438e75d5ab3e822de6fd9207 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 14:54:04 +0800 Subject: [PATCH 06/14] Fix wasm for speech enhancement --- wasm/speech-enhancement/app-speech-enhancement.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wasm/speech-enhancement/app-speech-enhancement.js b/wasm/speech-enhancement/app-speech-enhancement.js index fe67c4258b..cc89c67fe7 100644 --- a/wasm/speech-enhancement/app-speech-enhancement.js +++ b/wasm/speech-enhancement/app-speech-enhancement.js @@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) { console.log('ArrayBuffer length:', arrayBuffer.byteLength); const uint8Array = new Uint8Array(arrayBuffer); - const wave = readWaveFromBinaryData(uint8Array); + const wave = readWaveFromBinaryData(uint8Array, Module); if (wave == null) { alert( `${file.name} is not a valid .wav file. Please select a *.wav file`); From 4d5c45eb35f626cc65d062e54bba2ca393e84f71 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 15:15:57 +0800 Subject: [PATCH 07/14] Change input shape to make it easier for C++ --- .../workflows/export-spleeter-to-onnx.yaml | 2 +- scripts/spleeter/convert_to_torch.py | 4 +-- scripts/spleeter/export_onnx.py | 4 +-- scripts/spleeter/separate.py | 8 ++++-- scripts/spleeter/separate_onnx.py | 21 ++++++++-------- scripts/spleeter/unet.py | 11 +++++++- .../csrc/offline-source-separation-impl.cc | 11 ++++++-- .../offline-source-separation-spleeter-impl.h | 25 ++++++++++++++++++- 8 files changed, 64 insertions(+), 22 deletions(-) diff --git a/.github/workflows/export-spleeter-to-onnx.yaml b/.github/workflows/export-spleeter-to-onnx.yaml index f1993ce778..feb1d2a6bc 100644 --- a/.github/workflows/export-spleeter-to-onnx.yaml +++ b/.github/workflows/export-spleeter-to-onnx.yaml @@ -3,7 +3,7 @@ name: export-spleeter-to-onnx on: push: branches: - - spleeter-2 + - spleeter-cpp workflow_dispatch: concurrency: diff --git a/scripts/spleeter/convert_to_torch.py b/scripts/spleeter/convert_to_torch.py index dc6e75800d..e5610aa320 100755 --- a/scripts/spleeter/convert_to_torch.py +++ b/scripts/spleeter/convert_to_torch.py @@ -217,8 +217,8 @@ def main(name): # for the batchnormalization in torch, # default input shape is NCHW - # NHWC to NCHW - torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) + torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2)) + torch_y1_out = torch_y1_out.permute(1, 0, 2, 3) # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) assert torch.allclose( diff --git a/scripts/spleeter/export_onnx.py b/scripts/spleeter/export_onnx.py index adc26048eb..af19ea8618 100755 --- a/scripts/spleeter/export_onnx.py +++ b/scripts/spleeter/export_onnx.py @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix): def export(model, prefix): num_splits = 1 - x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32) + x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32) filename = f"./2stems/{prefix}.onnx" torch.onnx.export( @@ -56,7 +56,7 @@ def export(model, prefix): input_names=["x"], output_names=["y"], dynamic_axes={ - "x": {0: "num_splits"}, + "x": {1: "num_splits"}, }, opset_version=13, ) diff --git a/scripts/spleeter/separate.py b/scripts/spleeter/separate.py index 2a83b7ef2f..cf7eebf042 100755 --- a/scripts/spleeter/separate.py +++ b/scripts/spleeter/separate.py @@ -101,13 +101,17 @@ def main(): print("y2", y.shape, y.dtype) y = y.abs() - y = y.permute(0, 3, 1, 2) - # (1, 2, 512, 1024) + + y = y.permute(3, 0, 1, 2) + # (2, 1, 512, 1024) print("y3", y.shape, y.dtype) vocals_spec = vocals(y) accompaniment_spec = accompaniment(y) + vocals_spec = vocals_spec.permute(1, 0, 2, 3) + accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3) + sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 print( "vocals_spec", diff --git a/scripts/spleeter/separate_onnx.py b/scripts/spleeter/separate_onnx.py index 28ed37600d..6cc0baea2b 100755 --- a/scripts/spleeter/separate_onnx.py +++ b/scripts/spleeter/separate_onnx.py @@ -12,15 +12,14 @@ """ ----------inputs for ./2stems/vocals.onnx---------- -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) ----------outputs for ./2stems/vocals.onnx---------- -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) ----------inputs for ./2stems/accompaniment.onnx---------- -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) ----------outputs for ./2stems/accompaniment.onnx---------- -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) - +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) """ @@ -123,16 +122,16 @@ def main(): if padding > 0: stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) - stft0 = stft0.reshape(-1, 1, 512, 1024) - stft1 = stft1.reshape(-1, 1, 512, 1024) + stft0 = stft0.reshape(1, -1, 512, 1024) + stft1 = stft1.reshape(1, -1, 512, 1024) - stft_01 = torch.cat([stft0, stft1], axis=1) + stft_01 = torch.cat([stft0, stft1], axis=0) print("stft_01", stft_01.shape, stft_01.dtype) vocals_spec = vocals(stft_01) accompaniment_spec = accompaniment(stft_01) - # (num_splits, num_channels, 512, 1024) + # (num_channels, num_splits, 512, 1024) sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 @@ -142,8 +141,8 @@ def main(): for name, spec in zip( ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] ): - spec_c0 = spec[:, 0, :, :] - spec_c1 = spec[:, 1, :, :] + spec_c0 = spec[0] + spec_c1 = spec[1] spec_c0 = spec_c0.reshape(-1, 1024) spec_c1 = spec_c1.reshape(-1, 1024) diff --git a/scripts/spleeter/unet.py b/scripts/spleeter/unet.py index cfcabb6c1b..60eaa5e21b 100644 --- a/scripts/spleeter/unet.py +++ b/scripts/spleeter/unet.py @@ -67,6 +67,14 @@ def __init__(self): self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) def forward(self, x): + """ + Args: + x: (num_audio_channels, num_splits, 512, 1024) + Returns: + y: (num_audio_channels, num_splits, 512, 1024) + """ + x = x.permute(1, 0, 2, 3) + in_x = x # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) @@ -147,4 +155,5 @@ def forward(self, x): up7 = self.up7(batch12) up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) - return up7 * in_x + ans = up7 * in_x + return ans.permute(1, 0, 2, 3) diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.cc b/sherpa-onnx/csrc/offline-source-separation-impl.cc index 5605778310..d2e9632828 100644 --- a/sherpa-onnx/csrc/offline-source-separation-impl.cc +++ b/sherpa-onnx/csrc/offline-source-separation-impl.cc @@ -3,19 +3,26 @@ // Copyright (c) 2025 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-source-separation-impl.h" + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h" + namespace sherpa_onnx { std::unique_ptr OfflineSourceSeparationImpl::Create( const OfflineSourceSeparationConfig &config) { - return nullptr; + // TODO(fangjun): Support other models + return std::make_unique(config); } template std::unique_ptr OfflineSourceSeparationImpl::Create( Manager *mgr, const OfflineSourceSeparationConfig &config) { - return nullptr; + // TODO(fangjun): Support other models + return std::make_unique(mgr, config); } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h index cbc1097374..78606bfb24 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h @@ -5,6 +5,29 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ -namespace sherpa_onnx {} +namespace sherpa_onnx { + +class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { + public: + OfflineSourceSeparationSpleeterImpl( + const OfflineSourceSeparationConfig &config) {} + + template + OfflineSourceSeparationSpleeterImpl( + Manager *mgr, const OfflineSourceSeparationConfig &config) {} + + OfflineSourceSeparationOutput Process( + const OfflineSourceSeparationInput &input) const override { + return {}; + } + + int32_t GetOutputSampleRate() const override { return 44100; } + + int32_t GetNumberOfStems() const override { return 2; } + + private: +}; + +} // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ From 1ead04ce2de521985677baf94866d1171a8d6093 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 16:26:57 +0800 Subject: [PATCH 08/14] Support reading multi-channel wave files --- .../workflows/export-spleeter-to-onnx.yaml | 2 +- cmake/cmake_extension.py | 1 + sherpa-onnx/csrc/CMakeLists.txt | 2 + .../sherpa-onnx-offline-source-separation.cc | 87 +++++++++++++++++ sherpa-onnx/csrc/wave-reader.cc | 93 +++++++++++++------ sherpa-onnx/csrc/wave-reader.h | 7 ++ 6 files changed, 165 insertions(+), 27 deletions(-) diff --git a/.github/workflows/export-spleeter-to-onnx.yaml b/.github/workflows/export-spleeter-to-onnx.yaml index feb1d2a6bc..068b168e2f 100644 --- a/.github/workflows/export-spleeter-to-onnx.yaml +++ b/.github/workflows/export-spleeter-to-onnx.yaml @@ -3,7 +3,7 @@ name: export-spleeter-to-onnx on: push: branches: - - spleeter-cpp + - spleeter-cpp-2 workflow_dispatch: concurrency: diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 6b79b173c0..457b847dd6 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -56,6 +56,7 @@ def get_binaries(): "sherpa-onnx-offline-denoiser", "sherpa-onnx-offline-language-identification", "sherpa-onnx-offline-punctuation", + "sherpa-onnx-offline-source-separation", "sherpa-onnx-offline-speaker-diarization", "sherpa-onnx-offline-tts", "sherpa-onnx-offline-tts-play", diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 4df30953c8..bb6a0247c0 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -332,6 +332,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) + add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc) add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc) @@ -352,6 +353,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline-language-identification sherpa-onnx-offline-parallel sherpa-onnx-offline-punctuation + sherpa-onnx-offline-source-separation sherpa-onnx-online-punctuation sherpa-onnx-vad ) diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc index e69de29bb2..bbc1c4bd73 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc @@ -0,0 +1,87 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include + +#include // NOLINT +#include + +#include "sherpa-onnx/csrc/offline-source-separation.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "sherpa-onnx/csrc/wave-writer.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Non-streaming source separation with sherpa-onnx. + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models +to download models. + +Usage: + +(1) Use spleeter models + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 +tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav + +./bin/sherpa-onnx-offline-source-separation \ + --spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \ + --spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \ + --input-wav=audio_example.wav \ + --output-wav=output_example.wav +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OfflineSourceSeparationConfig config; + + std::string input_wave; + std::string output_wave; + + config.Register(&po); + po.Register("input-wav", &input_wave, "Path to input wav."); + po.Register("output-wav", &output_wave, "Path to output wav"); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, "Please don't give positional arguments\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (input_wave.empty()) { + fprintf(stderr, "Please provide --input-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (output_wave.empty()) { + fprintf(stderr, "Please provide --output-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + exit(EXIT_FAILURE); + } + + int32_t sampling_rate = -1; + bool is_ok = false; + auto samples = + sherpa_onnx::ReadWaveMultiChannel(input_wave, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); + return -1; + } + + fprintf(stderr, "Started\n"); + + fprintf(stderr, "Input channels: %d\n", static_cast(samples.size())); + fprintf(stderr, "Input sample rate: %d\n", sampling_rate); + + return 0; +} diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc index 0deb29a4f0..a7829c9f2f 100644 --- a/sherpa-onnx/csrc/wave-reader.cc +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -63,8 +63,9 @@ in sherpa-onnx. // Read a wave file of mono-channel. // Return its samples normalized to the range [-1, 1). -std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, - bool *is_ok) { +std::vector> ReadWaveImpl(std::istream &is, + int32_t *sampling_rate, + bool *is_ok) { WaveHeader header{}; is.read(reinterpret_cast(&header.chunk_id), sizeof(header.chunk_id)); @@ -144,12 +145,6 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, is.read(reinterpret_cast(&header.num_channels), sizeof(header.num_channels)); - if (header.num_channels != 1) { // we support only single channel for now - SHERPA_ONNX_LOGE( - "Warning: %d channels are found. We only use the first channel.\n", - header.num_channels); - } - is.read(reinterpret_cast(&header.sample_rate), sizeof(header.sample_rate)); @@ -219,7 +214,7 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, *sampling_rate = header.sample_rate; - std::vector ans; + std::vector> ans(header.num_channels); if (header.bits_per_sample == 16 && header.audio_format == 1) { // header.subchunk2_size contains the number of bytes in the data. @@ -233,11 +228,16 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); + for (auto &v : ans) { + v.resize(samples.size()); + } // samples are interleaved - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - ans[i] = samples[i * header.num_channels] / 32768.; + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + ans[c][k] = samples[i + c] / 32768.; + } } } else if (header.bits_per_sample == 8 && header.audio_format == 1) { // number of samples == number of bytes for 8-bit encoded samples @@ -252,14 +252,21 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - // Note(fangjun): We want to normalize each sample into the range [-1, 1] - // Since each original sample is in the range [0, 256], dividing - // them by 128 converts them to the range [0, 2]; - // so after subtracting 1, we get the range [-1, 1] - // - ans[i] = samples[i * header.num_channels] / 128. - 1; + for (auto &v : ans) { + v.resize(samples.size()); + } + + // samples are interleaved + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + // Note(fangjun): We want to normalize each sample into the range [-1, + // 1] Since each original sample is in the range [0, 256], dividing them + // by 128 converts them to the range [0, 2]; so after subtracting 1, we + // get the range [-1, 1] + // + ans[c][k] = samples[i + c] / 128. - 1; + } } } else if (header.bits_per_sample == 32 && header.audio_format == 1) { // 32 here is for int32 @@ -275,9 +282,16 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - ans[i] = static_cast(samples[i * header.num_channels]) / (1 << 31); + for (auto &v : ans) { + v.resize(samples.size()); + } + + // samples are interleaved + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + ans[c][k] = static_cast(samples[i + c]) / (1 << 31); + } } } else if (header.bits_per_sample == 32 && header.audio_format == 3) { // 32 here is for float32 @@ -293,9 +307,16 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - ans[i] = samples[i * header.num_channels]; + for (auto &v : ans) { + v.resize(samples.size()); + } + + // samples are interleaved + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + ans[c][k] = samples[i + c]; + } } } else { SHERPA_ONNX_LOGE( @@ -321,7 +342,27 @@ std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, std::vector ReadWave(std::istream &is, int32_t *sampling_rate, bool *is_ok) { auto samples = ReadWaveImpl(is, sampling_rate, is_ok); + + if (samples.size() > 1) { + SHERPA_ONNX_LOGE( + "Warning: %d channels are found. We only use the first channel.\n", + static_cast(samples.size())); + } + + return samples[0]; +} + +std::vector> ReadWaveMultiChannel(std::istream &is, + int32_t *sampling_rate, + bool *is_ok) { + auto samples = ReadWaveImpl(is, sampling_rate, is_ok); return samples; } +std::vector> ReadWaveMultiChannel( + const std::string &filename, int32_t *sampling_rate, bool *is_ok) { + std::ifstream is(filename, std::ifstream::binary); + return ReadWaveMultiChannel(is, sampling_rate, is_ok); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/wave-reader.h b/sherpa-onnx/csrc/wave-reader.h index 98e956abf6..1ba9910218 100644 --- a/sherpa-onnx/csrc/wave-reader.h +++ b/sherpa-onnx/csrc/wave-reader.h @@ -26,6 +26,13 @@ std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, std::vector ReadWave(std::istream &is, int32_t *sampling_rate, bool *is_ok); +std::vector> ReadWaveMultiChannel(std::istream &is, + int32_t *sampling_rate, + bool *is_ok); + +std::vector> ReadWaveMultiChannel( + const std::string &filename, int32_t *sampling_rate, bool *is_ok); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ From 6fa3854a1f905996af25befa459b34e7b52b4b0a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 17:15:52 +0800 Subject: [PATCH 09/14] Test read and write multi-channel waves --- .../sherpa-onnx-offline-source-separation.cc | 11 +++ sherpa-onnx/csrc/wave-reader.cc | 8 +- sherpa-onnx/csrc/wave-writer.cc | 80 +++++++++++++------ sherpa-onnx/csrc/wave-writer.h | 8 +- 4 files changed, 79 insertions(+), 28 deletions(-) diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc index bbc1c4bd73..931dada8a8 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc @@ -82,6 +82,17 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m fprintf(stderr, "Input channels: %d\n", static_cast(samples.size())); fprintf(stderr, "Input sample rate: %d\n", sampling_rate); + fprintf(stderr, "Input sample size: %d\n", (int)samples[0].size()); + + fprintf(stderr, "Done\n"); + is_ok = sherpa_onnx::WriteWave(output_wave, sampling_rate, samples[0].data(), + samples[1].data(), samples[0].size()); + + if (!is_ok) { + fprintf(stderr, "Failed to write to '%s'\n", output_wave.c_str()); + exit(EXIT_FAILURE); + } + fprintf(stderr, "Saved to write to '%s'\n", output_wave.c_str()); return 0; } diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc index a7829c9f2f..90db0d513b 100644 --- a/sherpa-onnx/csrc/wave-reader.cc +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -229,7 +229,7 @@ std::vector> ReadWaveImpl(std::istream &is, } for (auto &v : ans) { - v.resize(samples.size()); + v.resize(samples.size() / header.num_channels); } // samples are interleaved @@ -253,7 +253,7 @@ std::vector> ReadWaveImpl(std::istream &is, } for (auto &v : ans) { - v.resize(samples.size()); + v.resize(samples.size() / header.num_channels); } // samples are interleaved @@ -283,7 +283,7 @@ std::vector> ReadWaveImpl(std::istream &is, } for (auto &v : ans) { - v.resize(samples.size()); + v.resize(samples.size() / header.num_channels); } // samples are interleaved @@ -308,7 +308,7 @@ std::vector> ReadWaveImpl(std::istream &is, } for (auto &v : ans) { - v.resize(samples.size()); + v.resize(samples.size() / header.num_channels); } // samples are interleaved diff --git a/sherpa-onnx/csrc/wave-writer.cc b/sherpa-onnx/csrc/wave-writer.cc index f3aca8e90c..98cd24c989 100644 --- a/sherpa-onnx/csrc/wave-writer.cc +++ b/sherpa-onnx/csrc/wave-writer.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/wave-writer.h" +#include #include #include #include @@ -36,12 +37,44 @@ struct WaveHeader { } // namespace -int64_t WaveFileSize(int32_t n_samples) { - return sizeof(WaveHeader) + n_samples * sizeof(int16_t); +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels /*= 1*/) { + return sizeof(WaveHeader) + n_samples * sizeof(int16_t) * num_channels; } void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, int32_t n) { + WriteWave(buffer, sampling_rate, samples, nullptr, n); +} + +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples, int32_t n) { + return WriteWave(filename, sampling_rate, samples, nullptr, n); +} + +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples_ch0, const float *samples_ch1, int32_t n) { + std::string buffer; + buffer.resize(WaveFileSize(n, samples_ch1 == nullptr ? 1 : 2)); + + WriteWave(buffer.data(), sampling_rate, samples_ch0, samples_ch1, n); + + std::ofstream os(filename, std::ios::binary); + if (!os) { + SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str()); + return false; + } + + os << buffer; + if (!os) { + SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str()); + return false; + } + + return true; +} + +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0, + const float *samples_ch1, int32_t n) { WaveHeader header{}; header.chunk_id = 0x46464952; // FFIR header.format = 0x45564157; // EVAW @@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, header.subchunk1_size = 16; // 16 for PCM header.audio_format = 1; // PCM =1 - int32_t num_channels = 1; + int32_t num_channels = samples_ch1 == nullptr ? 1 : 2; int32_t bits_per_sample = 16; // int16_t + header.num_channels = num_channels; header.sample_rate = sampling_rate; header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; @@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, header.chunk_size = 36 + header.subchunk2_size; - std::vector samples_int16(n); + std::vector samples_int16_ch0(n); for (int32_t i = 0; i != n; ++i) { - samples_int16[i] = samples[i] * 32767; + samples_int16_ch0[i] = std::min(samples_ch0[i] * 32767, 32767); + } + + std::vector samples_int16_ch1; + if (samples_ch1) { + samples_int16_ch1.resize(n); + for (int32_t i = 0; i != n; ++i) { + samples_int16_ch1[i] = std::min(samples_ch1[i] * 32767, 32767); + } } memcpy(buffer, &header, sizeof(WaveHeader)); - memcpy(buffer + sizeof(WaveHeader), samples_int16.data(), - n * sizeof(int16_t)); -} -bool WriteWave(const std::string &filename, int32_t sampling_rate, - const float *samples, int32_t n) { - std::string buffer; - buffer.resize(WaveFileSize(n)); - WriteWave(buffer.data(), sampling_rate, samples, n); - std::ofstream os(filename, std::ios::binary); - if (!os) { - SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str()); - return false; - } - os << buffer; - if (!os) { - SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str()); - return false; + if (samples_ch1 == nullptr) { + memcpy(buffer + sizeof(WaveHeader), samples_int16_ch0.data(), + n * sizeof(int16_t)); + } else { + auto p = reinterpret_cast(buffer + sizeof(WaveHeader)); + + for (int32_t i = 0; i != n; ++i) { + p[2 * i] = samples_int16_ch0[i]; + p[2 * i + 1] = samples_int16_ch1[i]; + } } - return true; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/wave-writer.h b/sherpa-onnx/csrc/wave-writer.h index 054d9eaff3..0d6e8c6a1a 100644 --- a/sherpa-onnx/csrc/wave-writer.h +++ b/sherpa-onnx/csrc/wave-writer.h @@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate, void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, int32_t n); -int64_t WaveFileSize(int32_t n_samples); +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples_ch0, const float *samples_ch1, int32_t n); + +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0, + const float *samples_ch1, int32_t n); + +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels = 1); } // namespace sherpa_onnx From 11cf8eaeb87fe6c42beb7c4191605e860214cbfe Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 17:20:51 +0800 Subject: [PATCH 10/14] Begin to implement processing --- .../offline-source-separation-spleeter-impl.h | 8 ++++++- .../sherpa-onnx-offline-source-separation.cc | 22 ++++++++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h index 78606bfb24..3de370107e 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h @@ -5,12 +5,17 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-source-separation.h" + namespace sherpa_onnx { class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { public: OfflineSourceSeparationSpleeterImpl( - const OfflineSourceSeparationConfig &config) {} + const OfflineSourceSeparationConfig &config) { + SHERPA_ONNX_LOGE("created!"); + } template OfflineSourceSeparationSpleeterImpl( @@ -18,6 +23,7 @@ class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { OfflineSourceSeparationOutput Process( const OfflineSourceSeparationInput &input) const override { + SHERPA_ONNX_LOGE("processing!"); return {}; } diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc index 931dada8a8..b73427d1cf 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc @@ -69,10 +69,10 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m exit(EXIT_FAILURE); } - int32_t sampling_rate = -1; bool is_ok = false; - auto samples = - sherpa_onnx::ReadWaveMultiChannel(input_wave, &sampling_rate, &is_ok); + sherpa_onnx::OfflineSourceSeparationInput input; + input.samples.data = + sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok); if (!is_ok) { fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); return -1; @@ -80,13 +80,19 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m fprintf(stderr, "Started\n"); - fprintf(stderr, "Input channels: %d\n", static_cast(samples.size())); - fprintf(stderr, "Input sample rate: %d\n", sampling_rate); - fprintf(stderr, "Input sample size: %d\n", (int)samples[0].size()); + fprintf(stderr, "Input channels: %d\n", + static_cast(input.samples.data.size())); + fprintf(stderr, "Input sample rate: %d\n", input.sample_rate); + fprintf(stderr, "Input sample size: %d\n", (int)input.samples.data[0].size()); + + sherpa_onnx::OfflineSourceSeparation sp(config); + + auto output = sp.Process(input); fprintf(stderr, "Done\n"); - is_ok = sherpa_onnx::WriteWave(output_wave, sampling_rate, samples[0].data(), - samples[1].data(), samples[0].size()); + is_ok = sherpa_onnx::WriteWave( + output_wave, input.sample_rate, input.samples.data[0].data(), + input.samples.data[1].data(), input.samples.data[0].size()); if (!is_ok) { fprintf(stderr, "Failed to write to '%s'\n", output_wave.c_str()); From b36bf86af12688101cb64c6256e9f9f3f07431a8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 17:27:04 +0800 Subject: [PATCH 11/14] add meta data --- ...urce-separation-spleeter-model-meta-data.h | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h index e69de29bb2..7e1ca94f9c 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h @@ -0,0 +1,26 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +struct OfflineSourceSeparationSpleeterModelMetaData { + int32_t sample_rate = 44100; + + int32_t num_stems = 2; + int32_t n_fft = 4096; + int32_t hop_length = 1024; + int32_t window_length = 0; + bool center = false; + std::string window_type = "hann"; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ From 8d9e495e9a6d7b767ebb9a2c728efee6cb01cbf9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 19:50:55 +0800 Subject: [PATCH 12/14] Add spleeter model --- sherpa-onnx/csrc/CMakeLists.txt | 1 + .../offline-source-separation-model-config.cc | 14 +- .../offline-source-separation-model-config.h | 14 +- ...urce-separation-spleeter-model-meta-data.h | 4 +- ...ffline-source-separation-spleeter-model.cc | 212 ++++++++++++++++++ ...offline-source-separation-spleeter-model.h | 37 +++ 6 files changed, 277 insertions(+), 5 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc create mode 100644 sherpa-onnx/csrc/offline-source-separation-spleeter-model.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index bb6a0247c0..f704bfeb7a 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -54,6 +54,7 @@ set(sources offline-source-separation-impl.cc offline-source-separation-model-config.cc offline-source-separation-spleeter-model-config.cc + offline-source-separation-spleeter-model.cc offline-source-separation.cc offline-stream.cc diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-model-config.cc index 168b002fae..dfd765d3f1 100644 --- a/sherpa-onnx/csrc/offline-source-separation-model-config.cc +++ b/sherpa-onnx/csrc/offline-source-separation-model-config.cc @@ -8,6 +8,15 @@ namespace sherpa_onnx { void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) { spleeter.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); } bool OfflineSourceSeparationModelConfig::Validate() const { @@ -18,7 +27,10 @@ std::string OfflineSourceSeparationModelConfig::ToString() const { std::ostringstream os; os << "OfflineSourceSeparationModelConfig("; - os << "spleeter=" << spleeter.ToString() << ")"; + os << "spleeter=" << spleeter.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.h b/sherpa-onnx/csrc/offline-source-separation-model-config.h index 692ba09944..bf88d39dbb 100644 --- a/sherpa-onnx/csrc/offline-source-separation-model-config.h +++ b/sherpa-onnx/csrc/offline-source-separation-model-config.h @@ -15,11 +15,19 @@ namespace sherpa_onnx { struct OfflineSourceSeparationModelConfig { OfflineSourceSeparationSpleeterModelConfig spleeter; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + OfflineSourceSeparationModelConfig() = default; - explicit OfflineSourceSeparationModelConfig( - const OfflineSourceSeparationSpleeterModelConfig &spleeter) - : spleeter(spleeter) {} + OfflineSourceSeparationModelConfig( + const OfflineSourceSeparationSpleeterModelConfig &spleeter, + int32_t num_threads, bool debug, const std::string &provider) + : spleeter(spleeter), + num_threads(num_threads), + debug(debug), + provider(provider) {} void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h index 7e1ca94f9c..fd7c5ee163 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h @@ -10,10 +10,12 @@ namespace sherpa_onnx { +// See also +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py struct OfflineSourceSeparationSpleeterModelMetaData { int32_t sample_rate = 44100; - int32_t num_stems = 2; + int32_t n_fft = 4096; int32_t hop_length = 1024; int32_t window_length = 0; diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc new file mode 100644 index 0000000000..12f343cc67 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc @@ -0,0 +1,212 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationSpleeterModel::Impl { + public: + explicit Impl(const OfflineSourceSeparationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.spleeter.vocals); + InitVocals(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.spleeter.accompaniment); + InitAccompaniment(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineSourceSeparationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.spleeter.vocals); + InitVocals(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.spleeter.accompaniment); + InitAccompaniment(buf.data(), buf.size()); + } + } + + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const { + return meta_; + } + + Ort::Value RunVocals(Ort::Value x) const { + auto out = vocals_sess_->Run({}, vocals_input_names_ptr_.data(), &x, 1, + vocals_output_names_ptr_.data(), + vocals_output_names_ptr_.size()); + return std::move(out[0]); + } + + Ort::Value RunAccompaniment(Ort::Value x) const { + auto out = + accompaniment_sess_->Run({}, accompaniment_input_names_ptr_.data(), &x, + 1, accompaniment_output_names_ptr_.data(), + accompaniment_output_names_ptr_.size()); + return std::move(out[0]); + } + + private: + void InitVocals(void *model_data, size_t model_data_length) { + vocals_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(vocals_sess_.get(), &vocals_input_names_, + &vocals_input_names_ptr_); + + GetOutputNames(vocals_sess_.get(), &vocals_output_names_, + &vocals_output_names_ptr_); + + Ort::ModelMetadata meta_data = vocals_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---vocals model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : vocals_input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : vocals_output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + std::string model_type; + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + if (model_type != "spleeter") { + SHERPA_ONNX_LOGE("Expect model type 'spleeter'. Given: '%s'", + model_type.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + int32_t stems = -1; + SHERPA_ONNX_READ_META_DATA(stems, "stems"); + if (stems != 2) { + SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems", stems); + SHERPA_ONNX_EXIT(-1); + } + } + + void InitAccompaniment(void *model_data, size_t model_data_length) { + accompaniment_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(accompaniment_sess_.get(), &accompaniment_input_names_, + &accompaniment_input_names_ptr_); + + GetOutputNames(accompaniment_sess_.get(), &accompaniment_output_names_, + &accompaniment_output_names_ptr_); + } + + private: + OfflineSourceSeparationModelConfig config_; + OfflineSourceSeparationSpleeterModelMetaData meta_; + + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr vocals_sess_; + + std::vector vocals_input_names_; + std::vector vocals_input_names_ptr_; + + std::vector vocals_output_names_; + std::vector vocals_output_names_ptr_; + + std::unique_ptr accompaniment_sess_; + + std::vector accompaniment_input_names_; + std::vector accompaniment_input_names_ptr_; + + std::vector accompaniment_output_names_; + std::vector accompaniment_output_names_ptr_; +}; + +OfflineSourceSeparationSpleeterModel::~OfflineSourceSeparationSpleeterModel() = + default; + +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel( + const OfflineSourceSeparationModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel( + Manager *mgr, const OfflineSourceSeparationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +Ort::Value OfflineSourceSeparationSpleeterModel::RunVocals(Ort::Value x) const { + return impl_->RunVocals(std::move(x)); +} + +Ort::Value OfflineSourceSeparationSpleeterModel::RunAccompaniment( + Ort::Value x) const { + return impl_->RunAccompaniment(std::move(x)); +} + +const OfflineSourceSeparationSpleeterModelMetaData & +OfflineSourceSeparationSpleeterModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSourceSeparationSpleeterModel:: + OfflineSourceSeparationSpleeterModel( + AAssetManager *mgr, const OfflineSourceSeparationModelConfig &config); +#endif + +#if __OHOS__ +template OfflineSourceSeparationSpleeterModel:: + OfflineSourceSeparationSpleeterModel( + NativeResourceManager *mgr, + const OfflineSourceSeparationModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.h new file mode 100644 index 0000000000..8cad92fe55 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationSpleeterModel { + public: + ~OfflineSourceSeparationSpleeterModel(); + + explicit OfflineSourceSeparationSpleeterModel( + const OfflineSourceSeparationModelConfig &config); + + template + OfflineSourceSeparationSpleeterModel( + Manager *mgr, const OfflineSourceSeparationModelConfig &config); + + Ort::Value RunVocals(Ort::Value x) const; + Ort::Value RunAccompaniment(Ort::Value x) const; + + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ From 76578e934e4e4612cfb98d2b08bdf6ed7d27b5f6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 21:54:33 +0800 Subject: [PATCH 13/14] first working version --- .../offline-source-separation-spleeter-impl.h | 229 +++++++++++++++++- ...urce-separation-spleeter-model-meta-data.h | 2 +- ...ffline-source-separation-spleeter-model.cc | 8 +- .../sherpa-onnx-offline-source-separation.cc | 43 +++- 4 files changed, 261 insertions(+), 21 deletions(-) diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h index 3de370107e..94aca85147 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h @@ -5,33 +5,250 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ +#include "Eigen/Dense" +#include "kaldi-native-fbank/csrc/istft.h" +#include "kaldi-native-fbank/csrc/stft.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" #include "sherpa-onnx/csrc/offline-source-separation.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/wave-writer.h" namespace sherpa_onnx { class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { public: OfflineSourceSeparationSpleeterImpl( - const OfflineSourceSeparationConfig &config) { + const OfflineSourceSeparationConfig &config) + : config_(config), model_(config_.model) { SHERPA_ONNX_LOGE("created!"); } template OfflineSourceSeparationSpleeterImpl( - Manager *mgr, const OfflineSourceSeparationConfig &config) {} + Manager *mgr, const OfflineSourceSeparationConfig &config) + : config_(config), model_(mgr, config_.model) {} OfflineSourceSeparationOutput Process( const OfflineSourceSeparationInput &input) const override { - SHERPA_ONNX_LOGE("processing!"); - return {}; + if (config_.model.debug) { + SHERPA_ONNX_LOGE("input sample rate: %d", input.sample_rate); + SHERPA_ONNX_LOGE("input ch0 samples size: %d", + static_cast(input.samples.data[0].size())); + } + + if (input.samples.data.size() > 1) { + if (config_.model.debug) { + SHERPA_ONNX_LOGE("input ch1 samples size: %d", + static_cast(input.samples.data[1].size())); + } + + if (input.samples.data[0].size() != input.samples.data[1].size()) { + SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d", + static_cast(input.samples.data[0].size()), + static_cast(input.samples.data[1].size())); + + SHERPA_ONNX_EXIT(-1); + } + } + + auto stft_ch0 = ComputeStft(input, 0); + + auto stft_ch1 = ComputeStft(input, 1); + knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1; + + SHERPA_ONNX_LOGE("number of frames: %d", stft_ch0.num_frames); + + int32_t num_frames = stft_ch0.num_frames; + int32_t fft_bins = stft_ch0.real.size() / num_frames; + + int32_t pad = 512 - (stft_ch0.num_frames % 512); + if (pad < 512) { + num_frames += pad; + } + + if (num_frames % 512) { + SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d", + num_frames, num_frames % 512); + SHERPA_ONNX_EXIT(-1); + } + + Eigen::VectorXf real(2 * num_frames * 1024); + Eigen::VectorXf imag(2 * num_frames * 1024); + real.setZero(); + imag.setZero(); + + float *p_real = &real[0]; + float *p_imag = &imag[0]; + + // copy stft result of channel 0 + for (int32_t i = 0; i != stft_ch0.num_frames; ++i) { + std::copy(stft_ch0.real.data() + i * fft_bins, + stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i); + + std::copy(stft_ch0.imag.data() + i * fft_bins, + stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i); + } + + p_real += num_frames * 1024; + p_imag += num_frames * 1024; + + // copy stft result of channel 1 + for (int32_t i = 0; i != stft_ch1.num_frames; ++i) { + std::copy(p_stft_ch1->real.data() + i * fft_bins, + p_stft_ch1->real.data() + i * fft_bins + 1024, + p_real + 1024 * i); + + std::copy(p_stft_ch1->imag.data() + i * fft_bins, + p_stft_ch1->imag.data() + i * fft_bins + 1024, + p_imag + 1024 * i); + } + + Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{2, num_frames / 512, 512, 1024}; + Ort::Value x_tensor = Ort::Value::CreateTensor( + memory_info, &x[0], x.size(), x_shape.data(), x_shape.size()); + + Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor)); + Ort::Value accompaniment_spec_tensor = + model_.RunAccompaniment(std::move(x_tensor)); + + Eigen::VectorXf vocals_spec = Eigen::Map( + vocals_spec_tensor.GetTensorMutableData(), x.size()); + + Eigen::VectorXf accompaniment_spec = Eigen::Map( + accompaniment_spec_tensor.GetTensorMutableData(), x.size()); + + Eigen::VectorXf sum_spec = vocals_spec.array().square() + + accompaniment_spec.array().square() + 1e-10; + + vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array(); + + accompaniment_spec = + (accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array(); + SHERPA_ONNX_LOGE("here"); + + auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0); + auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1); + + auto accompaniment_samples_ch0 = + ProcessSpec(accompaniment_spec, stft_ch0, 0); + auto accompaniment_samples_ch1 = + ProcessSpec(accompaniment_spec, *p_stft_ch1, 1); + + OfflineSourceSeparationOutput ans; + ans.sample_rate = GetOutputSampleRate(); + + ans.stems.resize(2); + ans.stems[0].data.reserve(2); + ans.stems[1].data.reserve(2); + + ans.stems[0].data.push_back(std::move(vocals_samples_ch0)); + ans.stems[0].data.push_back(std::move(vocals_samples_ch1)); + + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0)); + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1)); + + return ans; + } + + int32_t GetOutputSampleRate() const override { + return model_.GetMetaData().sample_rate; + } + + int32_t GetNumberOfStems() const override { + return model_.GetMetaData().num_stems; } - int32_t GetOutputSampleRate() const override { return 44100; } + private: + // spec is of shape (2, num_chunks, 512, 1024) + std::vector ProcessSpec(const Eigen::VectorXf &spec, + const knf::StftResult &stft, + int32_t channel) const { + int32_t fft_bins = stft.real.size() / stft.num_frames; + + Eigen::VectorXf mask(stft.real.size()); + mask.setZero(); + + float *p_mask = &mask[0]; + + // assume there are 2 channels + const float *p_spec = &spec[0] + (spec.size() / 2) * channel; + + for (int32_t i = 0; i != stft.num_frames; ++i) { + std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024, + p_mask + i * fft_bins); + } + + knf::StftResult masked_stft; + + masked_stft.num_frames = stft.num_frames; + masked_stft.real.resize(stft.real.size()); + masked_stft.imag.resize(stft.imag.size()); + + Eigen::Map(masked_stft.real.data(), + masked_stft.real.size()) = + mask.array() * + Eigen::Map(const_cast(stft.real.data()), + stft.real.size()) + .array(); - int32_t GetNumberOfStems() const override { return 2; } + Eigen::Map(masked_stft.imag.data(), + masked_stft.imag.size()) = + mask.array() * + Eigen::Map(const_cast(stft.imag.data()), + stft.imag.size()) + .array(); + + auto stft_config = GetStftConfig(); + knf::IStft istft(stft_config); + + return istft.Compute(masked_stft); + } + + knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input, + int32_t ch) const { + if (ch >= input.samples.data.size()) { + SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch, + static_cast(input.samples.data.size())); + SHERPA_ONNX_EXIT(-1); + } + + if (input.samples.data[ch].empty()) { + return {}; + } + + return ComputeStft(input.samples.data[ch]); + } + + knf::StftResult ComputeStft(const std::vector &samples) const { + auto stft_config = GetStftConfig(); + knf::Stft stft(stft_config); + + return stft.Compute(samples.data(), samples.size()); + } + + knf::StftConfig GetStftConfig() const { + const auto &meta = model_.GetMetaData(); + + knf::StftConfig stft_config; + stft_config.n_fft = meta.n_fft; + stft_config.hop_length = meta.hop_length; + stft_config.win_length = meta.window_length; + stft_config.window_type = meta.window_type; + stft_config.center = meta.center; + stft_config.center = false; + + return stft_config; + } private: + OfflineSourceSeparationConfig config_; + OfflineSourceSeparationSpleeterModel model_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h index fd7c5ee163..31b214cbd7 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h @@ -18,7 +18,7 @@ struct OfflineSourceSeparationSpleeterModelMetaData { int32_t n_fft = 4096; int32_t hop_length = 1024; - int32_t window_length = 0; + int32_t window_length = 4096; bool center = false; std::string window_type = "hann"; }; diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc index 12f343cc67..e3c1651115 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc @@ -126,10 +126,10 @@ class OfflineSourceSeparationSpleeterModel::Impl { SHERPA_ONNX_EXIT(-1); } - int32_t stems = -1; - SHERPA_ONNX_READ_META_DATA(stems, "stems"); - if (stems != 2) { - SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems", stems); + SHERPA_ONNX_READ_META_DATA(meta_.num_stems, "stems"); + if (meta_.num_stems != 2) { + SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems", + meta_.num_stems); SHERPA_ONNX_EXIT(-1); } } diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc index b73427d1cf..90619b7c17 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc @@ -31,18 +31,23 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m --spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \ --spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \ --input-wav=audio_example.wav \ - --output-wav=output_example.wav + --output-vocals-wav=output_vocals.wav \ + --output-accompaniment-wav=output_accompaniment.wav )usage"; sherpa_onnx::ParseOptions po(kUsageMessage); sherpa_onnx::OfflineSourceSeparationConfig config; std::string input_wave; - std::string output_wave; + std::string output_vocals_wave; + std::string output_accompaniment_wave; config.Register(&po); po.Register("input-wav", &input_wave, "Path to input wav."); - po.Register("output-wav", &output_wave, "Path to output wav"); + po.Register("output-vocals-wav", &output_vocals_wave, + "Path to output vocals wav"); + po.Register("output-accompaniment-wav", &output_accompaniment_wave, + "Path to output accompaniment wav"); po.Read(argc, argv); if (po.NumArgs() != 0) { @@ -58,8 +63,14 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m exit(EXIT_FAILURE); } - if (output_wave.empty()) { - fprintf(stderr, "Please provide --output-wav\n"); + if (output_vocals_wave.empty()) { + fprintf(stderr, "Please provide --output-vocals-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (output_accompaniment_wave.empty()) { + fprintf(stderr, "Please provide --output-accompaniment-wav\n"); po.PrintUsage(); exit(EXIT_FAILURE); } @@ -89,16 +100,28 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m auto output = sp.Process(input); - fprintf(stderr, "Done\n"); is_ok = sherpa_onnx::WriteWave( - output_wave, input.sample_rate, input.samples.data[0].data(), - input.samples.data[1].data(), input.samples.data[0].size()); + output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(), + output.stems[0].data[1].data(), output.stems[0].data[0].size()); + + if (!is_ok) { + fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str()); + exit(EXIT_FAILURE); + } + + is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate, + output.stems[1].data[0].data(), + output.stems[1].data[1].data(), + output.stems[1].data[0].size()); if (!is_ok) { - fprintf(stderr, "Failed to write to '%s'\n", output_wave.c_str()); + fprintf(stderr, "Failed to write to '%s'\n", + output_accompaniment_wave.c_str()); exit(EXIT_FAILURE); } - fprintf(stderr, "Saved to write to '%s'\n", output_wave.c_str()); + + fprintf(stderr, "Saved to write to '%s' and '%s'\n", + output_vocals_wave.c_str(), output_accompaniment_wave.c_str()); return 0; } From 9dba3431e2464cc4a7cd310c727bfc58ae738390 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 23 May 2025 22:17:57 +0800 Subject: [PATCH 14/14] add resampling --- .../offline-source-separation-spleeter-impl.h | 56 +++++++++++++------ .../sherpa-onnx-offline-source-separation.cc | 21 +++++-- 2 files changed, 54 insertions(+), 23 deletions(-) diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h index 94aca85147..7a707c63f3 100644 --- a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h @@ -12,7 +12,7 @@ #include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" #include "sherpa-onnx/csrc/offline-source-separation.h" #include "sherpa-onnx/csrc/onnx-utils.h" -#include "sherpa-onnx/csrc/wave-writer.h" +#include "sherpa-onnx/csrc/resample.h" namespace sherpa_onnx { @@ -20,9 +20,7 @@ class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { public: OfflineSourceSeparationSpleeterImpl( const OfflineSourceSeparationConfig &config) - : config_(config), model_(config_.model) { - SHERPA_ONNX_LOGE("created!"); - } + : config_(config), model_(config_.model) {} template OfflineSourceSeparationSpleeterImpl( @@ -31,34 +29,57 @@ class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { OfflineSourceSeparationOutput Process( const OfflineSourceSeparationInput &input) const override { - if (config_.model.debug) { - SHERPA_ONNX_LOGE("input sample rate: %d", input.sample_rate); - SHERPA_ONNX_LOGE("input ch0 samples size: %d", - static_cast(input.samples.data[0].size())); + const OfflineSourceSeparationInput *p_input = &input; + OfflineSourceSeparationInput tmp_input; + + int32_t output_sample_rate = GetOutputSampleRate(); + + if (input.sample_rate != output_sample_rate) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + input.sample_rate, output_sample_rate); + + float min_freq = std::min(input.sample_rate, output_sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + input.sample_rate, output_sample_rate, lowpass_cutoff, + lowpass_filter_width); + + std::vector s; + for (const auto &samples : input.samples.data) { + resampler->Reset(); + resampler->Resample(samples.data(), samples.size(), true, &s); + tmp_input.samples.data.push_back(std::move(s)); + } + + tmp_input.sample_rate = output_sample_rate; + p_input = &tmp_input; } - if (input.samples.data.size() > 1) { + if (p_input->samples.data.size() > 1) { if (config_.model.debug) { SHERPA_ONNX_LOGE("input ch1 samples size: %d", - static_cast(input.samples.data[1].size())); + static_cast(p_input->samples.data[1].size())); } - if (input.samples.data[0].size() != input.samples.data[1].size()) { + if (p_input->samples.data[0].size() != p_input->samples.data[1].size()) { SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d", - static_cast(input.samples.data[0].size()), - static_cast(input.samples.data[1].size())); + static_cast(p_input->samples.data[0].size()), + static_cast(p_input->samples.data[1].size())); SHERPA_ONNX_EXIT(-1); } } - auto stft_ch0 = ComputeStft(input, 0); + auto stft_ch0 = ComputeStft(*p_input, 0); - auto stft_ch1 = ComputeStft(input, 1); + auto stft_ch1 = ComputeStft(*p_input, 1); knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1; - SHERPA_ONNX_LOGE("number of frames: %d", stft_ch0.num_frames); - int32_t num_frames = stft_ch0.num_frames; int32_t fft_bins = stft_ch0.real.size() / num_frames; @@ -130,7 +151,6 @@ class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { accompaniment_spec = (accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array(); - SHERPA_ONNX_LOGE("here"); auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0); auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1); diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc index 90619b7c17..8af94aa1df 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc @@ -91,14 +91,16 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m fprintf(stderr, "Started\n"); - fprintf(stderr, "Input channels: %d\n", - static_cast(input.samples.data.size())); - fprintf(stderr, "Input sample rate: %d\n", input.sample_rate); - fprintf(stderr, "Input sample size: %d\n", (int)input.samples.data[0].size()); - sherpa_onnx::OfflineSourceSeparation sp(config); + const auto begin = std::chrono::steady_clock::now(); auto output = sp.Process(input); + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; is_ok = sherpa_onnx::WriteWave( output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(), @@ -120,8 +122,17 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-m exit(EXIT_FAILURE); } + fprintf(stderr, "Done\n"); fprintf(stderr, "Saved to write to '%s' and '%s'\n", output_vocals_wave.c_str(), output_accompaniment_wave.c_str()); + float duration = + input.samples.data[0].size() / static_cast(input.sample_rate); + fprintf(stderr, "num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + return 0; }