Skip to content

Commit 74d6fce

Browse files
committed
Add support for 8-bit quantizatino of the FLUX T5XXL text encoder.
1 parent 766ddc1 commit 74d6fce

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
77
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
88
from optimum.quanto import qfloat8
9-
from optimum.quanto.models import QuantizedDiffusersModel
9+
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
1010
from PIL import Image
1111
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
12+
from transformers.models.auto import AutoModelForTextEncoding
1213

1314
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
1415
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
@@ -24,6 +25,10 @@ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
2425
base_class = FluxTransformer2DModel
2526

2627

28+
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
29+
auto_class = AutoModelForTextEncoding
30+
31+
2732
@invocation(
2833
"flux_text_to_image",
2934
title="FLUX Text to Image",
@@ -196,9 +201,35 @@ def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
196201
assert isinstance(model, CLIPTextModel)
197202
return model
198203

199-
@staticmethod
200-
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
201-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
204+
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
205+
if self.use_8bit:
206+
model_8bit_path = path / "quantized"
207+
if model_8bit_path.exists():
208+
# The quantized model exists, load it.
209+
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
210+
# something that we should be able to make much faster.
211+
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
212+
213+
# Access the underlying wrapped model.
214+
# We access the wrapped model, even though it is private, because it simplifies the type checking by
215+
# always returning a T5EncoderModel from this function.
216+
model = q_model._wrapped
217+
else:
218+
# The quantized model does not exist yet, quantize and save it.
219+
# TODO(ryand): dtype?
220+
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
221+
assert isinstance(model, T5EncoderModel)
222+
223+
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
224+
225+
model_8bit_path.mkdir(parents=True, exist_ok=True)
226+
q_model.save_pretrained(model_8bit_path)
227+
228+
# (See earlier comment about accessing the wrapped model.)
229+
model = q_model._wrapped
230+
else:
231+
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
232+
202233
assert isinstance(model, T5EncoderModel)
203234
return model
204235

0 commit comments

Comments
 (0)