Skip to content

Commit 89a652c

Browse files
committed
Got FLUX schnell working with 8-bit quantization. Still lots of rough edges to clean up.
1 parent b227b90 commit 89a652c

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import json
12
from pathlib import Path
23
from typing import Literal
34

45
import torch
56
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
67
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
7-
from diffusers.pipelines.flux import FluxPipeline
8+
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9+
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
810
from PIL import Image
11+
from safetensors.torch import load_file, save_file
912
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
1013

1114
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@@ -29,6 +32,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
2932
"""Text-to-image generation using a FLUX model."""
3033

3134
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
35+
use_8bit: bool = InputField(
36+
default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision."
37+
)
3238
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
3339
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
3440
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
@@ -110,7 +116,10 @@ def _run_diffusion(
110116
clip_embeddings: torch.Tensor,
111117
t5_embeddings: torch.Tensor,
112118
):
113-
scheduler = FlowMatchEulerDiscreteScheduler()
119+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
120+
121+
# HACK(ryand): Manually empty the cache.
122+
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
114123

115124
transformer_path = flux_model_dir / "transformer"
116125
with context.models.load_local_model(
@@ -144,7 +153,7 @@ def _run_vae_decoding(
144153
self,
145154
context: InvocationContext,
146155
flux_model_dir: Path,
147-
latent: torch.Tensor,
156+
latents: torch.Tensor,
148157
) -> Image.Image:
149158
vae_path = flux_model_dir / "vae"
150159
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
@@ -166,8 +175,9 @@ def _run_vae_decoding(
166175
latents = (
167176
latents / flux_pipeline_with_vae.vae.config.scaling_factor
168177
) + flux_pipeline_with_vae.vae.config.shift_factor
178+
latents = latents.to(dtype=vae.dtype)
169179
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
170-
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")
180+
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
171181

172182
assert isinstance(image, Image.Image)
173183
return image
@@ -184,9 +194,38 @@ def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
184194
assert isinstance(model, T5EncoderModel)
185195
return model
186196

187-
@staticmethod
188-
def _load_flux_transformer(path: Path) -> FluxTransformer2DModel:
189-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
197+
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
198+
if self.use_8bit:
199+
model_8bit_path = path / "quantized"
200+
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
201+
model_8bit_map_path = model_8bit_path / "quantization_map.json"
202+
if model_8bit_path.exists():
203+
# The quantized model exists, load it.
204+
with torch.device("meta"):
205+
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
206+
assert isinstance(model, FluxTransformer2DModel)
207+
208+
state_dict = load_file(model_8bit_weights_path)
209+
with open(model_8bit_map_path, "r") as f:
210+
quant_map = json.load(f)
211+
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
212+
else:
213+
# The quantized model does not exist yet, quantize and save it.
214+
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
215+
assert isinstance(model, FluxTransformer2DModel)
216+
217+
quantize(model, weights=qfloat8)
218+
freeze(model)
219+
220+
model_8bit_path.mkdir(parents=True, exist_ok=True)
221+
save_file(model.state_dict(), model_8bit_weights_path)
222+
with open(model_8bit_map_path, "w") as f:
223+
json.dump(quantization_map(model), f)
224+
else:
225+
model = FluxTransformer2DModel.from_pretrained(
226+
path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype()
227+
)
228+
190229
assert isinstance(model, FluxTransformer2DModel)
191230
return model
192231

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,17 @@ dependencies = [
4545
"onnx==1.15.0",
4646
"onnxruntime==1.16.3",
4747
"opencv-python==4.9.0.80",
48+
"optimum-quanto==0.2.4",
4849
"pytorch-lightning==2.1.3",
4950
"safetensors==0.4.3",
5051
# sentencepiece is required to load T5TokenizerFast (used by FLUX).
5152
"sentencepiece==0.2.0",
5253
"spandrel==0.3.4",
5354
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
54-
"torch==2.2.2",
55+
"torch==2.4.0",
5556
"torchmetrics==0.11.4",
5657
"torchsde==0.2.6",
57-
"torchvision==0.17.2",
58+
"torchvision==0.19.0",
5859
"transformers==4.41.1",
5960

6061
# Core application dependencies, pinned for reproducible builds.

0 commit comments

Comments
 (0)