Skip to content

[ICCV 2025] Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning

Diffusion-CoT/ReflectionFlow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

From Reflection to Perfection:
Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning

arXiv Website HF Dataset: ReflectionFlow
1CUHK MMLAB  2KAUST  3Hugging Face  4Shanghai AI Lab 

Overall pipeline of the ReflectionFlow framework with qualitative and quantitative results of scaling compute at inference time.

🔥 News

  • [2025/6/25] Our paper is accepted by ICCV 2025!
  • [2025/5/23] Release the code for our image verifier.
  • [2025/4/23] Release paper.
  • [2025/4/20] Release GenRef dataset, model checkpoints, as well as the training and inference code.

✨ Quick Start

Installation

  1. Environment setup
conda create -n ReflectionFlow python=3.10
conda activate ReflectionFlow
  1. Requirements installation
pip install -r requirements.txt

🚀 Models and Datasets

Datasets

Name Description Link
GenRef-wds WebDataset format of full GenRef HuggingFace
GenRef-CoT Chain-of-Thought reflection dataset HuggingFace

Models

Name Description Finetune Data Link
FLUX Corrector Main FLUX-based "text image -> image" model GenRef-wds HuggingFace
Reflection Generator Qwen-based reflection generator GenRef-CoT HuggingFace
Image Verifier Qwen-based image verifier GenRef-CoT HuggingFace

🤖 Reflection Tuning

train_flux/config.yaml exposes all the arguments to control all the training-time configurations.

First, get the data. You can either download the webdataset shards from diffusion-cot/GenRef-wds or directly pass URLs.

When using local paths, set path under [train][dataset] to a glob pattern: DATA_DIR/genref_*.tar. The current config.yaml configures training to stream from the diffusion-cot/GenRef-wds repository. You can even change the number of tars you want to stream for easier debugging. Just change genref_{0..208}.tar to something like genref_{0..4}.tar, depending on the number of shards you want to use.

Run the following command for training the FLUX Corrector:

bash train_flux/train.sh

We tested our implementation on a single node of 8 80GB A100s and H100s. We acknowledge that there are opportunities for optimization, but we didn't prioritize them in this release.

Note

Validation during training is yet to be implemented.

⚡ Inference Time Scaling

Introduction

We provide the code for the inference time scaling of our reflection-tuned models. Currently, we support:

Setup

First, you need to set up the following:

export OPENAI_API_KEY=your_api_key
# if you want to use NVILA as verifier
pip install transformers==4.46
pip install git+https://github.com/bfshi/scaling_on_scales.git

Then you need to set up the FLUX_PATH and LORA_PATH in the config file of your choice from tts/config. The FLUX_PATH is basically the contents of black-forest-labs/FLUX.1-dev which can be downloaded like so:

from huggingface_hub import snapshot_download

local_dir = "SOME_DIR"
snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", local_dir=local_dir)

The LORA_PATH is our corrector model path.

If you want to use our finetuned reflection generator, you need to first install LLaMA-Factory. Then download the model from here and change the model_name_or_path in the config file of tts/config/our_reflectionmodel.yaml to the reflection generator path. To be specific, the path should be like Reflection-Generator/infer/30000. Next, host the model with:

API_PORT=8001 CUDA_VISIBLE_DEVICES=0 llamafactory-cli api configs/our_reflectionmodel.yaml

And change the name of reflection_args in the config file (for example: tts/configs/flux.1_dev_gptscore.json) to ours.

Note

When using our reflection generator model, please consider using at least two GPUs for better allocating resources.

Run

First, please run tts_t2i_noise_scaling.py to generate naive noise scaling results, with the commands:

export OUTPUT_DIR=output_dir
cd tts
python tts_t2i_noise_scaling.py --output_dir=$OUTPUT_DIR --meta_path=geneval/evaluation_metadata.jsonl --pipeline_config_path=configs/flux.1_dev_gptscore.json 

Next, you can run the following command to generate the results of reflection tuning:

export NEW_OUTPUT_DIR=reflection_tuning_dir
python tts_reflectionflow.py --imgpath=$OUTPUT_DIR --pipeline_config_path=configs/flux.1_dev_gptscore.json --output_dir=$NEW_OUTPUT_DIR

We also provide the code for only noise & prompt scaling:

python tts_t2i_noise_prompt_scaling.py --output_dir=$OUTPUT_DIR --meta_path=geneval/evaluation_metadata.jsonl --pipeline_config_path=configs/flux.1_dev_gptscore.json 

You can also change to tts/configs/flux.1_dev_nvilascore.json to use the NVILA verifier.

By default, we use prompts from tts/config/geneval/evaluation_metadata.jsonl. If you don't want to use all the prompts from it, you can specify --start_index and --end_index CLI args.

NVILA Verifier Filter

After generation, we provide the code using NVILA verifier to filter and get different numbers of sample results.

python verifier_filter.py --imgpath=$OUTPUT_DIR --pipeline_config_path=configs/flux.1_dev_nvilascore.json 

Our Image Verifier

We provide a simple start code for our image verifier. To run the code, please first upgrade the transformers. Currently, we use the transformers version 4.51.3.

pip install transformers==4.51.3

Then you can run the following code to get the score of the image:

from reward_modeling.test_reward import ImageVLMRewardInference
import torch

imgname = IMG_PATH
original_prompt = ORIGINAL_PROMPT

score_verfier = ImageVLMRewardInference(MODEL_PATH, load_from_pretrained_step=10080, device="cuda", dtype=torch.bfloat16)
scores = score_verfier.reward([imgname], [original_prompt], use_norm=True)
print(scores[0]['VQ'])

The MODEL_PATH is the path to the model checkpoint. And scores[0]['VQ'] is the score of the text-image pair, which is higher the better.

🤝 Acknowledgement

We are deeply grateful for the following GitHub repositories, as their valuable code and efforts have been incredibly helpful:

✏️ Citation

If you find ReflectionFlow useful for your research and applications, please cite using this BibTeX:

@misc{zhuo2025reflectionperfectionscalinginferencetime,
    title={From Reflection to Perfection: Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning}, 
    author={Le Zhuo and Liangbing Zhao and Sayak Paul and Yue Liao and Renrui Zhang and Yi Xin and Peng Gao and Mohamed Elhoseiny and Hongsheng Li},
    year={2025},
    eprint={2504.16080},
    archivePrefix={arXiv},
    primaryClass={cs.CV},
    url={https://arxiv.org/abs/2504.16080}, 
}

About

[ICCV 2025] Scaling Inference-Time Optimization for Text-to-Image Diffusion Models via Reflection Tuning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published