diff --git a/.github/workflows/export-spleeter-to-onnx.yaml b/.github/workflows/export-spleeter-to-onnx.yaml index f1993ce778..068b168e2f 100644 --- a/.github/workflows/export-spleeter-to-onnx.yaml +++ b/.github/workflows/export-spleeter-to-onnx.yaml @@ -3,7 +3,7 @@ name: export-spleeter-to-onnx on: push: branches: - - spleeter-2 + - spleeter-cpp-2 workflow_dispatch: concurrency: diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 6b79b173c0..457b847dd6 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -56,6 +56,7 @@ def get_binaries(): "sherpa-onnx-offline-denoiser", "sherpa-onnx-offline-language-identification", "sherpa-onnx-offline-punctuation", + "sherpa-onnx-offline-source-separation", "sherpa-onnx-offline-speaker-diarization", "sherpa-onnx-offline-tts", "sherpa-onnx-offline-tts-play", diff --git a/scripts/spleeter/convert_to_torch.py b/scripts/spleeter/convert_to_torch.py index dc6e75800d..e5610aa320 100755 --- a/scripts/spleeter/convert_to_torch.py +++ b/scripts/spleeter/convert_to_torch.py @@ -217,8 +217,8 @@ def main(name): # for the batchnormalization in torch, # default input shape is NCHW - # NHWC to NCHW - torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) + torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2)) + torch_y1_out = torch_y1_out.permute(1, 0, 2, 3) # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) assert torch.allclose( diff --git a/scripts/spleeter/export_onnx.py b/scripts/spleeter/export_onnx.py index adc26048eb..af19ea8618 100755 --- a/scripts/spleeter/export_onnx.py +++ b/scripts/spleeter/export_onnx.py @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix): def export(model, prefix): num_splits = 1 - x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32) + x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32) filename = f"./2stems/{prefix}.onnx" torch.onnx.export( @@ -56,7 +56,7 @@ def export(model, prefix): input_names=["x"], output_names=["y"], dynamic_axes={ - "x": {0: "num_splits"}, + "x": {1: "num_splits"}, }, opset_version=13, ) diff --git a/scripts/spleeter/separate.py b/scripts/spleeter/separate.py index 2a83b7ef2f..cf7eebf042 100755 --- a/scripts/spleeter/separate.py +++ b/scripts/spleeter/separate.py @@ -101,13 +101,17 @@ def main(): print("y2", y.shape, y.dtype) y = y.abs() - y = y.permute(0, 3, 1, 2) - # (1, 2, 512, 1024) + + y = y.permute(3, 0, 1, 2) + # (2, 1, 512, 1024) print("y3", y.shape, y.dtype) vocals_spec = vocals(y) accompaniment_spec = accompaniment(y) + vocals_spec = vocals_spec.permute(1, 0, 2, 3) + accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3) + sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 print( "vocals_spec", diff --git a/scripts/spleeter/separate_onnx.py b/scripts/spleeter/separate_onnx.py index 28ed37600d..6cc0baea2b 100755 --- a/scripts/spleeter/separate_onnx.py +++ b/scripts/spleeter/separate_onnx.py @@ -12,15 +12,14 @@ """ ----------inputs for ./2stems/vocals.onnx---------- -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) ----------outputs for ./2stems/vocals.onnx---------- -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) ----------inputs for ./2stems/accompaniment.onnx---------- -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) ----------outputs for ./2stems/accompaniment.onnx---------- -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) - +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) """ @@ -123,16 +122,16 @@ def main(): if padding > 0: stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) - stft0 = stft0.reshape(-1, 1, 512, 1024) - stft1 = stft1.reshape(-1, 1, 512, 1024) + stft0 = stft0.reshape(1, -1, 512, 1024) + stft1 = stft1.reshape(1, -1, 512, 1024) - stft_01 = torch.cat([stft0, stft1], axis=1) + stft_01 = torch.cat([stft0, stft1], axis=0) print("stft_01", stft_01.shape, stft_01.dtype) vocals_spec = vocals(stft_01) accompaniment_spec = accompaniment(stft_01) - # (num_splits, num_channels, 512, 1024) + # (num_channels, num_splits, 512, 1024) sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 @@ -142,8 +141,8 @@ def main(): for name, spec in zip( ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] ): - spec_c0 = spec[:, 0, :, :] - spec_c1 = spec[:, 1, :, :] + spec_c0 = spec[0] + spec_c1 = spec[1] spec_c0 = spec_c0.reshape(-1, 1024) spec_c1 = spec_c1.reshape(-1, 1024) diff --git a/scripts/spleeter/unet.py b/scripts/spleeter/unet.py index cfcabb6c1b..60eaa5e21b 100644 --- a/scripts/spleeter/unet.py +++ b/scripts/spleeter/unet.py @@ -67,6 +67,14 @@ def __init__(self): self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) def forward(self, x): + """ + Args: + x: (num_audio_channels, num_splits, 512, 1024) + Returns: + y: (num_audio_channels, num_splits, 512, 1024) + """ + x = x.permute(1, 0, 2, 3) + in_x = x # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) @@ -147,4 +155,5 @@ def forward(self, x): up7 = self.up7(batch12) up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) - return up7 * in_x + ans = up7 * in_x + return ans.permute(1, 0, 2, 3) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 52fe3bb357..f704bfeb7a 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -50,6 +50,13 @@ set(sources offline-rnn-lm.cc offline-sense-voice-model-config.cc offline-sense-voice-model.cc + + offline-source-separation-impl.cc + offline-source-separation-model-config.cc + offline-source-separation-spleeter-model-config.cc + offline-source-separation-spleeter-model.cc + offline-source-separation.cc + offline-stream.cc offline-tdnn-ctc-model.cc offline-tdnn-model-config.cc @@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) + add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc) add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc) @@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline-language-identification sherpa-onnx-offline-parallel sherpa-onnx-offline-punctuation + sherpa-onnx-offline-source-separation sherpa-onnx-online-punctuation sherpa-onnx-vad ) diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.cc b/sherpa-onnx/csrc/offline-source-separation-impl.cc new file mode 100644 index 0000000000..d2e9632828 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-impl.cc @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-source-separation-impl.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-impl.h" + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr +OfflineSourceSeparationImpl::Create( + const OfflineSourceSeparationConfig &config) { + // TODO(fangjun): Support other models + return std::make_unique(config); +} + +template +std::unique_ptr +OfflineSourceSeparationImpl::Create( + Manager *mgr, const OfflineSourceSeparationConfig &config) { + // TODO(fangjun): Support other models + return std::make_unique(mgr, config); +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +OfflineSourceSeparationImpl::Create( + AAssetManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +OfflineSourceSeparationImpl::Create( + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-impl.h b/sherpa-onnx/csrc/offline-source-separation-impl.h new file mode 100644 index 0000000000..8bb6852d0b --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-impl.h @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-source-separation-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ + +#include + +#include "sherpa-onnx/csrc/offline-source-separation.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationImpl { + public: + static std::unique_ptr Create( + const OfflineSourceSeparationConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineSourceSeparationConfig &config); + + virtual ~OfflineSourceSeparationImpl() = default; + + virtual OfflineSourceSeparationOutput Process( + const OfflineSourceSeparationInput &input) const = 0; + + virtual int32_t GetOutputSampleRate() const = 0; + + virtual int32_t GetNumberOfStems() const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-model-config.cc new file mode 100644 index 0000000000..dfd765d3f1 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-model-config.cc @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/offline-source-separation-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" + +namespace sherpa_onnx { + +void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) { + spleeter.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineSourceSeparationModelConfig::Validate() const { + return spleeter.Validate(); +} + +std::string OfflineSourceSeparationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSourceSeparationModelConfig("; + os << "spleeter=" << spleeter.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-model-config.h b/sherpa-onnx/csrc/offline-source-separation-model-config.h new file mode 100644 index 0000000000..bf88d39dbb --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-model-config.h @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/offline-source-separation-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSourceSeparationModelConfig { + OfflineSourceSeparationSpleeterModelConfig spleeter; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineSourceSeparationModelConfig() = default; + + OfflineSourceSeparationModelConfig( + const OfflineSourceSeparationSpleeterModelConfig &spleeter, + int32_t num_threads, bool debug, const std::string &provider) + : spleeter(spleeter), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h new file mode 100644 index 0000000000..7a707c63f3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h @@ -0,0 +1,276 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ + +#include "Eigen/Dense" +#include "kaldi-native-fbank/csrc/istft.h" +#include "kaldi-native-fbank/csrc/stft.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" +#include "sherpa-onnx/csrc/offline-source-separation.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/resample.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { + public: + OfflineSourceSeparationSpleeterImpl( + const OfflineSourceSeparationConfig &config) + : config_(config), model_(config_.model) {} + + template + OfflineSourceSeparationSpleeterImpl( + Manager *mgr, const OfflineSourceSeparationConfig &config) + : config_(config), model_(mgr, config_.model) {} + + OfflineSourceSeparationOutput Process( + const OfflineSourceSeparationInput &input) const override { + const OfflineSourceSeparationInput *p_input = &input; + OfflineSourceSeparationInput tmp_input; + + int32_t output_sample_rate = GetOutputSampleRate(); + + if (input.sample_rate != output_sample_rate) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + input.sample_rate, output_sample_rate); + + float min_freq = std::min(input.sample_rate, output_sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + input.sample_rate, output_sample_rate, lowpass_cutoff, + lowpass_filter_width); + + std::vector s; + for (const auto &samples : input.samples.data) { + resampler->Reset(); + resampler->Resample(samples.data(), samples.size(), true, &s); + tmp_input.samples.data.push_back(std::move(s)); + } + + tmp_input.sample_rate = output_sample_rate; + p_input = &tmp_input; + } + + if (p_input->samples.data.size() > 1) { + if (config_.model.debug) { + SHERPA_ONNX_LOGE("input ch1 samples size: %d", + static_cast(p_input->samples.data[1].size())); + } + + if (p_input->samples.data[0].size() != p_input->samples.data[1].size()) { + SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d", + static_cast(p_input->samples.data[0].size()), + static_cast(p_input->samples.data[1].size())); + + SHERPA_ONNX_EXIT(-1); + } + } + + auto stft_ch0 = ComputeStft(*p_input, 0); + + auto stft_ch1 = ComputeStft(*p_input, 1); + knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1; + + int32_t num_frames = stft_ch0.num_frames; + int32_t fft_bins = stft_ch0.real.size() / num_frames; + + int32_t pad = 512 - (stft_ch0.num_frames % 512); + if (pad < 512) { + num_frames += pad; + } + + if (num_frames % 512) { + SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d", + num_frames, num_frames % 512); + SHERPA_ONNX_EXIT(-1); + } + + Eigen::VectorXf real(2 * num_frames * 1024); + Eigen::VectorXf imag(2 * num_frames * 1024); + real.setZero(); + imag.setZero(); + + float *p_real = &real[0]; + float *p_imag = &imag[0]; + + // copy stft result of channel 0 + for (int32_t i = 0; i != stft_ch0.num_frames; ++i) { + std::copy(stft_ch0.real.data() + i * fft_bins, + stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i); + + std::copy(stft_ch0.imag.data() + i * fft_bins, + stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i); + } + + p_real += num_frames * 1024; + p_imag += num_frames * 1024; + + // copy stft result of channel 1 + for (int32_t i = 0; i != stft_ch1.num_frames; ++i) { + std::copy(p_stft_ch1->real.data() + i * fft_bins, + p_stft_ch1->real.data() + i * fft_bins + 1024, + p_real + 1024 * i); + + std::copy(p_stft_ch1->imag.data() + i * fft_bins, + p_stft_ch1->imag.data() + i * fft_bins + 1024, + p_imag + 1024 * i); + } + + Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{2, num_frames / 512, 512, 1024}; + Ort::Value x_tensor = Ort::Value::CreateTensor( + memory_info, &x[0], x.size(), x_shape.data(), x_shape.size()); + + Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor)); + Ort::Value accompaniment_spec_tensor = + model_.RunAccompaniment(std::move(x_tensor)); + + Eigen::VectorXf vocals_spec = Eigen::Map( + vocals_spec_tensor.GetTensorMutableData(), x.size()); + + Eigen::VectorXf accompaniment_spec = Eigen::Map( + accompaniment_spec_tensor.GetTensorMutableData(), x.size()); + + Eigen::VectorXf sum_spec = vocals_spec.array().square() + + accompaniment_spec.array().square() + 1e-10; + + vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array(); + + accompaniment_spec = + (accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array(); + + auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0); + auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1); + + auto accompaniment_samples_ch0 = + ProcessSpec(accompaniment_spec, stft_ch0, 0); + auto accompaniment_samples_ch1 = + ProcessSpec(accompaniment_spec, *p_stft_ch1, 1); + + OfflineSourceSeparationOutput ans; + ans.sample_rate = GetOutputSampleRate(); + + ans.stems.resize(2); + ans.stems[0].data.reserve(2); + ans.stems[1].data.reserve(2); + + ans.stems[0].data.push_back(std::move(vocals_samples_ch0)); + ans.stems[0].data.push_back(std::move(vocals_samples_ch1)); + + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0)); + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1)); + + return ans; + } + + int32_t GetOutputSampleRate() const override { + return model_.GetMetaData().sample_rate; + } + + int32_t GetNumberOfStems() const override { + return model_.GetMetaData().num_stems; + } + + private: + // spec is of shape (2, num_chunks, 512, 1024) + std::vector ProcessSpec(const Eigen::VectorXf &spec, + const knf::StftResult &stft, + int32_t channel) const { + int32_t fft_bins = stft.real.size() / stft.num_frames; + + Eigen::VectorXf mask(stft.real.size()); + mask.setZero(); + + float *p_mask = &mask[0]; + + // assume there are 2 channels + const float *p_spec = &spec[0] + (spec.size() / 2) * channel; + + for (int32_t i = 0; i != stft.num_frames; ++i) { + std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024, + p_mask + i * fft_bins); + } + + knf::StftResult masked_stft; + + masked_stft.num_frames = stft.num_frames; + masked_stft.real.resize(stft.real.size()); + masked_stft.imag.resize(stft.imag.size()); + + Eigen::Map(masked_stft.real.data(), + masked_stft.real.size()) = + mask.array() * + Eigen::Map(const_cast(stft.real.data()), + stft.real.size()) + .array(); + + Eigen::Map(masked_stft.imag.data(), + masked_stft.imag.size()) = + mask.array() * + Eigen::Map(const_cast(stft.imag.data()), + stft.imag.size()) + .array(); + + auto stft_config = GetStftConfig(); + knf::IStft istft(stft_config); + + return istft.Compute(masked_stft); + } + + knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input, + int32_t ch) const { + if (ch >= input.samples.data.size()) { + SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch, + static_cast(input.samples.data.size())); + SHERPA_ONNX_EXIT(-1); + } + + if (input.samples.data[ch].empty()) { + return {}; + } + + return ComputeStft(input.samples.data[ch]); + } + + knf::StftResult ComputeStft(const std::vector &samples) const { + auto stft_config = GetStftConfig(); + knf::Stft stft(stft_config); + + return stft.Compute(samples.data(), samples.size()); + } + + knf::StftConfig GetStftConfig() const { + const auto &meta = model_.GetMetaData(); + + knf::StftConfig stft_config; + stft_config.n_fft = meta.n_fft; + stft_config.hop_length = meta.hop_length; + stft_config.win_length = meta.window_length; + stft_config.window_type = meta.window_type; + stft_config.center = meta.center; + stft_config.center = false; + + return stft_config; + } + + private: + OfflineSourceSeparationConfig config_; + OfflineSourceSeparationSpleeterModel model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc new file mode 100644 index 0000000000..c43f693f91 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc @@ -0,0 +1,54 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineSourceSeparationSpleeterModelConfig::Register(ParseOptions *po) { + po->Register("spleeter-vocals", &vocals, "Path to the spleeter vocals model"); + + po->Register("spleeter-accompaniment", &accompaniment, + "Path to the spleeter accompaniment model"); +} + +bool OfflineSourceSeparationSpleeterModelConfig::Validate() const { + if (vocals.empty()) { + SHERPA_ONNX_LOGE("Please provide --spleeter-vocals"); + return false; + } + + if (!FileExists(vocals)) { + SHERPA_ONNX_LOGE("spleeter vocals '%s' does not exist. ", vocals.c_str()); + return false; + } + + if (accompaniment.empty()) { + SHERPA_ONNX_LOGE("Please provide --spleeter-accompaniment"); + return false; + } + + if (!FileExists(accompaniment)) { + SHERPA_ONNX_LOGE("spleeter accompaniment '%s' does not exist. ", + accompaniment.c_str()); + return false; + } + + return true; +} + +std::string OfflineSourceSeparationSpleeterModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSourceSeparationSpleeterModelConfig("; + os << "vocals=\"" << vocals << "\", "; + os << "accompaniment=\"" << accompaniment << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h new file mode 100644 index 0000000000..7e868966ad --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSourceSeparationSpleeterModelConfig { + std::string vocals; + + std::string accompaniment; + + OfflineSourceSeparationSpleeterModelConfig() = default; + + OfflineSourceSeparationSpleeterModelConfig(const std::string &vocals, + const std::string &accompaniment) + : vocals(vocals), accompaniment(accompaniment) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h new file mode 100644 index 0000000000..31b214cbd7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +// See also +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py +struct OfflineSourceSeparationSpleeterModelMetaData { + int32_t sample_rate = 44100; + int32_t num_stems = 2; + + int32_t n_fft = 4096; + int32_t hop_length = 1024; + int32_t window_length = 4096; + bool center = false; + std::string window_type = "hann"; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc new file mode 100644 index 0000000000..e3c1651115 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc @@ -0,0 +1,212 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationSpleeterModel::Impl { + public: + explicit Impl(const OfflineSourceSeparationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.spleeter.vocals); + InitVocals(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.spleeter.accompaniment); + InitAccompaniment(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineSourceSeparationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.spleeter.vocals); + InitVocals(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.spleeter.accompaniment); + InitAccompaniment(buf.data(), buf.size()); + } + } + + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const { + return meta_; + } + + Ort::Value RunVocals(Ort::Value x) const { + auto out = vocals_sess_->Run({}, vocals_input_names_ptr_.data(), &x, 1, + vocals_output_names_ptr_.data(), + vocals_output_names_ptr_.size()); + return std::move(out[0]); + } + + Ort::Value RunAccompaniment(Ort::Value x) const { + auto out = + accompaniment_sess_->Run({}, accompaniment_input_names_ptr_.data(), &x, + 1, accompaniment_output_names_ptr_.data(), + accompaniment_output_names_ptr_.size()); + return std::move(out[0]); + } + + private: + void InitVocals(void *model_data, size_t model_data_length) { + vocals_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(vocals_sess_.get(), &vocals_input_names_, + &vocals_input_names_ptr_); + + GetOutputNames(vocals_sess_.get(), &vocals_output_names_, + &vocals_output_names_ptr_); + + Ort::ModelMetadata meta_data = vocals_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---vocals model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : vocals_input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : vocals_output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + std::string model_type; + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + if (model_type != "spleeter") { + SHERPA_ONNX_LOGE("Expect model type 'spleeter'. Given: '%s'", + model_type.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + SHERPA_ONNX_READ_META_DATA(meta_.num_stems, "stems"); + if (meta_.num_stems != 2) { + SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems", + meta_.num_stems); + SHERPA_ONNX_EXIT(-1); + } + } + + void InitAccompaniment(void *model_data, size_t model_data_length) { + accompaniment_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(accompaniment_sess_.get(), &accompaniment_input_names_, + &accompaniment_input_names_ptr_); + + GetOutputNames(accompaniment_sess_.get(), &accompaniment_output_names_, + &accompaniment_output_names_ptr_); + } + + private: + OfflineSourceSeparationModelConfig config_; + OfflineSourceSeparationSpleeterModelMetaData meta_; + + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr vocals_sess_; + + std::vector vocals_input_names_; + std::vector vocals_input_names_ptr_; + + std::vector vocals_output_names_; + std::vector vocals_output_names_ptr_; + + std::unique_ptr accompaniment_sess_; + + std::vector accompaniment_input_names_; + std::vector accompaniment_input_names_ptr_; + + std::vector accompaniment_output_names_; + std::vector accompaniment_output_names_ptr_; +}; + +OfflineSourceSeparationSpleeterModel::~OfflineSourceSeparationSpleeterModel() = + default; + +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel( + const OfflineSourceSeparationModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel( + Manager *mgr, const OfflineSourceSeparationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +Ort::Value OfflineSourceSeparationSpleeterModel::RunVocals(Ort::Value x) const { + return impl_->RunVocals(std::move(x)); +} + +Ort::Value OfflineSourceSeparationSpleeterModel::RunAccompaniment( + Ort::Value x) const { + return impl_->RunAccompaniment(std::move(x)); +} + +const OfflineSourceSeparationSpleeterModelMetaData & +OfflineSourceSeparationSpleeterModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSourceSeparationSpleeterModel:: + OfflineSourceSeparationSpleeterModel( + AAssetManager *mgr, const OfflineSourceSeparationModelConfig &config); +#endif + +#if __OHOS__ +template OfflineSourceSeparationSpleeterModel:: + OfflineSourceSeparationSpleeterModel( + NativeResourceManager *mgr, + const OfflineSourceSeparationModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation-spleeter-model.h b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.h new file mode 100644 index 0000000000..8cad92fe55 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation-spleeter-model.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSourceSeparationSpleeterModel { + public: + ~OfflineSourceSeparationSpleeterModel(); + + explicit OfflineSourceSeparationSpleeterModel( + const OfflineSourceSeparationModelConfig &config); + + template + OfflineSourceSeparationSpleeterModel( + Manager *mgr, const OfflineSourceSeparationModelConfig &config); + + Ort::Value RunVocals(Ort::Value x) const; + Ort::Value RunAccompaniment(Ort::Value x) const; + + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-source-separation.cc b/sherpa-onnx/csrc/offline-source-separation.cc new file mode 100644 index 0000000000..d352d9ab58 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation.cc @@ -0,0 +1,74 @@ +// sherpa-onnx/csrc/offline-source-separation.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-source-separation.h" + +#include + +#include "sherpa-onnx/csrc/offline-source-separation-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +namespace sherpa_onnx { + +void OfflineSourceSeparationConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OfflineSourceSeparationConfig::Validate() const { + return model.Validate(); +} + +std::string OfflineSourceSeparationConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSourceSeparationConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +template +OfflineSourceSeparation::OfflineSourceSeparation( + Manager *mgr, const OfflineSourceSeparationConfig &config) + : impl_(OfflineSourceSeparationImpl::Create(mgr, config)) {} + +OfflineSourceSeparation::OfflineSourceSeparation( + const OfflineSourceSeparationConfig &config) + : impl_(OfflineSourceSeparationImpl::Create(config)) {} + +OfflineSourceSeparation::~OfflineSourceSeparation() = default; + +OfflineSourceSeparationOutput OfflineSourceSeparation::Process( + const OfflineSourceSeparationInput &input) const { + return impl_->Process(input); +} + +int32_t OfflineSourceSeparation::GetOutputSampleRate() const { + return impl_->GetOutputSampleRate(); +} + +// e.g., it is 2 for 2stems from spleeter +int32_t OfflineSourceSeparation::GetNumberOfStems() const { + return impl_->GetNumberOfStems(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSourceSeparation::OfflineSourceSeparation( + AAssetManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +#if __OHOS__ +template OfflineSourceSeparation::OfflineSourceSeparation( + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-source-separation.h b/sherpa-onnx/csrc/offline-source-separation.h new file mode 100644 index 0000000000..dc9e82a510 --- /dev/null +++ b/sherpa-onnx/csrc/offline-source-separation.h @@ -0,0 +1,77 @@ +// sherpa-onnx/csrc/offline-source-separation.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineSourceSeparationConfig { + OfflineSourceSeparationModelConfig model; + + OfflineSourceSeparationConfig() = default; + + OfflineSourceSeparationConfig(const OfflineSourceSeparationModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +struct MultiChannelSamples { + // data[i] is for the i-th channel + // + // each sample is in the range [-1, 1] + std::vector> data; +}; + +struct OfflineSourceSeparationInput { + MultiChannelSamples samples; + + int32_t sample_rate; +}; + +struct OfflineSourceSeparationOutput { + std::vector stems; + + int32_t sample_rate; +}; + +class OfflineSourceSeparationImpl; + +class OfflineSourceSeparation { + public: + ~OfflineSourceSeparation(); + + OfflineSourceSeparation(const OfflineSourceSeparationConfig &config); + + template + OfflineSourceSeparation(Manager *mgr, + const OfflineSourceSeparationConfig &config); + + OfflineSourceSeparationOutput Process( + const OfflineSourceSeparationInput &input) const; + + int32_t GetOutputSampleRate() const; + + // e.g., it is 2 for 2stems from spleeter + int32_t GetNumberOfStems() const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ diff --git a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h index 8cf0cdabf8..36ebf83b69 100644 --- a/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h +++ b/sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h @@ -12,7 +12,7 @@ namespace sherpa_onnx { // please refer to -// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/gtcrn/add_meta_data.py struct OfflineSpeechDenoiserGtcrnModelMetaData { int32_t sample_rate = 0; int32_t version = 1; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc index 61d6afd259..0ec314879d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc @@ -11,7 +11,7 @@ int main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( -Non-stremaing speech denoising with sherpa-onnx. +Non-streaming speech denoising with sherpa-onnx. Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc new file mode 100644 index 0000000000..8af94aa1df --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc @@ -0,0 +1,138 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include + +#include // NOLINT +#include + +#include "sherpa-onnx/csrc/offline-source-separation.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "sherpa-onnx/csrc/wave-writer.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Non-streaming source separation with sherpa-onnx. + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models +to download models. + +Usage: + +(1) Use spleeter models + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 +tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav + +./bin/sherpa-onnx-offline-source-separation \ + --spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \ + --spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \ + --input-wav=audio_example.wav \ + --output-vocals-wav=output_vocals.wav \ + --output-accompaniment-wav=output_accompaniment.wav +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OfflineSourceSeparationConfig config; + + std::string input_wave; + std::string output_vocals_wave; + std::string output_accompaniment_wave; + + config.Register(&po); + po.Register("input-wav", &input_wave, "Path to input wav."); + po.Register("output-vocals-wav", &output_vocals_wave, + "Path to output vocals wav"); + po.Register("output-accompaniment-wav", &output_accompaniment_wave, + "Path to output accompaniment wav"); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, "Please don't give positional arguments\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (input_wave.empty()) { + fprintf(stderr, "Please provide --input-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (output_vocals_wave.empty()) { + fprintf(stderr, "Please provide --output-vocals-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (output_accompaniment_wave.empty()) { + fprintf(stderr, "Please provide --output-accompaniment-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + exit(EXIT_FAILURE); + } + + bool is_ok = false; + sherpa_onnx::OfflineSourceSeparationInput input; + input.samples.data = + sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); + return -1; + } + + fprintf(stderr, "Started\n"); + + sherpa_onnx::OfflineSourceSeparation sp(config); + + const auto begin = std::chrono::steady_clock::now(); + auto output = sp.Process(input); + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + is_ok = sherpa_onnx::WriteWave( + output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(), + output.stems[0].data[1].data(), output.stems[0].data[0].size()); + + if (!is_ok) { + fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str()); + exit(EXIT_FAILURE); + } + + is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate, + output.stems[1].data[0].data(), + output.stems[1].data[1].data(), + output.stems[1].data[0].size()); + + if (!is_ok) { + fprintf(stderr, "Failed to write to '%s'\n", + output_accompaniment_wave.c_str()); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "Done\n"); + fprintf(stderr, "Saved to write to '%s' and '%s'\n", + output_vocals_wave.c_str(), output_accompaniment_wave.c_str()); + + float duration = + input.samples.data[0].size() / static_cast(input.sample_rate); + fprintf(stderr, "num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/sherpa-onnx/csrc/wave-reader.cc b/sherpa-onnx/csrc/wave-reader.cc index 0deb29a4f0..90db0d513b 100644 --- a/sherpa-onnx/csrc/wave-reader.cc +++ b/sherpa-onnx/csrc/wave-reader.cc @@ -63,8 +63,9 @@ in sherpa-onnx. // Read a wave file of mono-channel. // Return its samples normalized to the range [-1, 1). -std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, - bool *is_ok) { +std::vector> ReadWaveImpl(std::istream &is, + int32_t *sampling_rate, + bool *is_ok) { WaveHeader header{}; is.read(reinterpret_cast(&header.chunk_id), sizeof(header.chunk_id)); @@ -144,12 +145,6 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, is.read(reinterpret_cast(&header.num_channels), sizeof(header.num_channels)); - if (header.num_channels != 1) { // we support only single channel for now - SHERPA_ONNX_LOGE( - "Warning: %d channels are found. We only use the first channel.\n", - header.num_channels); - } - is.read(reinterpret_cast(&header.sample_rate), sizeof(header.sample_rate)); @@ -219,7 +214,7 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, *sampling_rate = header.sample_rate; - std::vector ans; + std::vector> ans(header.num_channels); if (header.bits_per_sample == 16 && header.audio_format == 1) { // header.subchunk2_size contains the number of bytes in the data. @@ -233,11 +228,16 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); + for (auto &v : ans) { + v.resize(samples.size() / header.num_channels); + } // samples are interleaved - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - ans[i] = samples[i * header.num_channels] / 32768.; + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + ans[c][k] = samples[i + c] / 32768.; + } } } else if (header.bits_per_sample == 8 && header.audio_format == 1) { // number of samples == number of bytes for 8-bit encoded samples @@ -252,14 +252,21 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - // Note(fangjun): We want to normalize each sample into the range [-1, 1] - // Since each original sample is in the range [0, 256], dividing - // them by 128 converts them to the range [0, 2]; - // so after subtracting 1, we get the range [-1, 1] - // - ans[i] = samples[i * header.num_channels] / 128. - 1; + for (auto &v : ans) { + v.resize(samples.size() / header.num_channels); + } + + // samples are interleaved + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + // Note(fangjun): We want to normalize each sample into the range [-1, + // 1] Since each original sample is in the range [0, 256], dividing them + // by 128 converts them to the range [0, 2]; so after subtracting 1, we + // get the range [-1, 1] + // + ans[c][k] = samples[i + c] / 128. - 1; + } } } else if (header.bits_per_sample == 32 && header.audio_format == 1) { // 32 here is for int32 @@ -275,9 +282,16 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - ans[i] = static_cast(samples[i * header.num_channels]) / (1 << 31); + for (auto &v : ans) { + v.resize(samples.size() / header.num_channels); + } + + // samples are interleaved + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + ans[c][k] = static_cast(samples[i + c]) / (1 << 31); + } } } else if (header.bits_per_sample == 32 && header.audio_format == 3) { // 32 here is for float32 @@ -293,9 +307,16 @@ std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, return {}; } - ans.resize(samples.size() / header.num_channels); - for (int32_t i = 0; i != static_cast(ans.size()); ++i) { - ans[i] = samples[i * header.num_channels]; + for (auto &v : ans) { + v.resize(samples.size() / header.num_channels); + } + + // samples are interleaved + for (int32_t i = 0, k = 0; i < static_cast(samples.size()); + i += header.num_channels, ++k) { + for (int32_t c = 0; c != header.num_channels; ++c) { + ans[c][k] = samples[i + c]; + } } } else { SHERPA_ONNX_LOGE( @@ -321,7 +342,27 @@ std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, std::vector ReadWave(std::istream &is, int32_t *sampling_rate, bool *is_ok) { auto samples = ReadWaveImpl(is, sampling_rate, is_ok); + + if (samples.size() > 1) { + SHERPA_ONNX_LOGE( + "Warning: %d channels are found. We only use the first channel.\n", + static_cast(samples.size())); + } + + return samples[0]; +} + +std::vector> ReadWaveMultiChannel(std::istream &is, + int32_t *sampling_rate, + bool *is_ok) { + auto samples = ReadWaveImpl(is, sampling_rate, is_ok); return samples; } +std::vector> ReadWaveMultiChannel( + const std::string &filename, int32_t *sampling_rate, bool *is_ok) { + std::ifstream is(filename, std::ifstream::binary); + return ReadWaveMultiChannel(is, sampling_rate, is_ok); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/wave-reader.h b/sherpa-onnx/csrc/wave-reader.h index 98e956abf6..1ba9910218 100644 --- a/sherpa-onnx/csrc/wave-reader.h +++ b/sherpa-onnx/csrc/wave-reader.h @@ -26,6 +26,13 @@ std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, std::vector ReadWave(std::istream &is, int32_t *sampling_rate, bool *is_ok); +std::vector> ReadWaveMultiChannel(std::istream &is, + int32_t *sampling_rate, + bool *is_ok); + +std::vector> ReadWaveMultiChannel( + const std::string &filename, int32_t *sampling_rate, bool *is_ok); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ diff --git a/sherpa-onnx/csrc/wave-writer.cc b/sherpa-onnx/csrc/wave-writer.cc index f3aca8e90c..98cd24c989 100644 --- a/sherpa-onnx/csrc/wave-writer.cc +++ b/sherpa-onnx/csrc/wave-writer.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/csrc/wave-writer.h" +#include #include #include #include @@ -36,12 +37,44 @@ struct WaveHeader { } // namespace -int64_t WaveFileSize(int32_t n_samples) { - return sizeof(WaveHeader) + n_samples * sizeof(int16_t); +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels /*= 1*/) { + return sizeof(WaveHeader) + n_samples * sizeof(int16_t) * num_channels; } void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, int32_t n) { + WriteWave(buffer, sampling_rate, samples, nullptr, n); +} + +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples, int32_t n) { + return WriteWave(filename, sampling_rate, samples, nullptr, n); +} + +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples_ch0, const float *samples_ch1, int32_t n) { + std::string buffer; + buffer.resize(WaveFileSize(n, samples_ch1 == nullptr ? 1 : 2)); + + WriteWave(buffer.data(), sampling_rate, samples_ch0, samples_ch1, n); + + std::ofstream os(filename, std::ios::binary); + if (!os) { + SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str()); + return false; + } + + os << buffer; + if (!os) { + SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str()); + return false; + } + + return true; +} + +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0, + const float *samples_ch1, int32_t n) { WaveHeader header{}; header.chunk_id = 0x46464952; // FFIR header.format = 0x45564157; // EVAW @@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, header.subchunk1_size = 16; // 16 for PCM header.audio_format = 1; // PCM =1 - int32_t num_channels = 1; + int32_t num_channels = samples_ch1 == nullptr ? 1 : 2; int32_t bits_per_sample = 16; // int16_t + header.num_channels = num_channels; header.sample_rate = sampling_rate; header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; @@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, header.chunk_size = 36 + header.subchunk2_size; - std::vector samples_int16(n); + std::vector samples_int16_ch0(n); for (int32_t i = 0; i != n; ++i) { - samples_int16[i] = samples[i] * 32767; + samples_int16_ch0[i] = std::min(samples_ch0[i] * 32767, 32767); + } + + std::vector samples_int16_ch1; + if (samples_ch1) { + samples_int16_ch1.resize(n); + for (int32_t i = 0; i != n; ++i) { + samples_int16_ch1[i] = std::min(samples_ch1[i] * 32767, 32767); + } } memcpy(buffer, &header, sizeof(WaveHeader)); - memcpy(buffer + sizeof(WaveHeader), samples_int16.data(), - n * sizeof(int16_t)); -} -bool WriteWave(const std::string &filename, int32_t sampling_rate, - const float *samples, int32_t n) { - std::string buffer; - buffer.resize(WaveFileSize(n)); - WriteWave(buffer.data(), sampling_rate, samples, n); - std::ofstream os(filename, std::ios::binary); - if (!os) { - SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str()); - return false; - } - os << buffer; - if (!os) { - SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str()); - return false; + if (samples_ch1 == nullptr) { + memcpy(buffer + sizeof(WaveHeader), samples_int16_ch0.data(), + n * sizeof(int16_t)); + } else { + auto p = reinterpret_cast(buffer + sizeof(WaveHeader)); + + for (int32_t i = 0; i != n; ++i) { + p[2 * i] = samples_int16_ch0[i]; + p[2 * i + 1] = samples_int16_ch1[i]; + } } - return true; } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/wave-writer.h b/sherpa-onnx/csrc/wave-writer.h index 054d9eaff3..0d6e8c6a1a 100644 --- a/sherpa-onnx/csrc/wave-writer.h +++ b/sherpa-onnx/csrc/wave-writer.h @@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate, void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, int32_t n); -int64_t WaveFileSize(int32_t n_samples); +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples_ch0, const float *samples_ch1, int32_t n); + +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0, + const float *samples_ch1, int32_t n); + +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels = 1); } // namespace sherpa_onnx diff --git a/wasm/speech-enhancement/app-speech-enhancement.js b/wasm/speech-enhancement/app-speech-enhancement.js index fe67c4258b..cc89c67fe7 100644 --- a/wasm/speech-enhancement/app-speech-enhancement.js +++ b/wasm/speech-enhancement/app-speech-enhancement.js @@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) { console.log('ArrayBuffer length:', arrayBuffer.byteLength); const uint8Array = new Uint8Array(arrayBuffer); - const wave = readWaveFromBinaryData(uint8Array); + const wave = readWaveFromBinaryData(uint8Array, Module); if (wave == null) { alert( `${file.name} is not a valid .wav file. Please select a *.wav file`);