Skip to content

Quantization is slow with FLUX.1-dev AND no effective LoRA support #11215

@cryptexis

Description

@cryptexis

Hi everyone,

I have the following use-case. I have one base model e.g. flux-dev and I would like to dynamically load/unload different LoRAs.
I have tried different quantization methods - and to be honest it was a pain use them.
Quanto was generating pure noise half of the time and not reliable at all. Also when I finally made it work, LoRAs won't load because of layers or keys were missing (depending on the LoRA).
Torchao was very promising, but on the LoRA loading step was reporting that it supports only int8.

While one wants to replicate a simple behaviour of comfyUI and load fp8 checkpoint and have a similar speed and memory usage, my journey literally turned into hell.

Now the only quantization method described in the documentation was BitsAndBytes method. Let me post the code first:

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline

from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import T5EncoderModel,

from rich import print
from time import time

torch_dtype = torch.float16

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch_dtype,
    cache_dir="./flux"
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = T5EncoderModel.from_pretrained(
    model_id,
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch_dtype,
    cache_dir="./flux",

)

pipe = FluxPipeline.from_pretrained(
    model_id, 
    transformer=transformer_8bit, 
    text_encoder_2=text_encoder_8bit,
    torch_dtype=torch_dtype,
    device_map="balanced",
    cache_dir="./flux",
)

with torch.no_grad():
    prompt = "A cat holding a sign that says hello world"
    image = pipe(
        prompt,
        output_type="pil",
        num_inference_steps=20,
        guidance_scale=3.5,
        height=1280,
        width=720
    ).images[0]

pipe.load_lora_weights(
    "./",  
    weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors",
    adapter_name="speed"
)
pipe.set_adapters(["speed"], adapter_weights=[0.125])

with torch.no_grad():
    prompt = "A cat holding a sign that says hello world"
    image = pipe(
        prompt,
        output_type="pil",
        num_inference_steps=8,
        guidance_scale=3.5,
        height=1280,
        width=720,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]

When I create pipeline:

pipe = FluxPipeline.from_pretrained(
    model_id, 
    torch_dtype=torch_dtype,
    device_map="balanced",
    cache_dir="./flux",
)

the vanilla version with memory usage between 30-40GB usage.

Without LoRA - I have 2.12it/sec
With LoRA - it drops to 1.87it/sec (12% decrease in speed) - well understandable (more parameters -> slower the model)

Now if I run it with quantization above, memory usage is between 22-26GB, however

Without LoRA - I have 1.56it/sec (26% decrease from vanilla)
With LoRA - it drops to 1.05it/sec (another 32% decrease in speed) - well understandable (more parameters -> slower the model)

Usage with LoRA also sums up to be 43% slower than with LoRA in vanilla mode. I am using L40s GPU

torch==2.6.0
diffussers==0.32.2
accelerate==1.6.0
transformers==4.50.3
peft==0.15.1
protobuf==6.30.2
sentencepiece=0.2.0

Is there anything could be done in general for dynamically loading LoRAs to quantized models?
And also for not losing so much performance?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions