Weight-conditional diffusion-based training data reconstruction
This project undertaken by Eric Saikali and guided by Dr. Ana-Maria Cretu at the SPRING lab at EPFL, explores the privacy risks associated with neural networks by investigating their capacity to leak information about their training datasets in white-box settings.
The study introduces a novel adaptation of diffusion models to perform weight-conditional generation, a unique approach that leverages neural network weights as conditioning factors for generative tasks. This builds on the foundation of privacy concerns, aiming to understand the extent to which models trained on subsets of the CIFAR-10 dataset can reveal details about their training data.
The initial phase of the project involves training 2,000 classification models on varied subsets of CIFAR-10. Subsequently, a Denoising Diffusion Probabilistic Model (DDPM) is developed, conditioned on the encoded weights of these models, analogous to class-conditional generative techniques.
The primary objective of this project is to extend the scope of diffusion models by introducing weight-conditional generation, and applying them to reconstruction attacks. The ultimate aim is to develop a generator network that, when conditioned on the neural network weights W, can output samples resembling the dataset (D) used to train these weights.
This approach seeks to extend diffusion models to a new task as they have traditionally been designed for:
- Unconditional generation: Producing random samples from a data distribution.
- Class-conditional generation: Generating samples specific to a predefined class (e.g., digits).
- Text-conditional generation: Creating outputs based on textual descriptions.
To the best of our knowledge, diffusion (and more generally, generative) models have not explored conditioning on trained neural network weights to reconstruct data from the original training set. This project pioneers this exploration and evaluates its implications for privacy and security.
-
create a virtual python environment through
python3 -m venv venv
with your local python 3.9 base environment, If you don't have python 3.9 you can use the installvirtualenv
in your base environment viapip install virtualenv
and then usevirtualenv venv --python=python3.9
, You can check the version withpython3 --version
-
activate the virtual python environment through
source venv/bin/activate
-
download the
requirements.txt
file on your environment throughpip install -r requirements.txt
while being in the project's folder
- Train 2,000 single-layer perceptron models on subsets of the CIFAR-10 dataset with varied configurations and data partitions. For the purpose of this experiment, each model was trained on 100 samples, for 100 epochs.
- reproduction : run
python3 train_mlps.py
for launching it with all default parameters.
- The large diffusion model
$DM_1$ is based on Tushar Kumar's code which was originally used configs files and its files are located at the root of this project. - The small diffusion model
$DM_2$ is based on Tim Pearce's code which was refactored to take into account arguments and its files are located in thesmall_model
folder.
Please note that while the report only regroup the most critical experiments, more were done, first to fine tune the architecture as highlight the number of configuration insrc/config
and to test and get insights about how Diffusion model would compare. Also note that to reproduce the results, make sure to specify in the parameterweights_dir
the directory of the trained MLPs. This is in order to train with the same order of images and classes as the weight conditional training. -
Baseline Training:
-
Unconditional DDPM:
- The diffusion model is trained without any additional conditioning signals.
- The goal is to learn the data distribution directly, predicting the noise at each timestep given only the noisy input.
-
reproduction : to reproduce the results of the large model
$DM_{1n}$ runpython3 train_ddpm_cond.py --config src/config/cifar10_no_cond-double-lr-0.0001.yaml
. -
reproduction : to reproduce the results of the small model
$DM_{2n}$ runpython3 small_model/cifar10_training.py --weight_dir 'data/result/models_' --class_cond False --weight_cond False --save_dir 'small-model-no-cond/' --guidance_scale 2.0 --n_epoch 200
.
-
Class-Conditional DDPM:
- The model is trained with class labels as conditioning information.
- Classes are represented as one-hot encoded vectors, which are added to the timestep embeddings. This enables the model to generate samples specific to a given class.
-
reproduction : for
$DM_{1c}$ runpython3 train_ddpm_cond.py --config src/config/cifar10_class_cond-double-lr-0.0001.yaml
. -
reproduction : for
$DM_{2c}$ runpython3 small_model/cifar10_training.py --weight_dir 'data/result/models_' --class_cond True --weight_cond False --save_dir 'small-model-class/' --guidance_scale 2.0 --n_epoch 200
.
-
Unconditional DDPM:
-
Attack Training:
-
Weight-Conditional DDPM:
- The diffusion model is conditioned on the weights of an external neural network.
- Neural network weights are compressed into an embedding vector using the symmetric invariant technique Neural Feature Normalization (NFN).
- The resulting embedding is incorporated into the timestep embedding, analogous to how class conditioning is applied.
-
reproduction : for
$DM_{1w}$ runpython3 train_ddpm_cond.py --config src/config/cifar10_weight_cond-double-lr-0.0001.yaml
-
reproduction : for
$DM_{2w}$ runpython3 small_model/cifar10_training.py --weight_dir 'data/result/models_' --class_cond False --weight_cond True --save_dir 'small-model-weight/' --guidance_scale 4.0 --n_epoch 90 --batch_size 32
.
-
Weight + Class Conditional DDPM:
- This Diffusion model setup extend conditioning to include both neural network weights and class labels.
- Thanks to these two types of conditioning, the model can generate data that is both class-specific and influenced by the characteristics of a particular neural network.
-
reproduction : for
$DM_{1cw}$ runpython3 train_ddpm_cond.py --config src/config/cifar10_weight_class_cond-double-lr-0.0001.yaml
. -
reproduction : for
$DM_{2cw}$ runpython3 small_model/cifar10_training.py --weight_dir 'data/result/models_' --class_cond True --weight_cond True --save_dir 'small-model-weight-class/' --guidance_scale 2.0 --n_epoch 90 --batch_size 32
. -
reproduction : for
$DM_{1NFN}$ runpython3 train_ddpm_cond.py --config src/config/cifar10_weight_class_cond-double-lr-0.0001-nfn_disabled.yaml
.
-
- The L2 Distance was utilized to assess the perceptual similarity between the generated samples and the original dataset. This metric quantifies how closely a generated image resembles a target image in terms of visual features.
- The fidelity of the generated samples was evaluated using a two-step process:
- Similarity Scoring: The MSE loss was computed for each generated image against all samples from the dataset.
- Reconstruction Accuracy: The sample with the lowest MSE loss (i.e., the most visually similar image) was selected, and its reconstruction score was recorded to gauge how accurately the model replicated the original training data.
If you wish to recreate the graphs and tables, using the
fetch_eval_scores.sh
script, copy all models scores and finally given the correct directories and filenames,evaluate-reconstruction.ipynb
will compute the scores per model and their aggregates.
Testing on the training set : weights_dir
the directory of the trained MLPs. Also note that you might need to change the device if not using cuda
.
You can find all final scores in final_score.jsonl
within the specified save_dir
parameter.
To reproduce the result for each of the models on the train set
python3 eval_ddpm_cond.py --config src/config/cifar10_no_cond-double-lr-0.0001.yaml --model 'storage-no-cond-double-lr-0.0001/no_conditional_diff/no_conditional_diff_49.pt' --is_train True
python3 eval_ddpm_cond.py --config src/config/cifar10_class_cond-double-lr-0.0001.yaml --model 'storage-class-cond-double-lr-0.0001/class_conditional_diff/class_conditional_diff_49.pt' --is_train True
python3 eval_ddpm_cond.py --config src/config/cifar10_weight_cond-double-lr-0.0001.yaml --model 'storage-weight-cond-double-lr-0.0001/weight_conditional_diff/weight_conditional_diff_118.pt' --is_train True
python3 eval_ddpm_cond.py --config src/config/cifar10_weight_class_cond-double-lr-0.0001.yaml --model 'storage-weight-class-cond-double-lr-0.0001/class_weight_conditional_diff/class_weight_conditional_diff_36.pt' --is_train True
python3 eval_ddpm_cond.py --config src/config/cifar10_weight_class_cond-double-lr-0.0001-nfn_disabled.yaml --model 'nfn_disabled_storage-weight-class-cond-double-lr-0.0001/nfn_disabled_class_weight_conditional_diff/nfn_disabled_class_weight_conditional_diff_36.pt' --is_train True
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-no-cond/model_200ep_256feat.pth' --class_cond False --weight_cond False --guidance 0.0 --is_train True
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-class/model_200ep_256feat.pth' --class_cond True --weight_cond False --guidance 2.0 --is_train True
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-weight/model_84ep_256feat.pth' --class_cond False --weight_cond True --guidance 4.0 --is_train True
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-weight-class/model_87ep_256feat.pth' --class_cond True --weight_cond True --guidance 2.0 --is_train True
To reproduce the result for each of the models on the test set
python3 eval_ddpm_cond.py --config src/config/cifar10_no_cond-double-lr-0.0001.yaml --model 'storage-no-cond-double-lr-0.0001/no_conditional_diff/no_conditional_diff_49.pt' --is_train False
python3 eval_ddpm_cond.py --config src/config/cifar10_class_cond-double-lr-0.0001.yaml --model 'storage-class-cond-double-lr-0.0001/class_conditional_diff/class_conditional_diff_49.pt' --is_train False
python3 eval_ddpm_cond.py --config src/config/cifar10_weight_cond-double-lr-0.0001.yaml --model 'storage-weight-cond-double-lr-0.0001/weight_conditional_diff/weight_conditional_diff_118.pt' --is_train False
python3 eval_ddpm_cond.py --config src/config/cifar10_weight_class_cond-double-lr-0.0001.yaml --model 'storage-weight-class-cond-double-lr-0.0001/class_weight_conditional_diff/class_weight_conditional_diff_36.pt' --is_train False
python3 eval_ddpm_cond.py --config src/config/cifar10_weight_class_cond-double-lr-0.0001-nfn_disabled.yaml --model 'nfn_disabled_storage-weight-class-cond-double-lr-0.0001/nfn_disabled_class_weight_conditional_diff/nfn_disabled_class_weight_conditional_diff_36.pt' --is_train False
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-no-cond/model_200ep_256feat.pth' --class_cond False --weight_cond False --guidance 0.0 --is_train False
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-class/model_200ep_256feat.pth' --class_cond True --weight_cond False --guidance 2.0 --is_train False
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-weight/model_84ep_256feat.pth' --class_cond False --weight_cond True --guidance 4.0 --is_train False
python3 small_model/cifar10_eval.py --weight_dir 'data/result/models_' --save_dir 'small-model-eval/' --load_path 'small-model-weight-class/model_87ep_256feat.pth' --class_cond True --weight_cond True --guidance 2.0 --is_train False
If you wish to recreate the graphs and tables, if using a remote cluster, using the fetch_eval_scores.sh
script, copy all models scores and finally given the correct directories and filenames, evaluate-reconstruction.ipynb
will compute the scores per model and their aggregates.
Otherwise, run directly evaluate-reconstruction.ipynb
, it will save graphs and tables.
-
In the root folder, the file
sample_ddpm_cond.py
can be run as a standalone for desired sampling on a given checkpoint, make sure to give the--config
and--model
arguments if you wish to run them. -
In the
notebook
folder, jupyter notebooks are which are useful for plotting and testing are present-
analyse_nfn_weight_distribution.ipynb
which checks quantitatively whether the NFN has been trained by comparing its initial state to its trained state. -
eval_ddpm_cond.ipynb
which enables a small-scale evaluation of the reconstruction over the models$DM_1$ on the train set$D_{aux}$ , this notebook is practical for displaying class-conditional samples, weight-class-conditional samples and the true images of the the weight it was trained on. -
eval_ddpm_cond_test.ipynb
which enables a small-scale evaluation of the reconstruction over the first model on the test set$D$ , this notebook is practical for displaying class-conditional samples, weight-class-conditional samples and the true images of the the weight it was tested on. -
evaluate-reconstruction.ipynb
which enables after copying the results of the above assessment, a detailed analysis of this evaluation compared to a baseline, it requires to be given the paths to'storage-for-eval-{split}-samples-99-num-samples-100--model-{model_type}_cond-double-models_scores.jsonl
. -
merge_epochs_best_img.ipynb
which allows for concatenating image generation across epochs regardless of settings. -
plot_results_from_MLP.ipynb
allows for re-plotting the losses and accuracies stored by each of the trained MLPs. -
sample_ddpm_cond.ipynb
is a sampling script which enables simple sampling given the configuration and a hard-coded path to the pytorch model state dict. -
verify_ddpm_mlp_mapping_correctness.ipynb
which checks that the data sampling in the training loop matches the one in the MLP training ensuring that weights and images are properly retrieved jointly. -
verify_ddpm_weight_batching_correctness.ipynb
which checks that the batching which transforms weights state-dict into weight space feature expected by the nfn model are correct.
-
-
In the
small_model
folder, the filecifar10_sampling.py
can be run as a standalone for desired sampling on a given checkpoint,python3 cifar10_sampling.py --batch_size 10 --load_path 'small-model-class-cond/model_200ep_256feat.pth' --weight_dir 'data/result/models_' --class_cond True --weight_cond False --guidance_scale 2.0
-
The scripts in the
scripts
folder contains practical bash command line to fetch scores and images obtain from a remote server.-
fetch_eval_scores.sh
which will retrieve from a given output of theeval_ddpm_cond.py
all the scores necessary for theevaluate-reconstruction.ipynb
script. -
fetch_img.sh
which will retrieve from a given output all the generated images across epochs, which can then be transformed into a gif usingmerge_epochs_best_img.ipynb
.
-
-
Prior research in the context of this project was also conducted towards creating a reduced representation through an auto-encoder, in this context these files remain present as they provide a way to invert the NFN forward passes.
-
classifierWeightDataset.py
is a dataset which gives as input a single weight. -
nfn_auto_encoder.py
is the encoder-decoder training which compresses the weights. -
compress_models_weights.py
provide a script for compressing weights using the trained encoder-decoder. -
nfn_reverse/hnp_unpool.py
contains theHNPUnpool
module which has as purpose to invert the NFNHNPool
module -
nfn_reverse/np_linear_decoder.py
contains theNPLinearDecoder
module which has as purpose to invert the NFNNPLinear
module
-
I would like to thank the Spring Lab at EPFL for all the resources it provided and for allowing the creation of this research project. I am expressively thankful to Dr. Ana Maria Cretu for her outstanding guidance and dedicated mentorship which allowed me to develop and improve my work.