This repository is the official implementation of Training Class-Imbalanced Diffusion Model Via Overlap Optimization (Arxiv 2024, In Submission)
[Project Page] [Arxiv] [OpenReview] [Slides] [Poster]
Authors: Liang Yan, Lu Qi, Vincent Tao Hu, Ming-Hsuan Yang, Meng Tang
Diffusion models have made significant advances recently in high-quality image synthesis and related tasks. However, diffusion models trained on real-world datasets, which often follow long-tailed distributions, yield inferior fidelity for tail classes. Deep generative models, including diffusion models, are biased towards classes with abundant training images. To address the observed appearance overlap between synthesized images of rare classes and tail classes, we propose a method based on contrastive learning to minimize the overlap between distributions of synthetic images for different classes. We show variants of our probabilistic contrastive learning method can be applied to any class conditional diffusion model. We show significant improvement in image synthesis using our loss for multiple datasets with long-tailed distribution. Extensive experimental results demonstrate that the proposed method can effectively handle imbalanced data for diffusion-based generation and classification models.
The repo is implemented based on https://github.com/w86763777/pytorch-ddpm. Currently it supports the training for four datasets namely CIFAR10(LT) and CIFAR100(LT).
- Regular (conditional or unconditional) diffusion model training
- Class-balancing model training
- Class-balancing model finetuning based on a regular diffusion model
We provide mainly the scripts for trianing and evaluating the CIFAR100LT dataset. To run the code, please change the argument 'root' to the path where the dataset is downloaded.
Please find the features for cifar 100 and cifar 10 used in precision/recall/f_beta metrics. Put them in the stats folder and the codes will be ready to run. Note that those two metrics will only be evaluated if the number of samples is 50k otherwise it returns 0.
All the algorithms and models are implemented in Python and Pytorch. Experiments are conducted on a server with 8 NVIDIA V100 GPUs (32 GB memory) and Intel(R) Xeon (R) Platinum 8255C CPU @ 2.50GHz.
This implementation is based on / inspired by:
- https://github.com/w86763777/pytorch-ddpm
- https://github.com/crowsonkb/k-diffusion/blob/master/train.py (we refer to the implementation of ADA augmentation in K-diffusion model).
Feel free to cite this work if you find it useful to you!
@article{yan2024training,
title={Training Class-Imbalanced Diffusion Model Via Overlap Optimization},
author={Yan, Divin and Qi, Lu and Hu, Vincent Tao and Yang, Ming-Hsuan and Tang, Meng},
journal={arXiv preprint arXiv:2402.10821},
year={2024}
}