-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix vae dtype #11228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vae dtype #11228
Conversation
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
HI, I want to understand your use case here, if you are using the vae outside the pipeline, can't you just store the original |
I run the following script import PIL
import requests
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
IMG_URL = "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/main/imgs/example.jpg"
image = PIL.Image.open(requests.get(IMG_URL, stream=True).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
prompt = "turn him into cyborg"
model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
use_safetensors=True,
)
# To avoid overflow and precision loss.
pipe.vae.to(torch.float32)
new_image = pipe(prompt=prompt, image=image).images[0] I convert the vae to fp32 after I load the pipeline to avoid overflow and precision error. But in the current main codes, the vae will convert to input dtype which is float16 here. Then, the fp16 vae will trigger the force cast here. There are 2 logical errors even it won't break the e2e pipeline but will confuse users.
Please let me know if I didn't make it clear. |
Thanks a lot for explaining. The SDXL original vae it's an exception though, in all the other pipelines, the vae has the same dtype as the model, if you want to have a vae with a different precision than the model and the pipeline it should be something that the user has to manage like you're doing. To be more clear, if the SDXL 1.0 original vae didn't have the overflow problem, the I don't see a problem with what you're suggesting in this PR but if we add the options that everyone needs for their special use case, the pipelines would be huge, so I'll let the others take this decision. |
Hi @asomoza , thanks! I get your point! In that case, can we force vae to float32 and output a log to notice users that the vae will be forced to float32 whatever the SDXL model dtype is (to avoid overflow) ? |
IMO that would be a waste of VRAM because of this vae which runs in fp16 without a problem and that most popular finetuned models already use instead of the original one. Still I don't think that what you're suggesting in this PR is a bad idea, so lets maybe wait to see what are the opinions of @DN6 and @yiyixuxu, they're kind of busy right now so we need a little patience. |
hey @jiqing-feng |
Hi @asomoza @yiyixuxu . Thanks for your review. There are 2 things that confused me:
Would like to get your feedback. Thanks! |
@jiqing-feng we don't disable on point 2 I agree but it has been a long time since the release of SDXL and the use of the pipelines and this is the only time that someone opened an issue or PR about it, so mostly, everyone else just use the fp16 fixed one or just don't use custom code with the original vae. As I said before, your issue is very specific to you and only you, and if we allow every user to add/remove/refactor the pipelines to their very specific use case, the pipelines would be very complex and with an incredible amount of LoC. Also let me tell you that we appreciate your feedback and I really like that you're discussing this in a very respectful tone (which nowadays is very rare) but the easiest solution here is to just do it outside of the pipelines in your code. |
I see. Thanks for your clarification! I assume all model's |
Hi @sayakpaul . As vae can be manipulated outside the pipeline by users, we should revert to the original dtype of vae instead of the latent dtype. Please review this PR. Thanks!