Skip to content

ensure dtype match between diffused latents and vae weights #8391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025

Conversation

heyalexchoi
Copy link
Contributor

What does this PR do?

Simple fix to diffused latent dtype not matching vae weights dtype. See error below. I had this issue when loading pipeline in bfloat16 and using accelerate.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/PixArt-sigma/diffusion/utils/image_evaluation.py", line 150, in generate_images
    batch_images = pipeline(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 866, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
  File "/workspace/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 305, in decode
    decoded = self._decode(z, return_dict=False)[0]
  File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 277, in _decode
    z = self.post_quant_conv(z)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
Copy link
Member

Thanks for your PR. Does it only when using the Sigma pipeline? Would something like this would be more prudent to implement?

needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira
Copy link
Contributor

bghira commented Jun 4, 2024

this also occurs under SD 1.x/2.x and SDXL under accelerate, the default dtype for torch is fp32 but the vae dtype is bf16.

here is an error seen when using SDXL Refiner:

2024-06-05 00:41:54,010 [ERROR] (helpers.training.validation) Error generating validation image: Input type (fl
oat) and bias type (c10::BFloat16) should be the same, Traceback (most recent call last):
  File "/notebooks/SimpleTuner/helpers/training/validation.py", line 534, in validate_prompt
    validation_image_results = self.pipeline(**pipeline_kwargs).images
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py", line 1422, in __call__
    image = self.vae.decode(latents, return_dict=False)[0]
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 304, in decode
    decoded = self._decode(z).sample
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 274, in _decode
    z = self.post_quant_conv(z)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

@bghira
Copy link
Contributor

bghira commented Jun 4, 2024

#7886 is same/similar

@heyalexchoi
Copy link
Contributor Author

Thanks for your PR. Does it only when using the Sigma pipeline? Would something like this would be more prudent to implement?

needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

I don't know much about the background to a force_upcast config param. I do know I have had this issue in PixArt pipelines (maybe alpha too?) a few times. This fix seems simple and I don't see any downside.

@sayakpaul
Copy link
Member

Will defer to @yiyixuxu for an opinion on how to best proceed. IMO, we should handle in the same way as

needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

@bghira
Copy link
Contributor

bghira commented Jun 5, 2024

do you mean to provide a conditional check instead of unconditionally casting it to the vae's dtype? or do you mean we should set force_upcast in a certain situation?

for the former, i'm curious what problems you foresee with doing it unconditionally. it's not that having a check would hurt, but i also don't see it hurting anything to ensure the latents are equal to the vae dtype before decode.

for the latter, this is a situation where upcasting the vae to be the same as the latents is unnecessary, eg. i am using the fp16 fixed SDXL VAE for decode, and upcasting will just waste resources. the problem is that the latents become fp32 after being modified by the pipeline just a few lines prior to the decode, but the vae itself is bf16.

tl;dr i think casting to the vae dtype is the correct solution rather than upcasting vae to the latents dtype.

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@bghira
Copy link
Contributor

bghira commented Apr 6, 2025

i thought / hoped maybe it'd been fixed, but when trying to use the upstream vanilla diffusers pipelines for vae decode during training, it's still hitting this issue (even with Accelerate)

@hlky
Copy link
Contributor

hlky commented Apr 7, 2025

@bghira Can you share a minimal reproduction?

@bghira
Copy link
Contributor

bghira commented Apr 7, 2025

nope i am not sure what causes the dtype switch. i think it is HF Accelerate. but the latents are fp32 and vae is bf16.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delay!

@yiyixuxu yiyixuxu merged commit 5ded26c into huggingface:main Apr 7, 2025
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 7, 2025

sorry for the delay! I somehow missed this PR, the fix is simple and should not cause any problems, and this is different from the situation where we need to upcast vae in fp32

however, would like to know more when latents get upcasted in fp32 with accelerate @heyalexchoi or @bghira in a follow-up PR, if you could provide a minimal code example that'd be great

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants