This repository contains the official implementation of the paper Generative Fractional Diffusion Models (GFDM), introducing a continuous-time diffusion model driven by Markov approximate fractional Brownian motion, replacing the standard Brownian motion used in traditional diffusion models.
Fractional Brownian motion interpolates between the paths of Brownian-driven SDEs and those of the underlying integration in probability flow ODEs, while also offering even rougher paths:
Our experiments demonstrate that, compared to purely Brownian dynamics, the super-diffusive (smooth) regime of Markov approximation of fractional Brownian motion achieves higher image quality with fewer score model evaluations, improved pixel-wise diversity and better distribution coverage.
To run this code, install the latest project conda environment stored in gfdm.yml
via
conda env create -f gfdm.yml
You can use our repository to train GFDM on mnist
, fashionmnist
and cifar10
. To train on your custom dataset add in train.get_dataset
your dataset named yourdataset
to the constructor:
constructor = {
"mnist": vision_datasets.MNIST,
"fashionmnist": vision_datasets.FASHIONMNIST,
"cifar10": vision_datasets.CIFAR10,
"yourdataset": CustomDataset
}
Additionally, you need to add yourdataset
to the available choices for --data_name
in args.py
. To use our code out-of-the-box, your CustomDataset
must inherit from the vision_datasets.TVData
class. To train with a Hurst index (c,size,size)
with num_classes
classes, use the following command:
python train.py --data_name yourdataset --channels c --image_size size --num_classes num_classes --log_model_every_n 100000 --val_check_interval 10000 --hurst H --num_aug K --dynamics fvp --train_steps 1000000
Depending on the size of your images, consider adjusting the following default arguments of unet.UNetModel
:
--model_channels 128
, --num_res_blocks 4
, --attn_resolutions 8
, --channel_mult 1,2,2,2
.
For conditional image generation, we observe the best performance on MNIST and CIFAR-10 using Fractional Variance Preserving (FVP) dynamics with
CIFAR10 | MNIST | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
To train on MNIST we used the following parameters:
python train.py --data_name mnist --channels 1 --image_size 28 --hurst 0.9 --num_aug 3 --dynamics fvp --model_channels 64 --num_res_blocks 3 --attn_resolutions 4,2 --channel_mult 1,2,4 --dropout 0.0 --use_ema False --log_model_every_n 50000 --val_check_interval 500 --lr 1e-4 --batch_size 1024 --train_steps 50000
To train on CIFAR10 we used the following parameters:
python train.py --data_name cifar10 --channels 3 --image_size 32 --hurst 0.9 --num_aug 2 --dynamics fvp --model_channels 128 --num_res_blocks 4 --attn_resolutions 8 --channel_mult 1,2,2,2 --use_ema True --log_model_every_n 100000 --val_check_interval 10000 --lr 2e-4 --batch_size 128 --train_steps 1000000
To generate --mode=sde
or --mode ode
to sample from a trained model at
./runs/{id}_{data_name}_H{hurst}_K{num_aug}/model/model-{version}.pth
run:
python generate.py --run_id id --version version --n_samples M --batch_size batch_size --steps N --data_name data_name --hurst hurst --num_aug num_aug --mode sde
To download the best performing models with FVP dynamics run in gfdm/runs
:
gdown https://drive.google.com/uc?id=1OySDmN2vXe5ox4egkxLR1qZuocIlAGz3
After unzipping pretrained_models.zip
you can sample
python generate.py --run_id 1299582 --hurst 0.9 --num_aug 2 --dynamics fvp --mode sde --version v5 --n_samples 50000 --batch_size 1000 --steps 1000 --data_name cifar10
python generate.py --run_id 1299581 --hurst 0.7 --num_aug 2 --dynamics fvp --mode sde --version v5 --n_samples 50000 --batch_size 1000 --steps 1000 --data_name cifar10
python generate.py --run_id 1299543 --hurst 0.5 --num_aug 0 --dynamics fvp --mode sde --version v4 --n_samples 50000 --batch_size 1000 --steps 1000 --data_name cifar10
Our code uses Weights & Biases for looging. For online logging specify your personal key via --wb_key your_wandb_key
.
The code runs on as many GPUs as available. Consider to adjust --batch_size
and --accumulate_grad_batches
when swithing from one GPU to multiple-GPUs for an equivalent set-up.
We kindly ask that you cite our paper when using this code:
@inproceedings{
nobis2024generative,
title={Generative Fractional Diffusion Models},
author={Gabriel Nobis and Maximilian Springenberg and Marco Aversa and Michael Detzel and Rembert Daems and Roderick Murray-Smith and Shinichi Nakajima and Sebastian Lapuschkin and Stefano Ermon and Tolga Birdal and Manfred Opper and Christoph Knochenhauer and Luis Oala and Wojciech Samek},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=B9qg3wo75g}
}