This repository contains experiments applying Direct Preference Optimization (DPO) and its variants to improve the mathematical reasoning abilities of SFT-trained language models. The approach includes techniques like rejection sampling and self-certainty for selecting preference pairs. The repository is loosely organized to allow fast iteration and experimentation, driven primarily by personal interest in related topics. While contributions and pull requests are welcome, the project is not currently intended to be a fully collaborative effort. Feel free to reach out if you'd like to discuss ideas or directions.
We investigate using self-certainty (a model's internal confidence measure; inspired by this work and this work) to select preference pairs for DPO training on mathematical reasoning tasks, such as GSM8K.
- LoRA + DPO + Additional regularization losses to mitigate model collapse e.g., NLL (cf. ORPO) and DPOP.
- Baseline: Phi-3.5-mini-instruct accuracy on GSM8K: 73.2%
- DPO (LoRA, 1 epoch): 79.1% accuracy
- Statistical significance: p-value < 0.001 (McNemar's test)
- Self-certainty as quality signal: Confirmed that correct answers have significantly higher self-certainty than incorrect ones
- DPO degradation: Vanilla DPO performance drops to 2.65% after 2 epochs, confirming known instability issues
- Mitigation strategies: Preliminary experiments with DPOP loss and NLL regularization show promise for reducing degradation
- Higher quality solution sampling (prevent correct-answer-incorrect-reasoning problems due to LLM backtracking, e.g., Table 7)
- Improved preference pair selection using multiple factors
- Process-level and token-level reward signals
- Testing on more challenging mathematical datasets
explore.ipynb
: initial exploratory notebook understanding data (e.g., GSM8K and how baseline model such as phi-3.5 performs).prepare_training_data.py
: samples positive and negative examples for DPO training, where each sample has associated data fields such as logprobs.train.py
: training script for DPO.utils.py
: functions range from producing logprobs of full/answer-only sequences to config loading and sampling n-shot examples.configs/
: various configurations.data/
: filtered training data.outputs/
: model eval outputs for post-training analysis.eval.py
: evaluation script producing outputs.eval.ipynb
: evaluation notebook for post-training analysis.playground/
: testing toy examples and debug notes.
- Prepare training data:
python prepare_training_data.py
- Train using DPO:
python train.py --config <config_id>
- Eval trained model:
python eval.py --peft_model_path <lora_adapter_path>
figs/
eval.ipynb