Skip to content

bitsandbytes 8bit quant memory leak #11329

@Teriks

Description

@Teriks

Describe the bug

Using 8bit quant on pipeline modules results in un-freeable VRAM usage

This is probably more of a bitsandbytes issue?

I am wondering if there is a way to resolve this in the context of diffusers.

If you move the pipe to CPU, the modules are left on the GPU due to a bitsandbytes limitation.

If you garbage collect the pipe, bitsandbytes leaves the layers on the GPU.

Reproduction

import gc
import time

import torch
from transformers import BitsAndBytesConfig, CLIPTextModel, CLIPTextModelWithProjection
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel

# 8-bit quant config
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True
)

# Load base model name
model_id = "stabilityai/stable-diffusion-xl-base-1.0"


# Text encoders
text_encoder = CLIPTextModel.from_pretrained(
    model_id,
    subfolder="text_encoder",
    quantization_config=bnb_config
)

text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
    model_id,
    subfolder="text_encoder_2",
    quantization_config=bnb_config
)


# UNet
unet = UNet2DConditionModel.from_pretrained(
    model_id,
    subfolder="unet",
    torch_dtype=torch.float16,
    quantization_config=bnb_config
)


# Construct pipeline manually
pipe = StableDiffusionXLPipeline.from_pretrained(
    model_id,
    variant='fp16',
    text_encoder=text_encoder,
    text_encoder_2=text_encoder_2,
    unet=unet,
    torch_dtype=torch.float16
)

pipe.to('cuda')


# Generate
prompt = "test prompt"
image = pipe(prompt=prompt).images[0]
image.save("sdxl_8bit_quant_output.png")

pipe.to('cpu')

del pipe
torch.cuda.empty_cache()
gc.collect()

while True:
    # Check VRAM usage on your system here,
    # the quantized modules are still on the GPU
    time.sleep(1)

Logs

Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 49.52it/s]
The module 'CLIPTextModel' has been loaded in `bitsandbytes` 8bit and moving it to cuda via `.to()` is not supported. Module is still on cuda:0.
The module 'CLIPTextModelWithProjection' has been loaded in `bitsandbytes` 8bit and moving it to cuda via `.to()` is not supported. Module is still on cuda:0.
The module 'UNet2DConditionModel' has been loaded in `bitsandbytes` 8bit and moving it to cuda via `.to()` is not supported. Module is still on cuda:0.
100%|██████████| 50/50 [00:25<00:00,  1.98it/s]
The module 'CLIPTextModel' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.
The module 'CLIPTextModelWithProjection' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.
The module 'UNet2DConditionModel' has been loaded in `bitsandbytes` 8bit and moving it to cpu via `.to()` is not supported. Module is still on cuda:0.

System Info

Win11

python = 3.12.10
diffusers = 0.33.1
bitsandbytes = 0.45.5
transformers = 4.51.3

Who can help?

@yiyixuxu @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions