Skip to content

XueZeyue/DanceGRPO

Repository files navigation

DanceGRPO

DanceGRPO is the first unified RL-based framework for visual generation.

This is the official implementation for paper, DanceGRPO: Unleashing GRPO on Visual Generation. We develop DanceGRPO based on FastVideo, a scalable and efficient framework for video and image generation.

Key Features

DanceGRPO has the following features:

  • Support Stable Diffusion
  • Support FLUX
  • Support HunyuanVideo
  • Support SkyReels-I2V
  • Support Qwen-Image
  • Support Qwen-Image-Edit

Updates

  • [2025.05.12]: 🔥 We released the paper in arXiv!
  • [2025.05.28]: 🔥 We released the training scripts of FLUX and Stable Diffusion!
  • [2025.07.03]: 🔥 We released the training scripts of HunyuanVideo!
  • [2025.08.30]: 🔥 We released the training scripts of SkyReels-I2V!
  • [2025.09.04]: 🔥 We released the training scripts of Qwen-Image&Qwen-Image-Edit!

We have shared this work at many research labs, and the example slide can be found here. The trained FLUX checkpoints can be found here.

DanceGRPO is also a project dedicated to inspiring the community. If you have any research or engineering inquiries, feel free to open issues or email us directly at xuezeyue@connect.hku.hk.

Getting Started

Downloading checkpoints

You should use "mkdir" for these folders first.

For image generation,

  1. Download the Stable Diffusion v1.4 checkpoints from here to "./data/stable-diffusion-v1-4".
  2. Download the FLUX checkpoints from here to "./data/flux".
  3. Download the HPS-v2.1 checkpoint (HPS_v2.1_compressed.pt) from here to "./hps_ckpt".
  4. Download the CLIP H-14 checkpoint (open_clip_pytorch_model.bin) from here to "./hps_ckpt".

For video generation,

  1. Download the HunyuanVideo checkpoints from here to "./data/HunyuanVideo".
  2. Download the SkyReels-I2V checkpoints from here to "./data/SkyReels-I2V".
  3. Download the Qwen2-VL-2B-Instruct checkpoints from here to "./Qwen2-VL-2B-Instruct".
  4. Download the VideoAlign checkpoints from here to "./videoalign_ckpt".

Installation

./env_setup.sh fastvideo

Training

# for Stable Diffusion, with 8 H800 GPUs
bash scripts/finetune/finetune_sd_grpo.sh   
# for FLUX, preprocessing with 8 H800 GPUs
bash scripts/preprocess/preprocess_flux_rl_embeddings.sh
# for FLUX, training with 16 H800 GPUs for better convergence,
# or you can use finetune_flux_grpo_8gpus.sh with 8 H800 GPUs, but with relatively slower convergence
# or you can try the LoRA version, which takes ~20GB VRAM per GPU with one node (8 GPUs).
bash scripts/finetune/finetune_flux_grpo.sh   

For image generation open-source version, we use the prompts in HPD dataset for training, as shown in "./assets/prompts.txt".

# for HunyuanVideo, preprocessing with 8 H800 GPUs
bash scripts/preprocess/preprocess_hunyuan_rl_embeddings.sh
# for HunyuanVideo, using the following script for training with 16/32 H800 GPUs,
bash scripts/finetune/finetune_hunyuan_grpo.sh   

For the text-to-video generation open-source version, we filter the prompts from VidProM dataset for training, as shown in "./assets/video_prompts.txt".

# for SkyReels-I2V, preprocessing with 8 H800 GPUs
bash scripts/preprocess/preprocess_skyreels_rl_embeddings.sh
# for SkyReels-I2V, using the following script for training with 32 H800 GPUs
# we use FLUX to generate the reference image, please download FLUX checkpoints to "./data/flux"
bash scripts/finetune/finetune_skyreels_i2v.sh   

For the image-to-video generation open-source version, we filter the prompts from ConsistID dataset for training, as shown in "./assets/consist-id.txt".

About Qwen-Image

Download the Qwen-Image checkpoints to "./data/qwenimage". We also use HPS-v2.1 to train the model. The reward increases from ~0.25 to ~0.33 with 200 iterations.

# for Qwen-Image, preprocessing with 8 H800 GPUs
bash scripts/preprocess/preprocess_qwen_image_rl_embeddings.sh
# for Qwen-Image, using the following script for training with 8 H800 GPUs,
bash scripts/finetune/finetune_qwenimage_grpo.sh   
About Qwen-Image-Edit

Download the Qwen-Image-Edit checkpoints to "./data/qwenimage_edit".

Since there are no specific image edit open-source reward models for Qwen-Image-Edit, we still can use HPS-v2.1, and this implementation just serves as a reference.

Download this dataset to "./data/SEED-Data-Edit-Part2-3", and cd ./data/SEED-Data-Edit-Part2-3/real_editing/images , then run tar -xzf images.tar.gz.

The HPS-v2.1 reward will increase from ~0.23 to ~0.27 with about 150 iterations.

# for Qwen-Image-Edit, preprocessing with 8 H800 GPUs
bash scripts/preprocess/preprocess_qwen_image_edit_rl_embeddings.sh
# for Qwen-Image-Edit, using the following script for training with 8 H800 GPUs,
bash scripts/finetune/finetune_qwenimage_edit_grpo.sh   

Image Generation Rewards

We give the (moving average) reward curves (also the results in reward.txt or hps_reward.txt) of Stable Diffusion (left or upper) and FLUX (right or lower). We can complete the FLUX training (200 iterations) within 12 hours with 16 H800 GPUs.

  1. We provide more visualization examples (base, 80 iters rlhf, 160 iters rlhf) in "./assets/flux_visualization".
  2. Here is the visualization script "./scripts/visualization/vis_flux.py" for FLUX. First, run rm -rf ./data/flux/transformer/* to clear the directory, then copy the files from a trained checkpoint (e.g., checkpoint-160-0) into ./data/flux/transformer. After that, you can run the visualization. If it's trained for 160 iterations, the results are already provided in my repo.
  3. More discussion on FLUX can be found in "./fastvideo/README.md".
  4. (Thanks for a community contribution from @Jinfa Huang, if you change the train_batch_size and train_sp_batch_size from 1 to 2, change the gradient_accumulation_steps from 4 to 12, you can train the FLUX with 8 H800 GPUs, and you can finish the FLUX training within a day. If you experience a reward collapse similar to this, please reduce the max_grad_norm.)

Video Generation Rewards

We give the (moving average) reward curves (also the results in vq_reward.txt) of HunyuanVideo with 16/32 H800 GPUs.

With 16 H800 GPUs,

With 32 H800 GPUs,

  1. For the open-source version, our mission is to reduce the training cost. So we reduce the number of frames, sampling steps, and GPUs compared with the settings in the paper. So the reward curves will be different, but the VQ improvements are similar (50%~60%).
  2. For visualization, run rm -rf ./data/HunyuanVideo/transformer/* to clear the directory, then copy the files from a trained checkpoint (e.g., checkpoint-100-0) into ./data/HunyuanVideo/transformer. After that, you can run the visualization script "./scripts/visualization/vis_hunyuanvideo.sh".
  3. Although training with 16 H800 GPUs has similar rewards with 32 H800 GPUs, I still find that 32 H800 GPUs leads to better visulization results.
  4. We plot the rewards by de-normalizing, with the formula VQ = VQ * 2.2476 + 3.6757 by following here.

For SkyReels-I2V,

  1. We plot the rewards by de-normalizing, with the formula MQ = MQ * 1.3811 + 1.1646 by following here.

Multi-reward Training

The Multi-reward training code and reward curves can be found here.

Important Discussion and Results with More Reward Models for FLUX

Thanks for the issue from @Yi-Xuan XU, the results of more reward models and better visualization (how to avoid grid patterns) on FLUX can be found here. We also support the pickscore for FLUX with --use_pickscore.

We support the EMA for FLUX with --ema_decay 0.995 and --use_ema. Enabling EMA helps with better visualization.

Pref-GRPO also discusses how to avoid reward hacking. FlowCPS provides a better SDE for RLHF training.

How to Support Custom Models

  1. For preprocessing, modify the preprocess_flux_embedding.py and latent_flux_rl_datasets.py based on your text encoder.
  2. For FSDP and dataloader, modify the fsdp_util.py and communications_flux.py, we prefer FSDP rather than DeepSpeed since FSDP is easier to debug.
  3. Modify the train_grpo_flux.py.

How to debug:

  1. Print the probability ratio, reward, and advantage for each sample; the ratio should be 1.0 before the gradient update, and you can verify the advantage on your own. Please set the rollout inference batch size and training batch size to 1, otherwise you will not have the ratio 1.0.
  2. The gradient accumulation should follow the sample dimension, which means, suppose you use 20 steps, the gradient accumulation should be accumulate_samples*20.
  3. Based on our experience, the learning rate should be set to between 5e-6 and 2e-5, setting the lr to 1e-6 always leads to training failure in our settings.
  4. Make sure the batchsize is enough; you can follow our setting of flux_8gpus.
  5. More importantly, if you enable cfg, the gradient accumulation should be set to a large number. Based on our experience, we always set it to be num_generations*20, which means you update the gradient only once in each rollout.

Training Acceleration

  1. You can reduce the sampling steps, resolution, or timestep selection ratio.
  2. The outstanding follow-up works such as MixGRPO and BranchGRPO are also working on training acceleration. SRPO also tries to adopt differentiable reward models to accelerate the training.

Acknowledgement

We learned and reused code from the following projects:

We thank the authors for their contributions to the community!

Citation

If you use DanceGRPO for your research, please cite our paper:

@article{xue2025dancegrpo,
  title={DanceGRPO: Unleashing GRPO on Visual Generation},
  author={Xue, Zeyue and Wu, Jie and Gao, Yu and Kong, Fangyuan and Zhu, Lingting and Chen, Mengzhao and Liu, Zhiheng and Liu, Wei and Guo, Qiushan and Huang, Weilin and others},
  journal={arXiv preprint arXiv:2505.07818},
  year={2025}
}

About

An official implementation of DanceGRPO: Unleashing GRPO on Visual Generation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published