Skip to content

zihanghliu/ModelBalancing

Repository files navigation

Model Balancing Helps Low-data Training and Fine-tuning [EMNLP 2024 Main Track]

Zihang Liu, Yuanzhe Hu, Tianyu Pang, Yefan Zhou, Pu Ren, Yaoqing Yang

Paper

Introduction

In this work, we show that model training quality with limited data can be interpreted from a HT-SR perspective. We analyze the ESD of individual layers and propose to use shape metrics (PL_Alpha_Hill) from these ESDs to quantify the quality of individual layers, and assign different learning rates to different layers based on their PL_Alpha_Hill, namely TempBalance. We show that TempBalance achieves better layer-wise quality alignment that improve low-data training in NLP and SciML tasks.

Main Result on LLM Fine-tuning: Full Fine-tuning RoBERTa-base with TempBalance (TB) on GLUE Tasks

Environment Setup

# create conda env and install packages from requirements.txt
conda create -n ww_finetune python=3.8
conda activate ww_finetune
pip install -r requirements.txt

# setup transformer library from local source
conda activate ww_finetune
cd transformers
pip install -e .
pip install accelerate==0.28.0

Reproducing the Results on QNLI Dataset

# Full FT baseline
bash ./bash_scripts/run_glue_baseline_roberta_base_ratio.sh

# Our method TempBalance
bash ./bash_scripts/run_glue_block_tb_sigmoid_roberta_base_ratio.sh

Main Result on Neural PDE: Training FNO Model on 2D Compressible Navier-Stokes Dataset

Environment Setup

cd sciml
conda create -n ww_sciml python=3.8
conda activate ww_sciml
pip install -r requirements_sciml.txt

Dataset Preperation

  1. Download data from external source
cd ./sciml/pdebench/data_download
python download_direct.py --root_folder $your_proj_home_dir/sciml/data --pde_name 2d_cfd_tbv2
  1. Customize the config file named ./sciml/pdebench/models/config/args/config_2DCFD_TB.yaml to your own directory. Change the following lines in the file:
train_log_path: '$your_proj_home_dir/sciml/pdebench/logs/'
checkpoint_path: '$your_proj_home_dir/sciml/pdebench/checkpoints/'
data_path: '$your_proj_home_dir/sciml/data/2D/CFD/2D_Train_Rand/'

(Note that training data usually requires up to several minutes to load.)

Reproducing the Results of Training FNO Model on 2D Compressible Navier-Stokes Dataset

bash ./bash_scripts/reduce_batch_fno.sh

About

[EMNLP 2024 Oral] Model Balancing Helps Low Data Training and Fine-tuning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published