Skip to content

Device Combinations Bug of Flux Quantization With Bitsandbytes #10798

@CyberVy

Description

@CyberVy

Describe the bug

Quantizing Flux with Bitsandbytes and moving the pipeline to another cuda device instead of cuda:0 will cause a device combination bug.

Everything works fine when the pipeline works in cuda:0.

Reproduction

This reproduction shows how the components are quantized.
And now I've updated a easier reproduction in the latest comment.

# First, load a quantized Flux pipeline
# The way loading the pipeline is from https://huggingface.co/docs/diffusers/quantization/bitsandbytes?bnb=4-bit

import torch
from diffusers import FluxPipeline

from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel

quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.float16)

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.float16)

transformer_4bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer_4bit,
    text_encoder_2=text_encoder_2_4bit,
    torch_dtype=torch.float16,
).to("cuda:1")

print([str(item.device) for item in list(pipe.components.values()) if hasattr(item,"device")])
# The output is ['cuda:1', 'cuda:1', 'cuda:1', 'cuda:1'], which means the components are moved to cuda:1 successfully.

pipe_kwargs = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
# An error below will be raised.

Logs

...
/usr/local/lib/python3.10/dist-packages/bitsandbytes/functional.py in is_on_gpu(tensors)
    467 
    468     if len(gpu_ids) > 1:
--> 469         raise RuntimeError(
    470             f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
    471         )

RuntimeError: Input tensors need to be on the same GPU, but found the following tensor and device combinations:
 [(torch.Size([393216, 1]), device(type='cuda', index=1)), (torch.Size([1, 256]), device(type='cuda', index=1)), (torch.Size([1, 3072]), device(type='cuda', index=1)), (torch.Size([12288]), device(type='cuda', index=1)), (torch.Size([16]), device(type='cuda', index=0))]

System Info

  • 🤗 Diffusers version: 0.32.2
  • Platform: Linux-6.6.56+-x86_64-with-glibc2.35
  • Running on Google Colab?: Yes
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.3
  • Accelerate version: 1.2.1
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.2
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: Tesla T4, 15360 MiB
    Tesla T4, 15360 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions