Uni-Med: A Unified Medical Generalist Foundation Model For Multi-Task Learning Via Connector-MoE
Xun Zhu, Ying Hu, Fanbin Mo, Miao Li, Ji Wu
【Accepted】by The Thirty-eighth Annual Conference on Neural Information Processing Systems (Neurips 2024) [Poster]
1. Environment
Git clone our repository, creating a python environment and activate it:
conda env create -f environment.yml
conda activate uni_med
Note: If you need to perform an evaluation, please install pycocoevalcap from here.
2. Dataset
To download the raw data,you can follow:
Dataset | Download path | Dataset | Download path |
---|---|---|---|
MedQA | Download | PubMedQA | Download |
Slake | Download | Path-VQA | Download |
MIMIC-CXR | images captions | MPx | images captions |
SA-Med2D-20M | Download | MNIST | Download |
You can download the processed data (such as Slake-VQA/Slake-REC/Slake-REG; SA-Med2D-REC/SA-Med2D-REG) on figshare, which can be directly used for training.
Set dataset path in uni_med/configs/datasets/
3. Pretrained Model Weights
EVA-CLIP ViT-G Download
Llama 2 Chat 7B Download
Uni-Med achieves joint training on 6 six distinct medical tasks and 12 datasets, requiring only one-stage training on a single A800 GPU and no task/dataset fine-tuning.
(1) train config file setup
Set resample_rate and resample_method (projection/avgpool/maxpool) for visual feature aggregation.
Set projector_type (linear/mlp2x_gelu/moe_linear/moe_mlp), num_expert, router_method (router_task_token/router_token/router_task), num_task_tokens, task_token_c, and router_type (soft/hard/constant/sparse) for connector setting.
Set llm_model_name and llm_model_path for loading LLaMA model.
Set sft_type for finetuning (lora/full/none).
Set lora_target_modules and lora_r and lora_alpha for LoRA setting.
Set output_dir for saving model.
(2) Run
CUDA_VISIBLE_DEVICES=0 torchrun --master-port 295XX --nproc-per-node 1 train.py --cfg-path train_configs/uni_med.yaml
Set checkpoint, model parameters, save path and test set path in eval_configs/uni_med_benchmark_evaluation.yaml
(1) Evalauting Visual Question Answering
python eval_vqa.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset slakevqa_en
python eval_vqa.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset path_vqa
(2) Evalauting Referring Expression Comprehension
python eval_ref.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset ref_slake
python eval_ref.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset ref_sa_med
(3) Evalauting Referring Expression Generation
python eval_identify.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset invref_slake
python eval_identify.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset invref_sa_med
(4) Evalauting Report Generation
python eval_identify.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset mimic_caption
python eval_identify.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset medpix_single
(5) Evalauting Image Classification
python eval_identify.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset medmnist_2d_derma
python eval_identify.py --cfg-path eval_configs/uni_med_benchmark_evaluation.yaml --dataset medmnist_2d_organs
- MiniGPT-4 The standard model architecture of Uni-Med follows MiniGPT-v2. Don't forget to check this great open-source work if you don't know it before!
If you're using Uni-Med in your research or applications, please cite using this BibTeX:
@article{zhu2024uni,
title={Uni-Med: A Unified Medical Generalist Foundation Model For Multi-Task Learning Via Connector-MoE},
author={Zhu, Xun and Hu, Ying and Mo, Fanbin and Li, Miao and Wu, Ji},
journal={arXiv preprint arXiv:2409.17508},
year={2024}
}