-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
Hi everyone,
I have the following use-case. I have one base model e.g. flux-dev and I would like to dynamically load/unload different LoRAs.
I have tried different quantization methods - and to be honest it was a pain use them.
Quanto
was generating pure noise half of the time and not reliable at all. Also when I finally made it work, LoRAs won't load because of layers or keys were missing (depending on the LoRA).
Torchao
was very promising, but on the LoRA loading step was reporting that it supports only int8
.
While one wants to replicate a simple behaviour of comfyUI
and load fp8 checkpoint and have a similar speed and memory usage, my journey literally turned into hell.
Now the only quantization method described in the documentation was BitsAndBytes method. Let me post the code first:
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import T5EncoderModel,
from rich import print
from time import time
torch_dtype = torch.float16
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch_dtype,
cache_dir="./flux"
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = T5EncoderModel.from_pretrained(
model_id,
subfolder="text_encoder_2",
quantization_config=quant_config,
torch_dtype=torch_dtype,
cache_dir="./flux",
)
pipe = FluxPipeline.from_pretrained(
model_id,
transformer=transformer_8bit,
text_encoder_2=text_encoder_8bit,
torch_dtype=torch_dtype,
device_map="balanced",
cache_dir="./flux",
)
with torch.no_grad():
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
output_type="pil",
num_inference_steps=20,
guidance_scale=3.5,
height=1280,
width=720
).images[0]
pipe.load_lora_weights(
"./",
weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors",
adapter_name="speed"
)
pipe.set_adapters(["speed"], adapter_weights=[0.125])
with torch.no_grad():
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
output_type="pil",
num_inference_steps=8,
guidance_scale=3.5,
height=1280,
width=720,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
When I create pipeline:
pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="balanced",
cache_dir="./flux",
)
the vanilla version with memory usage between 30-40GB usage.
Without LoRA - I have 2.12it/sec
With LoRA - it drops to 1.87it/sec (12% decrease in speed) - well understandable (more parameters -> slower the model)
Now if I run it with quantization above, memory usage is between 22-26GB, however
Without LoRA - I have 1.56it/sec (26% decrease from vanilla)
With LoRA - it drops to 1.05it/sec (another 32% decrease in speed) - well understandable (more parameters -> slower the model)
Usage with LoRA also sums up to be 43% slower than with LoRA in vanilla mode. I am using L40s GPU
torch==2.6.0
diffussers==0.32.2
accelerate==1.6.0
transformers==4.50.3
peft==0.15.1
protobuf==6.30.2
sentencepiece=0.2.0
Is there anything could be done in general for dynamically loading LoRAs to quantized models?
And also for not losing so much performance?