Skip to content

Commit 766ddc1

Browse files
committed
Make 8-bit quantization save/reload work for the FLUX transformer. Reload is still very slow with the current optimum.quanto implementation.
1 parent e6ff748 commit 766ddc1

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
import json
21
from pathlib import Path
32
from typing import Literal
43

54
import torch
65
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
76
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
87
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9-
from optimum.quanto import freeze, qfloat8, quantization_map, quantize, requantize
8+
from optimum.quanto import qfloat8
9+
from optimum.quanto.models import QuantizedDiffusersModel
1010
from PIL import Image
11-
from safetensors.torch import load_file, save_file
1211
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
1312

1413
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@@ -21,6 +20,10 @@
2120
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
2221

2322

23+
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
24+
base_class = FluxTransformer2DModel
25+
26+
2427
@invocation(
2528
"flux_text_to_image",
2629
title="FLUX Text to Image",
@@ -202,23 +205,16 @@ def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
202205
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
203206
if self.use_8bit:
204207
model_8bit_path = path / "quantized"
205-
model_8bit_weights_path = model_8bit_path / "weights.safetensors"
206-
model_8bit_map_path = model_8bit_path / "quantization_map.json"
207208
if model_8bit_path.exists():
208209
# The quantized model exists, load it.
209-
# TODO(ryand): Make loading from quantized model work properly.
210-
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
211-
model = FluxTransformer2DModel.from_pretrained(
212-
path,
213-
local_files_only=True,
214-
)
215-
assert isinstance(model, FluxTransformer2DModel)
216-
model = model.to(device=torch.device("meta"))
217-
218-
state_dict = load_file(model_8bit_weights_path)
219-
with open(model_8bit_map_path, "r") as f:
220-
quant_map = json.load(f)
221-
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
210+
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
211+
# something that we should be able to make much faster.
212+
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
213+
214+
# Access the underlying wrapped model.
215+
# We access the wrapped model, even though it is private, because it simplifies the type checking by
216+
# always returning a FluxTransformer2DModel from this function.
217+
model = q_model._wrapped
222218
else:
223219
# The quantized model does not exist yet, quantize and save it.
224220
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
@@ -227,13 +223,13 @@ def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
227223
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
228224
assert isinstance(model, FluxTransformer2DModel)
229225

230-
quantize(model, weights=qfloat8)
231-
freeze(model)
226+
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
232227

233228
model_8bit_path.mkdir(parents=True, exist_ok=True)
234-
save_file(model.state_dict(), model_8bit_weights_path)
235-
with open(model_8bit_map_path, "w") as f:
236-
json.dump(quantization_map(model), f)
229+
q_model.save_pretrained(model_8bit_path)
230+
231+
# (See earlier comment about accessing the wrapped model.)
232+
model = q_model._wrapped
237233
else:
238234
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
239235

0 commit comments

Comments
 (0)