Yixin Chen1, Shuai Zhang2, Boran Han2, Tong He2 and Bo Li2,3.
1The Chinese University of Hong Kong, 2Amazon Web Services, 3University of Chicago
We received the SAC Awards from ACL 2024.
CaMML, a lightweight module, is crafted to seamlessly integrate multimodal contextual samples into large models, thereby empowering the model to derive knowledge from analogous, domain-specific, up-to-date information and make grounded inferences.
- Clone this repository and navigate to CaMML folder
git clone camml.git
cd camml
- Install Packages
conda create -n camml python=3.10 -y
conda activate camml
pip install --upgrade pip # enable PEP 660 support
bash install.sh
- Install additional packages for training cases
pip install flash-attn --no-build-isolation
Model | Image Size | LLM | Vision Encoder | CaMML Retrieval Model | CaMML Retrieval Data | Train Data | Finetuning Schedule | Model Download |
---|---|---|---|---|---|---|---|---|
CaMML-7B | 224 | Vicuna-7B-v1.3 | CLIP-ViT-L-14-224px | ImageBind-Huge | ScienceQA train | ScienceQA train | ft_12epochs_2e-5 | checkpoint |
CaMML-13B | 224 | Vicuna-13B-v1.3 | CLIP-ViT-L-14-224px | ImageBind-Huge | ScienceQA train | ScienceQA train | ft_12epochs_2e-5 | checkpoint |
Model | Image Size | LLM | Vision Encoder | CaMML Retrieval Model | CaMML Retrieval Data | Train Data | Finetuning Schedule | Model Download |
---|---|---|---|---|---|---|---|---|
CaMML-7B | 336 | Vicuna-7B-v1.5 | CLIP-ViT-L-14-336px | ImageBind-Huge | LLaVA-v1.5-665K | LLaVA-v1.5-665K | ft_1epoch_2e-5 | checkpoint |
CaMML-13B | 336 | Vicuna-13B-v1.5 | CLIP-ViT-L-14-336px | ImageBind-Huge | LLaVA-v1.5-665K | LLaVA-v1.5-665K | ft_1epoch_2e-5 | checkpoint |
CaMML is finetuned on ScienceQA dataset.
- Follow ScienceQA repo to set up the dataset.
- Prepare the Data in LLaVA-format.
python scripts/convert_sqa_to_llava.py \
convert_to_llava \
--base-dir /path/to/ScienceQA/data/scienceqa \
--prompt-format "QCM-LEPA" \
--split {train,val,minival,test,minitest}
CaMML is instruction-finetuned on LLaVA-1.5-665K dataset. Please follow LLaVA instructions to download the annotation of data llava_v1_5_mix665k.json, and download the images from constituting datasets:
- COCO: train2017
- GQA: images
- OCR-VQA: download script
- TextVQA: train_val_images
- VisualGenome: part1, part2
We build CaMML Retriever upon ImageBind Models and AutoFaiss Index tools. Each data entry is encoded by ImageBind-Huge pre-trained checkpoint and saved using AutoFaiss index. We provide the processed faiss index with corresponding data json file:
For building own customized dataset as source for CaMML Retriever, we provide scripts and examples for generating your own index and embedding:
python scripts/retriever/retriever_embed_llava665k.py
python scripts/retriever/build_autofaiss_index.py
We utilize following models as initialization:
- Vicuna-7B-v1.3
- Vicuna-13B-v1.3
- Vicuna-7B-v1.5
- Vicuna-13B-v1.5
- CLIP-ViT-L-14
- CLIP-ViT-L-14-336px
- LLaVA-MM-Projectors
We follow LLaVA preparation to test on 11 tasks (MME, MMbench, GQA, etc.), and organize the data in ./data/eval
.
Also, we provide evaluation on COCO caption, Flickr30k caption, OKVQA/A-OKVQA, and RefCOCO/+/g visual grounding, please download and add to ./data/eval
.
data
├──llava_665k_vision_flatIP.index
├──llava_665k_memory_metadata.json
├──sqa_vision_flatIP.index
├──sqa_train_post_memory_answer.json
├──llava
│ └── llava_665k
│ ├── coco
│ │ └── train2017
│ ├── gqa
│ │ └── images
│ ├── ocr_vqa
│ │ └── images
│ ├── textvqa
│ │ └── train_images
│ └── vg
│ ├── VG_100K
│ └── VG_100K_2
├──scienceqa
│ ├── images
│ ├── llava_train_QCM-LEPA.json
│ ├── llava_val_QCM-LEPA.json
│ ├── llava_test_QCM-LEPA.json
│ └── llava_test_CQA-A.json
└──eval
├── MME
├── mm-vet
├── mmbench
├── mmbench_cn
├── pope
├── scienceqa
├── seed_bench
├── vizwiz
├── vqav2
├── textvqa
├── gqa
├── cococap
├── flickr30k
├── okvqa
├── aokvqa
├── refcoco
├── refcocop
└── refcocog
Run:
bash scripts/train_camml_7B_sqa.sh
e.g.:
torchrun --nproc_per_node=$GPUS_PER_NODE --master_port=$RANDOM \
llava/train/train_camml_sqa.py \
--deepspeed "zero3.json" \
--model_name_or_path ./checkpoints/vicuna-7b-v1.3 \
--version v1 \
--data_path ./data/scienceqa/llava_train_QCM-LEPA.json \
--image_folder ./data/scienceqa/images/ \
--vision_tower ./checkpoints/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-pretrain-vicuna-7b-v1.3/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--perceiver_hidden_size 768 \
--perceiver_querys 64 \
--perceiver_layers 2 \
--icl_num 1 \
--random_shots_training True \
--image_aspect_ratio pad \
--group_by_modality_length True \
--fp16 True \
--output_dir ./checkpoints/$file \
--num_train_epochs 12 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 0 \
--lazy_preprocess True
Run:
bash scripts/train_camml_7B_665K.sh
e.g.:
torchrun --nproc_per_node=$GPUS_PER_NODE --master_port=$RANDOM \
camml/train/train_camml.py \
--deepspeed "zero3.json" \
--model_name_or_path ./checkpoints/vicuna-7b-v1.5 \
--version v1 \
--data_path ./data/llava/llava_665k/llava_v1_5_mix665k.json \
--image_folder ./data/llava/llava_665k/ \
--vision_tower ./checkpoints/clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--perceiver_hidden_size 768 \
--perceiver_querys 128 \
--perceiver_layers 2 \
--random_shots_training True \
--image_aspect_ratio pad \
--group_by_modality_length True \
--mm_projector_type mlp2x_gelu \
--fp16 True \
--output_dir ./checkpoints/$file \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 0 \
--lazy_preprocess True
By default, we adopt LLaVA-v1.5-665K dataset as our retriever source.
python camml/eval/run_camml.py --query $QUESTION --image-file $IMAGE_PATH
Model | AVG. | IMG | TXT |
---|---|---|---|
CaMML-7B-sqa-FT | 91.32 | 89.24 | 93.21 |
CaMML-13B-sqa-FT | 92.03 | 89.94 | 93.84 |
Model | LLM | VQAv2 | GQA | VizWiz | SQA(I) | TextVQA | POPE | MME | MMBench | MMBench-CN | SEED | MM-Vet |
---|---|---|---|---|---|---|---|---|---|---|---|---|
CaMML-7B | Vicuna-7B | 79.4 | 62.7 | 51.2 | 67.9 | 58.0 | 86.4 | 1506.9 | 66.9 | 60.6 | 60.4 | 32.2 |
CaMML-13B | Vicuna-13B | 80.2 | 63.7 | 57.4 | 72.3 | 59.9 | 86.7 | 1588.7 | 70.2 | 63.6 | 62.3 | 36.4 |
Model | LLM | COCO Cap (CIDEr) | Flickr30K Cap (CIDEr) | OKVQA (Acc) | AOKVQA (MC-Acc) | RefCOCO (Acc) | RefCOCO+ (Acc) | RefCOCOg (Acc) |
---|---|---|---|---|---|---|---|---|
CaMML-7B | Vicuna-7B | 111.4 | 82.7 | 64.7 | 81.1 | 66.6 | 60.3 | 57.6 |
CaMML-13B | Vicuna-13B | 116.8 | 84.5 | 66.3 | 82.0 | 70.6 | 65.9 | 60.5 |
CaMML supports up to 19 tasks, you can find them in scripts/evaluation
.
e.g., Testing ScienceQA finetuning CaMML:
# CaMML-7B, 1 shot
bash scripts/evaluation/sqa_ft_camml.sh camml_7b_sqa_ft 1
# CaMML-13B, 3 shots
bash scripts/evaluation/sqa_ft_camml.sh camml_13b_sqa_ft 3
e.g., Testing instruction-tuning CaMML on VQAv2:
TASK="vqav2" # mme, mmvet, mmbench, etc.
bash scripts/evaluation/${TASK}_camml.sh camml_7b
If you find CaMML useful for your research and applications, please cite using this BibTeX:
@misc{camml,
title={CaMML: Context-Aware Multimodal Learner for Large Models},
author={Yixin Chen and Shuai Zhang and Boran Han and Tong He and Bo Li},
year={2024},
journal={The 62nd Annual Meeting of the Association for Computational Linguistics},
eprint={2401.03149},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
We build this repo upon following codebases: