-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Quantizing Flux with Bitsandbytes and moving the pipeline to another cuda device instead of cuda:0 will cause a device combination bug.
Everything works fine when the pipeline works in cuda:0.
Reproduction
This reproduction shows how the components are quantized.
And now I've updated a easier reproduction in the latest comment.
# First, load a quantized Flux pipeline
# The way loading the pipeline is from https://huggingface.co/docs/diffusers/quantization/bitsandbytes?bnb=4-bit
import torch
from diffusers import FluxPipeline
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.float16)
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.float16)
transformer_4bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer_4bit,
text_encoder_2=text_encoder_2_4bit,
torch_dtype=torch.float16,
).to("cuda:1")
print([str(item.device) for item in list(pipe.components.values()) if hasattr(item,"device")])
# The output is ['cuda:1', 'cuda:1', 'cuda:1', 'cuda:1'], which means the components are moved to cuda:1 successfully.
pipe_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 50,
"max_sequence_length": 512,
}
image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
# An error below will be raised.
Logs
...
/usr/local/lib/python3.10/dist-packages/bitsandbytes/functional.py in is_on_gpu(tensors)
467
468 if len(gpu_ids) > 1:
--> 469 raise RuntimeError(
470 f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
471 )
RuntimeError: Input tensors need to be on the same GPU, but found the following tensor and device combinations:
[(torch.Size([393216, 1]), device(type='cuda', index=1)), (torch.Size([1, 256]), device(type='cuda', index=1)), (torch.Size([1, 3072]), device(type='cuda', index=1)), (torch.Size([12288]), device(type='cuda', index=1)), (torch.Size([16]), device(type='cuda', index=0))]
System Info
- 🤗 Diffusers version: 0.32.2
- Platform: Linux-6.6.56+-x86_64-with-glibc2.35
- Running on Google Colab?: Yes
- Python version: 3.10.12
- PyTorch version (GPU?): 2.5.1+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
- Jax version: 0.4.33
- JaxLib version: 0.4.33
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.3
- Accelerate version: 1.2.1
- PEFT version: 0.14.0
- Bitsandbytes version: 0.45.2
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: Tesla T4, 15360 MiB
Tesla T4, 15360 MiB - Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
StephanAkkerman
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working