Concept-aware fine-tuning (CAFT) encourages stronger conceptual understanding by incorporating multi-token prediction into fine-tuning.
git clone https://github.com/michaelchen-lab/caft-llm.git
cd caft-llm
pip install -e .
- Create
.env
file withHUGGINGFACE_TOKEN=<token>
and optionallyWANDB_TOKEN=<token>
- Add
train_set.jsonl
andeval_set.jsonl
files toscripts/datasets/
. Each instance should be of the format:
{
"id": "<int/str>", "status": "OK",
"conversation": [
{"role": "human", "content": "(prompt)"},
{"role": "assistant", "content": "(ground truth answer)"},
]
}
Currently, only the auxiliary heads of meta-llama/Llama-3.1-8B-Instruct
have been pretrained.
Method 1: Use the provided training script scripts/train.py
torchrun --nprod-per-node 1 scripts/train.py -ftm lora
torchrun --nprod-per-node 1 scripts/train.py -ftm lora -ft-heads -hpretrain
torchrun --nprod-per-node 1 scripts/train.py -ftm sft -lr 5e-6 -fr-unembed
torchrun --nprod-per-node 1 scripts/train.py -ftm sft -lr 5e-6 -fr-unembed -ft-heads -hpretrain
Selected Arguments:
--model-name-or-path -model
: Currently onlymeta-llama/Llama-3.1-8B-Instruct
is supported.--model-max-length -maxlen
--finetune-method -ftm
:lora
orsft
(full finetuning)--learning-rate -lr
--epochs -e
--freeze-unembedding -fr-unembed
: Only applicable for full fine-tuning. Recommended:True
--per-device-batch-size -micro-bs
--gradient-accumulation-steps -grad-acc
--heads-pretraining -hpretrain
: Train auxiliary heads on your dataset for 1 epoch before apply CAFT to your model.-ft-heads
must also be set toTrue
.
The full list of arguments can be found using this command:
python scripts/train.py --help
Method 2: Integrate CAFT into your existing Transformers fine-tuning pipeline
import transformers
from caft import *
# Import your pretrained Transformers model, tokenizer, TrainingArguments, and data_module
add_auxiliary_heads(model)
add_caft_loss(transformers)
trainer = transformers.trainer.Trainer( # The additional CAFT functions track and save the auxiliary losses
model=model, tokenizer=tokenizer, args=model_training_args,
callbacks=[CAFTSaveLogging],
compute_metrics=caft_compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
**data_module
)
Please refer to scripts/train.py
for a complete implementation example.
- Download the train and validation dataset from this Huggingface repo and save to
scripts/datasets
- Run the following command
torchrun nproc-per-node 4 scripts/train_aux_heads.py
We welcome community contributions and feature requests for caft-llm
. Feel free to open an issue or submit a pull request. If you have any questions or wish to collaborate, please contact michaelchenkj@gmail.com.
- Support all model architectures.
Description
Currently, the `LlamaDecoderLayer` is used to create auxiliary heads; in other words, only Llama-based models are supported. Edit `core.py` to copy the last hidden layer of the given model instead of inserting `LlamaDecoderLayer`, then reinitialize the weights. - Support speculative decoding.
Description
Speculative decoding can be implemented using the same method as Gloeckle et al. (2024) and Stern et al. (2018). - Support FSDP and DeepSpeed
This codebase adapts code from several amazing projects, including Medusa and Facebook Multi-Token.