diff --git a/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml b/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml index 3d02994f53..8811c3cec1 100644 --- a/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml +++ b/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml @@ -1,4 +1,4 @@ -name: export-nemo-speaker-verification-to-onnx +name: export-nemo-fast-conformer-ctc-to-onnx on: workflow_dispatch: diff --git a/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer.yaml b/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer.yaml new file mode 100644 index 0000000000..854cb89df0 --- /dev/null +++ b/.github/workflows/export-nemo-fast-conformer-hybrid-transducer-transducer.yaml @@ -0,0 +1,73 @@ +name: export-nemo-fast-conformer-transducer-to-onnx + +on: + workflow_dispatch: + +concurrency: + group: export-nemo-fast-conformer-hybrid-transducer-to-onnx-${{ github.ref }} + cancel-in-progress: true + +jobs: + export-nemo-fast-conformer-hybrid-transducer-to-onnx: + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' + name: NeMo transducer + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install NeMo + shell: bash + run: | + BRANCH='main' + pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] + pip install onnxruntime + pip install kaldi-native-fbank + pip install soundfile librosa + + - name: Run + shell: bash + run: | + cd scripts/nemo/fast-conformer-hybrid-transducer-ctc + ./run-transducer.sh + + mv -v sherpa-onnx-nemo* ../../.. + + - name: Download test waves + shell: bash + run: | + mkdir test_wavs + pushd test_wavs + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt + popd + + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms + + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms + + - name: Release + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py index 5751c90b5a..f5f02e0319 100755 --- a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import argparse from typing import Dict diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer.py new file mode 100755 index 0000000000..985f1fae78 --- /dev/null +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-transducer.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +import argparse +from typing import Dict + +import nemo.collections.asr as nemo_asr +import onnx +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + required=True, + choices=["80", "480", "1040"], + ) + return parser.parse_args() + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +@torch.no_grad() +def main(): + args = get_args() + model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms" + + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) + + with open("./tokens.txt", "w", encoding="utf-8") as f: + for i, s in enumerate(asr_model.joint.vocabulary): + f.write(f"{s} {i}\n") + f.write(f" {i+1}\n") + print("Saved to tokens.txt") + + decoder_type = "rnnt" + asr_model.change_decoding_strategy(decoder_type=decoder_type) + asr_model.eval() + + assert asr_model.encoder.streaming_cfg is not None + if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list): + chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1] + else: + chunk_size = asr_model.encoder.streaming_cfg.chunk_size + + if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list): + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size + window_size = chunk_size + pre_encode_cache_size + + print("chunk_size", chunk_size) + print("pre_encode_cache_size", pre_encode_cache_size) + print("window_size", window_size) + + chunk_shift = chunk_size + + # cache_last_channel: (batch_size, dim1, dim2, dim3) + cache_last_channel_dim1 = len(asr_model.encoder.layers) + cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size + cache_last_channel_dim3 = asr_model.encoder.d_model + + # cache_last_time: (batch_size, dim1, dim2, dim3) + cache_last_time_dim1 = len(asr_model.encoder.layers) + cache_last_time_dim2 = asr_model.encoder.d_model + cache_last_time_dim3 = asr_model.encoder.conv_context_size[0] + + asr_model.set_export_config({"decoder_type": "rnnt", "cache_support": True}) + + # asr_model.export("model.onnx") + asr_model.encoder.export("encoder.onnx") + asr_model.decoder.export("decoder.onnx") + asr_model.joint.export("joiner.onnx") + # model.onnx is a suffix. + # It will generate two files: + # encoder-model.onnx + # decoder_joint-model.onnx + + meta_data = { + "vocab_size": asr_model.decoder.vocab_size, + "window_size": window_size, + "chunk_shift": chunk_shift, + "normalize_type": "None", + "cache_last_channel_dim1": cache_last_channel_dim1, + "cache_last_channel_dim2": cache_last_channel_dim2, + "cache_last_channel_dim3": cache_last_channel_dim3, + "cache_last_time_dim1": cache_last_time_dim1, + "cache_last_time_dim2": cache_last_time_dim2, + "cache_last_time_dim3": cache_last_time_dim3, + "pred_rnn_layers": asr_model.decoder.pred_rnn_layers, + "pred_hidden": asr_model.decoder.pred_hidden, + "subsampling_factor": 8, + "model_type": "EncDecHybridRNNTCTCBPEModel", + "version": "1", + "model_author": "NeMo", + "url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", + "comment": "Only the transducer branch is exported", + } + add_meta_data("encoder.onnx", meta_data) + + print(meta_data) + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh index 6fae9b9499..00c31b78bc 100755 --- a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) set -ex diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-transducer.sh b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-transducer.sh new file mode 100755 index 0000000000..cddee8fc42 --- /dev/null +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-transducer.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +set -ex + +if [ ! -e ./0.wav ]; then + # curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav + curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav +fi + +ms=( +80 +480 +1040 +) + +for m in ${ms[@]}; do + ./export-onnx-transducer.py --model $m + d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms + if [ ! -f $d/encoder.onnx ]; then + mkdir -p $d + mv -v encoder.onnx $d/ + mv -v decoder.onnx $d/ + mv -v joiner.onnx $d/ + mv -v tokens.txt $d/ + ls -lh $d + fi +done + +# Now test the exported models + +for m in ${ms[@]}; do + d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms + python3 ./test-onnx-transducer.py \ + --encoder $d/encoder.onnx \ + --decoder $d/decoder.onnx \ + --joiner $d/joiner.onnx \ + --tokens $d/tokens.txt \ + --wav ./0.wav +done diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/show-onnx-transudcer.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/show-onnx-transudcer.py new file mode 100755 index 0000000000..18dc2ebe27 --- /dev/null +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/show-onnx-transudcer.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import onnxruntime + + +def show(filename): + session_opts = onnxruntime.SessionOptions() + session_opts.log_severity_level = 3 + sess = onnxruntime.InferenceSession(filename, session_opts) + for i in sess.get_inputs(): + print(i) + + print("-----") + + for i in sess.get_outputs(): + print(i) + + +def main(): + print("=========encoder==========") + show("./encoder.onnx") + + print("=========decoder==========") + show("./decoder.onnx") + + print("=========joiner==========") + show("./joiner.onnx") + + +if __name__ == "__main__": + main() + +""" +=========encoder========== +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2']) +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) +NodeArg(name='cache_last_channel', type='tensor(float)', shape=['cache_last_channel_dynamic_axes_1', 17, 'cache_last_channel_dynamic_axes_2', 512]) +NodeArg(name='cache_last_time', type='tensor(float)', shape=['cache_last_time_dynamic_axes_1', 17, 512, 'cache_last_time_dynamic_axes_2']) +NodeArg(name='cache_last_channel_len', type='tensor(int64)', shape=['cache_last_channel_len_dynamic_axes_1']) +----- +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2']) +NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1']) +NodeArg(name='cache_last_channel_next', type='tensor(float)', shape=['cache_last_channel_next_dynamic_axes_1', 17, 'cache_last_channel_next_dynamic_axes_2', 512]) +NodeArg(name='cache_last_time_next', type='tensor(float)', shape=['cache_last_time_next_dynamic_axes_1', 17, 512, 'cache_last_time_next_dynamic_axes_2']) +NodeArg(name='cache_last_channel_next_len', type='tensor(int64)', shape=['cache_last_channel_next_len_dynamic_axes_1']) +=========decoder========== +NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2']) +NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1']) +NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640]) +NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640]) +----- +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2']) +NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1']) +NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640]) +NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640]) +=========joiner========== +NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2']) +NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2']) +----- +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025]) + +""" diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py index 77c7a526b4..1ed6c4a61e 100755 --- a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-ctc.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) import argparse from pathlib import Path diff --git a/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py new file mode 100755 index 0000000000..c671851f5b --- /dev/null +++ b/scripts/nemo/fast-conformer-hybrid-transducer-ctc/test-onnx-transducer.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +from pathlib import Path + +import kaldi_native_fbank as knf +import librosa +import numpy as np +import onnxruntime as ort +import soundfile as sf +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder", type=str, required=True, help="Path to encoder.onnx" + ) + parser.add_argument( + "--decoder", type=str, required=True, help="Path to decoder.onnx" + ) + parser.add_argument("--joiner", type=str, required=True, help="Path to joiner.onnx") + + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") + + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") + + return parser.parse_args() + + +def create_fbank(): + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.remove_dc_offset = False + opts.frame_opts.window_type = "hann" + + opts.mel_opts.low_freq = 0 + opts.mel_opts.num_bins = 80 + + opts.mel_opts.is_librosa = True + + fbank = knf.OnlineFbank(opts) + return fbank + + +def compute_features(audio, fbank): + assert len(audio.shape) == 1, audio.shape + fbank.accept_waveform(16000, audio) + ans = [] + processed = 0 + while processed < fbank.num_frames_ready: + ans.append(np.array(fbank.get_frame(processed))) + processed += 1 + ans = np.stack(ans) + return ans + + +class OnnxModel: + def __init__( + self, + encoder: str, + decoder: str, + joiner: str, + ): + self.init_encoder(encoder) + self.init_decoder(decoder) + self.init_joiner(joiner) + + def init_encoder(self, encoder): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.encoder = ort.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = self.encoder.get_modelmeta().custom_metadata_map + print(meta) + + self.window_size = int(meta["window_size"]) + self.chunk_shift = int(meta["chunk_shift"]) + + self.cache_last_channel_dim1 = int(meta["cache_last_channel_dim1"]) + self.cache_last_channel_dim2 = int(meta["cache_last_channel_dim2"]) + self.cache_last_channel_dim3 = int(meta["cache_last_channel_dim3"]) + + self.cache_last_time_dim1 = int(meta["cache_last_time_dim1"]) + self.cache_last_time_dim2 = int(meta["cache_last_time_dim2"]) + self.cache_last_time_dim3 = int(meta["cache_last_time_dim3"]) + + self.pred_rnn_layers = int(meta["pred_rnn_layers"]) + self.pred_hidden = int(meta["pred_hidden"]) + + self.init_cache_state() + + def init_decoder(self, decoder): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.decoder = ort.InferenceSession( + decoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + def init_joiner(self, joiner): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.joiner = ort.InferenceSession( + joiner, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + def get_decoder_state(self): + batch_size = 1 + state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() + state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() + return state0, state1 + + def init_cache_state(self): + self.cache_last_channel = torch.zeros( + 1, + self.cache_last_channel_dim1, + self.cache_last_channel_dim2, + self.cache_last_channel_dim3, + dtype=torch.float32, + ).numpy() + + self.cache_last_time = torch.zeros( + 1, + self.cache_last_time_dim1, + self.cache_last_time_dim2, + self.cache_last_time_dim3, + dtype=torch.float32, + ).numpy() + + self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() + + def run_encoder(self, x: np.ndarray): + # x: (T, C) + x = torch.from_numpy(x) + x = x.t().unsqueeze(0) + # x: [1, C, T] + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) + + ( + encoder_out, + out_len, + cache_last_channel_next, + cache_last_time_next, + cache_last_channel_len_next, + ) = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + self.encoder.get_outputs()[2].name, + self.encoder.get_outputs()[3].name, + self.encoder.get_outputs()[4].name, + ], + { + self.encoder.get_inputs()[0].name: x.numpy(), + self.encoder.get_inputs()[1].name: x_lens.numpy(), + self.encoder.get_inputs()[2].name: self.cache_last_channel, + self.encoder.get_inputs()[3].name: self.cache_last_time, + self.encoder.get_inputs()[4].name: self.cache_last_channel_len, + }, + ) + self.cache_last_channel = cache_last_channel_next + self.cache_last_time = cache_last_time_next + self.cache_last_channel_len = cache_last_channel_len_next + + # [batch_size, dim, T] + return encoder_out + + def run_decoder( + self, + token: int, + state0: np.ndarray, + state1: np.ndarray, + ): + target = torch.tensor([[token]], dtype=torch.int32).numpy() + target_len = torch.tensor([1], dtype=torch.int32).numpy() + + ( + decoder_out, + decoder_out_length, + state0_next, + state1_next, + ) = self.decoder.run( + [ + self.decoder.get_outputs()[0].name, + self.decoder.get_outputs()[1].name, + self.decoder.get_outputs()[2].name, + self.decoder.get_outputs()[3].name, + ], + { + self.decoder.get_inputs()[0].name: target, + self.decoder.get_inputs()[1].name: target_len, + self.decoder.get_inputs()[2].name: state0, + self.decoder.get_inputs()[3].name: state1, + }, + ) + return decoder_out, state0_next, state1_next + + def run_joiner( + self, + encoder_out: np.ndarray, + decoder_out: np.ndarray, + ): + # encoder_out: [batch_size, dim, 1] + # decoder_out: [batch_size, dim, 1] + logit = self.joiner.run( + [ + self.joiner.get_outputs()[0].name, + ], + { + self.joiner.get_inputs()[0].name: encoder_out, + self.joiner.get_inputs()[1].name: decoder_out, + }, + )[0] + # logit: [batch_size, 1, 1, vocab_size] + return logit + + +def main(): + args = get_args() + assert Path(args.encoder).is_file(), args.encoder + assert Path(args.decoder).is_file(), args.decoder + assert Path(args.joiner).is_file(), args.joiner + assert Path(args.tokens).is_file(), args.tokens + assert Path(args.wav).is_file(), args.wav + + print(vars(args)) + + model = OnnxModel(args.encoder, args.decoder, args.joiner) + + id2token = dict() + with open(args.tokens, encoding="utf-8") as f: + for line in f: + t, idx = line.split() + id2token[int(idx)] = t + + fbank = create_fbank() + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + if sample_rate != 16000: + audio = librosa.resample( + audio, + orig_sr=sample_rate, + target_sr=16000, + ) + sample_rate = 16000 + + tail_padding = np.zeros(sample_rate * 2) + + audio = np.concatenate([audio, tail_padding]) + + window_size = model.window_size + chunk_shift = model.chunk_shift + + blank = len(id2token) - 1 + ans = [blank] + state0, state1 = model.get_decoder_state() + decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1) + + features = compute_features(audio, fbank) + num_chunks = (features.shape[0] - window_size) // chunk_shift + 1 + for i in range(num_chunks): + start = i * chunk_shift + end = start + window_size + chunk = features[start:end, :] + + encoder_out = model.run_encoder(chunk) + # encoder_out:[batch_size, dim, T) + for t in range(encoder_out.shape[2]): + encoder_out_t = encoder_out[:, :, t : t + 1] + logits = model.run_joiner(encoder_out_t, decoder_out) + logits = torch.from_numpy(logits) + logits = logits.squeeze() + idx = torch.argmax(logits, dim=-1).item() + if idx != blank: + ans.append(idx) + state0 = state0_next + state1 = state1_next + decoder_out, state0_next, state1_next = model.run_decoder( + ans[-1], state0, state1 + ) + + ans = ans[1:] # remove the first blank + tokens = [id2token[i] for i in ans] + underline = "▁" + # underline = b"\xe2\x96\x81".decode() + text = "".join(tokens).replace(underline, " ").strip() + print(args.wav) + print(text) + + +main()