By RuoChen from ZJU
conda create -n llama3 python=3.10 -y
conda activate llama3
pip install -r requirements.txt
# 指定TRL版本(必须为要求版本!!!)
pip install trl==0.8.0
# 如遇到torch下载问题,请运行如下命令
pip install torch --index-url https://download.pytorch.org/whl/cu124
使用以下命令下载 LLaMA3-8B 模型:
python download_model.py https://hf-mirror.com/Undi95/Meta-Llama-3-8B-hf --output ./models/llama3-8b-hf
python test_origin.py
对比原模型与奖励模型对于测试问题与回答对的预测结果:
# 测试数据:./data/test_data.json
python test_reward.py
# 使用预设测试文本,测试数据:./data/comparisonPPO_data.json
python PPO_comparison.py
# 手动输入测试文本
python PPO_comparison_chat.py
# 训练数据:./data/preference.json
python train_reward.py
# 训练数据:./data/PPOtrain_data.json
python train_PPO.py
tools/
文件夹下包含以下实用工具:
脚本名称 | 功能描述 |
---|---|
check_PPOConfig.py |
检查PPO训练配置文件 |
check_rewardConfig.py |
检查奖励模型训练配置文件 |
clear_gpu.py |
清理GPU缓存 |
download_model.py |
下载模型 |
.
├── data/
│ ├── preference.json # 奖励模型训练数据
│ ├── PPOtrain_data.json # PPO训练数据
│ ├── comparisonPPO_data.json # PPO对比测试数据
│ └── test_data.json # 奖励模型测试数据
├── models/
│ └── llama3-8b-hf/ # LLaMA3模型文件
├── tools/ # 工具脚本目录
├── requirements.txt # 项目依赖
└── [训练和测试脚本]
This project would not be possible without the following codebases: