We provide code for all the experiments presented in our paper.
The organization of code is as follows :
- Source code is in the
src
directory. It also contains therequirements.txt
file with all dependencies to download using pip. - Bash files required to run some experiments along with all the hyperparameters used are in the scripts directory.
Our model merging results:
MNIST | MLPNet | MLPLarge | MLPHuge |
---|---|---|---|
Joint Model | |||
Model A | |||
Model B | |||
AVG | |||
OT | |||
MPF (ours) |
CIFAR-10 | MLPNet | MLPLarge | MLPHuge |
---|---|---|---|
Joint Model | |||
Model A | |||
Model B | |||
AVG | |||
OT | |||
MPF (ours) |
Our analysis reveals that MPF effectively synthesizes knowledge from both Model A and Model B, demonstrating robust integration without the catastrophic forgetting observed in the two baseline methods.
The main dependencies for running the code are
- pytorch
- torchvision
- tqdm
- PIL
- numpy
- Python Optimal Transport (POT)
- tensorboard (from tensorflow to check logs)
- for CL you need requirements from https://github.com/sidak/otfusion
Next, we provide detailed instructions on running each experiment.
In general, each experiment has a bash file in scripts directory along with the hyperparameters and random seeds used in the experiment. Corresponding command in the relevant bash file needs to be uncommented before running the experiment. For most of the code, commands and argument names are self-explanatory.
First, all of the base models for fusion experiments need to be trained.
Their hyperparameters are all located in the CurveConfig class in src/curve_merging.py.
The code for the model classes is in src/models/fcmodel.py and src/models/mlpnet.py.
Running training and fusion:
- Check the
CurveConfig
class insrc/curve_merging.py
to modify the parameters and hyperparameters as you wish: model used, dataset, etc. - Then run
src/train.py
. This script trains the base models and then merges following the AVG, OT and MPF procedures detailed in the paper. - The results of the trained models are located in
checkpoints/seed_{seed}/<model_{A or B}>/final_model.pth
. - If you already have the checkpoints of the base models and just want to merge and saved the fusion model, use
bash scripts/run_fuse_fc_models.sh
. To choose the type of fusion, just change the fusion_type variable with one of "ot", "avg", "curve".
input_dim
should be 3072 for CIFAR-10 and 784 for MNIST.hidden_dims
should be 400, 200, 100 for MLPNet, 800, 400, 200 for MLPLarge and 1024, 512, 256 for MLPHuge.model_path_list
should contain pairs of strings{model_architecture}, model_checkpoint_path
that represent each model you want to merge.
- The statistics for the experiment are dumped in the
model_accuracies.csv
file, and the terminal.
The model with best validation accuracy is saved as best_val_acc_model.pth
,
while the final model at the end of training epoch is saved as final_model.pth
.
NOTE: We use the final model for our experiments. All the required model training and merging can be done using this script.
For CL the notebook setup in scripts directory already has the experiments with cells executed.
We thank the authors of the "Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs" and "Wasserstein Barycenter-based Model Fusion and Linear Mode Connectivity of Neural Networks" as well as "Model Fusion via Optimal Transport" papers for sharing their code. From the first paper, we reused their minimum-loss curve finding algorithm to deduce our merged model. From the second and third papers, we reused their implementation of OT.