-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Using 8bit quant on pipeline modules results in un-freeable VRAM usage
This is probably more of a bitsandbytes issue?
I am wondering if there is a way to resolve this in the context of diffusers.
If you move the pipe to CPU, the modules are left on the GPU due to a bitsandbytes limitation.
If you garbage collect the pipe, bitsandbytes leaves the layers on the GPU.
Reproduction
import gc
import time
import torch
from transformers import BitsAndBytesConfig, CLIPTextModel, CLIPTextModelWithProjection
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
# 8-bit quant config
bnb_config = BitsAndBytesConfig(
load_in_8bit=True
)
# Load base model name
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
# Text encoders
text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
quantization_config=bnb_config
)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
quantization_config=bnb_config
)
# UNet
unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
torch_dtype=torch.float16,
quantization_config=bnb_config
)
# Construct pipeline manually
pipe = StableDiffusionXLPipeline.from_pretrained(
model_id,
variant='fp16',
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
unet=unet,
torch_dtype=torch.float16
)
pipe.to('cuda')
# Generate
prompt = "test prompt"
image = pipe(prompt=prompt).images[0]
image.save("sdxl_8bit_quant_output.png")
pipe.to('cpu')
del pipe
torch.cuda.empty_cache()
gc.collect()
while True:
# Check VRAM usage on your system here,
# the quantized modules are still on the GPU
time.sleep(1)
Logs
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 49.52it/s]
The module 'CLIPTextModel' has been loaded in `bitsandbytes` 8bit and moving it to cuda via `.to()` is not supported. Module is still on cuda:0.
The module 'CLIPTextModelWithProjection' has been loaded in `bitsandbytes` 8bit and moving it to cuda via `.to()` is not supported. Module is still on cuda:0.
The module 'UNet2DConditionModel' has been loaded in `bitsandbytes` 8bit and moving it to cuda via `.to()` is not supported. Module is still on cuda:0.
100%|██████████| 50/50 [00:25<00:00, 1.98it/s]
The module 'CLIPTextModel' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.
The module 'CLIPTextModelWithProjection' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.
The module 'UNet2DConditionModel' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.
System Info
Win11
python = 3.12.10
diffusers = 0.33.1
bitsandbytes = 0.45.5
transformers = 4.51.3
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working