Skip to content

yanliang3612/DiffROP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

The Code of DiffROP: Training Class-Imbalanced Diffusion Model Via Overlap Optimization

This repository is the official implementation of Training Class-Imbalanced Diffusion Model Via Overlap Optimization (Arxiv 2024, In Submission)

[Project Page] [Arxiv] [OpenReview] [Slides] [Poster]

Authors: Liang Yan, Lu Qi, Vincent Tao Hu, Ming-Hsuan Yang, Meng Tang

Introduction

Diffusion models have made significant advances recently in high-quality image synthesis and related tasks. However, diffusion models trained on real-world datasets, which often follow long-tailed distributions, yield inferior fidelity for tail classes. Deep generative models, including diffusion models, are biased towards classes with abundant training images. To address the observed appearance overlap between synthesized images of rare classes and tail classes, we propose a method based on contrastive learning to minimize the overlap between distributions of synthetic images for different classes. We show variants of our probabilistic contrastive learning method can be applied to any class conditional diffusion model. We show significant improvement in image synthesis using our loss for multiple datasets with long-tailed distribution. Extensive experimental results demonstrate that the proposed method can effectively handle imbalanced data for diffusion-based generation and classification models.

About this repository

The repo is implemented based on https://github.com/w86763777/pytorch-ddpm. Currently it supports the training for four datasets namely CIFAR10(LT) and CIFAR100(LT).

  1. Regular (conditional or unconditional) diffusion model training
  2. Class-balancing model training
  3. Class-balancing model finetuning based on a regular diffusion model

(TODO) Running the Experiments

We provide mainly the scripts for trianing and evaluating the CIFAR100LT dataset. To run the code, please change the argument 'root' to the path where the dataset is downloaded.

(TODO) Checkpoint Release

Files used in evaluation

Please find the features for cifar 100 and cifar 10 used in precision/recall/f_beta metrics. Put them in the stats folder and the codes will be ready to run. Note that those two metrics will only be evaluated if the number of samples is 50k otherwise it returns 0.

Configuration

All the algorithms and models are implemented in Python and Pytorch. Experiments are conducted on a server with 8 NVIDIA V100 GPUs (32 GB memory) and Intel(R) Xeon (R) Platinum 8255C CPU @ 2.50GHz.

Acknowledgements

This implementation is based on / inspired by:

Cite Us

Feel free to cite this work if you find it useful to you!

@article{yan2024training,
  title={Training Class-Imbalanced Diffusion Model Via Overlap Optimization},
  author={Yan, Divin and Qi, Lu and Hu, Vincent Tao and Yang, Ming-Hsuan and Tang, Meng},
  journal={arXiv preprint arXiv:2402.10821},
  year={2024}
}

About

[Arxiv 2024] The official code of "Training Class-Imbalanced Diffusion Model Via Overlap Optimization".

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published