Skip to content

[Bug]: 基于 XPU 平台使用 PaddleNLP 微调后的 checkpoint 无法通过 from_pretrained 加载 #11098

@YoctoHan

Description

@YoctoHan

软件环境

- paddle2onnx                   2.0.1
- paddlefsl                     1.1.0
- paddlenlp                     3.0.0b4
- paddlepaddle                  3.2.0
- paddlepaddle-xpu              3.2.0
- paddleslim                    2.6.0

重复问题

  • I have searched the existing issues

错误描述

在昆仑芯 P800 平台上基于 PaddlePaddle 和 PaddleNLP 微调后的 checkpoint 无法通过 from_pretrained 正常加载。

稳定复现步骤 & 代码

复现步骤

启动微调

#!/bin/bash

# 定义日志目录
LOG_DIR="/workspace/sft/logs"
LOG_FILE="${LOG_DIR}/finetune_$(date +%Y%m%d_%H%M%S).log"

# 创建日志目录(如果不存在)
mkdir -p "${LOG_DIR}"

# 切换到工作目录
cd /workspace/PaddleNLP/llm || exit 1

# 执行脚本,后台运行并将日志存储到文件
nohup python -u -m paddle.distributed.launch \
    --gpus "0,1,2,3,4,5,6,7" \
    run_finetune.py /workspace/PaddleNLP/llm/config/aiXcoder/sft_argument.json \
    > "${LOG_FILE}" 2>&1 &

# 提示用户日志位置
echo "Training started. Logs are being written to: ${LOG_FILE}"

# 实时查看日志
tail -f "${LOG_FILE}"

微调结束后使用 from_pretrained 加载权重

from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("/workspace/sft/sft_ckpts/", dtype="float32")

输出如下:

[2025-09-18 11:33:46,454] [    INFO] - We are using <class 'paddlenlp.transformers.llama.modeling.LlamaForCausalLM'> to load '/workspace/sft/sft_ckpts/'.
[2025-09-18 11:33:46,454] [    INFO] - Loading configuration file /workspace/sft/sft_ckpts/config.json
[2025-09-18 11:33:46,455] [    INFO] - Loading weights file /workspace/sft/sft_ckpts/model_state.pdparams
[2025-09-18 11:34:00,894] [    INFO] - Loaded weights file from disk, setting weights to model.
W0918 11:34:00.904455 727857 xpu_context.cc:187] Please NOTE: xpu device: 0
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/paddlenlp/transformers/auto/modeling.py", line 798, in from_pretrained
    return cls._from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddlenlp/transformers/auto/modeling.py", line 346, in _from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddlenlp/transformers/model_utils.py", line 2567, in from_pretrained
    model, missing_keys, unexpected_keys, mismatched_keys = cls._load_pretrained_model(
  File "/usr/local/lib/python3.10/dist-packages/paddlenlp/transformers/model_utils.py", line 2254, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
        Skip loading for lm_head.weight. lm_head.weight receives a shape [49152, 4096], but the expected shape is [4096, 49152].
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

稳定复现,看报错应该是一个权重张量的形状发生了转置。

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions