Skip to content

Fix vae dtype #11228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed

Fix vae dtype #11228

wants to merge 8 commits into from

Conversation

jiqing-feng
Copy link
Contributor

Hi @sayakpaul . As vae can be manipulated outside the pipeline by users, we should revert to the original dtype of vae instead of the latent dtype. Please review this PR. Thanks!

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu April 8, 2025 08:40
@jiqing-feng
Copy link
Contributor Author

Hi @DN6 @yiyixuxu , do you mind reviewing this change? Thanks.

@jiqing-feng
Copy link
Contributor Author

Hi @DN6 @yiyixuxu , do you have time to take a look at this PR? Thanks!

@asomoza
Copy link
Member

asomoza commented Apr 11, 2025

HI, I want to understand your use case here, if you are using the vae outside the pipeline, can't you just store the original dtype outside too and revert it after using the pipeline?

@jiqing-feng
Copy link
Contributor Author

HI, I want to understand your use case here, if you are using the vae outside the pipeline, can't you just store the original dtype outside too and revert it after using the pipeline?

I run the following script

import PIL
import requests
import torch

from diffusers import StableDiffusionXLImg2ImgPipeline


IMG_URL = "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg"
image = PIL.Image.open(requests.get(IMG_URL, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
prompt = "turn him into cyborg"

model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
)
# To avoid overflow and precision loss.
pipe.vae.to(torch.float32)

new_image = pipe(prompt=prompt, image=image).images[0]

I convert the vae to fp32 after I load the pipeline to avoid overflow and precision error. But in the current main codes, the vae will convert to input dtype which is float16 here. Then, the fp16 vae will trigger the force cast here.

There are 2 logical errors even it won't break the e2e pipeline but will confuse users.

  1. The vae should keep it's original dtype instead of change it to input dtype
  2. I convert the vae to fp32 to avoid inside converting, which will hurt the performance, but the code converts the vae to input dtype(fp16), so it triggers force cast anyway.

Please let me know if I didn't make it clear.

@asomoza
Copy link
Member

asomoza commented Apr 11, 2025

Thanks a lot for explaining.

The SDXL original vae it's an exception though, in all the other pipelines, the vae has the same dtype as the model, if you want to have a vae with a different precision than the model and the pipeline it should be something that the user has to manage like you're doing.

To be more clear, if the SDXL 1.0 original vae didn't have the overflow problem, the original vae precision you're referring to would be the same as the model and not fp32, in this case since the vae has this flaw, whenever you use it with the SDXL pipelines it will be automatically upcasted, so the only exception here is if you're using it with custom code, in that case, you can just manually upcast it like you're doing.

I don't see a problem with what you're suggesting in this PR but if we add the options that everyone needs for their special use case, the pipelines would be huge, so I'll let the others take this decision.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Apr 11, 2025

Hi @asomoza , thanks! I get your point! In that case, can we force vae to float32 and output a log to notice users that the vae will be forced to float32 whatever the SDXL model dtype is (to avoid overflow) ?

@asomoza
Copy link
Member

asomoza commented Apr 14, 2025

can we force vae to float32

IMO that would be a waste of VRAM because of this vae which runs in fp16 without a problem and that most popular finetuned models already use instead of the original one.

Still I don't think that what you're suggesting in this PR is a bad idea, so lets maybe wait to see what are the opinions of @DN6 and @yiyixuxu, they're kind of busy right now so we need a little patience.

@yiyixuxu
Copy link
Collaborator

hey @jiqing-feng
we have a config for vae force_upcast, if it is a vae that needs to be run in float32 it will be upcast to float32 automatically inside pipeline

@jiqing-feng
Copy link
Contributor Author

Hi @asomoza @yiyixuxu . Thanks for your review. There are 2 things that confused me:

  1. For vae, I can see you disable fp16 here, so why don't we just force the vae to float32 (or bf16) at the init stage? Casting model at runtime is not good for performance.
  2. For this PR, I just revert the vae as it's original type, cannot see any risk on this. In my view, casting vae to the input type which ignore the original dtype will confused users because you assume vae have the same dtype as inputs and I don't know if the assume is right, but convert vae to it's original dtype is totally correct?

Would like to get your feedback. Thanks!

@asomoza
Copy link
Member

asomoza commented Apr 15, 2025

@jiqing-feng we don't disable fp16, what we're checking here is if the vae is in fp16 and has the force_upcast in the config, only then we upcast it to fp32, the force_upcast it's only enabled in the original vae and not in the one I posted before, if you use the fixed one, you can actually just use fp16 without any problem.

on point 2 I agree but it has been a long time since the release of SDXL and the use of the pipelines and this is the only time that someone opened an issue or PR about it, so mostly, everyone else just use the fp16 fixed one or just don't use custom code with the original vae.

As I said before, your issue is very specific to you and only you, and if we allow every user to add/remove/refactor the pipelines to their very specific use case, the pipelines would be very complex and with an incredible amount of LoC.

Also let me tell you that we appreciate your feedback and I really like that you're discussing this in a very respectful tone (which nowadays is very rare) but the easiest solution here is to just do it outside of the pipelines in your code.

@jiqing-feng
Copy link
Contributor Author

I see. Thanks for your clarification! I assume all model's force_upcast is True which is wrong. Closed this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants