-
Notifications
You must be signed in to change notification settings - Fork 90
Description
$ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "aed" --model_dir pretrained_models/FireRedASR-AED-L
Namespace(asr_type='aed', model_dir='pretrained_models/FireRedASR-AED-L', wav_path='examples/wav/BAC009S0764W0121.wav', wav_paths=None, wav_dir=None, wav_scp=None, output=None, use_gpu=1, batch_size=1, beam_size=1, decode_max_len=0, nbest=1, softmax_smoothing=1.0, aed_length_penalty=0.0, eos_penalty=1.0, decode_min_len=0, repetition_penalty=1.0, llm_length_penalty=0.0, temperature=1.0)
開始執行:模型加載與轉錄程序
檢查點成功:模型檔案存在於 pretrained_models/FireRedASR-AED-L/model.pth
#wavs=1
錯誤:加載模型時發生錯誤 - Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
(1) In PyTorch 2.6, we changed the default value of the weights_only
argument in torch.load
from False
to True
. Re-running torch.load
with weights_only
set to False
will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with weights_only=True
please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL argparse.Namespace was not an allowed global by default. Please use torch.serialization.add_safe_globals([Namespace])
or the torch.serialization.safe_globals([Namespace])
context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
已修改weights_only=False
def load_fireredasr_aed_model(model_path):
# 使用安全全局
torch.serialization.add_safe_globals([argparse.Namespace])
# 加載檢查點,設置 weights_only=False
try:
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
print("檢查點成功:模型權重加載完成")
except Exception as e:
print(f"錯誤:加載檢查點時發生錯誤 - {e}")
sys.exit(1)
# 假設 model 是已經初始化的模型
model = FireRedAsr() # 根據實際情況初始化模型
model.load_state_dict(checkpoint, strict=False) # 使用 strict=False 允許不匹配的權重
環境如下
NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5
$ python --version
Python 3.10.16
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-cusparselt-cu12 0.6.2
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
torch 2.6.0