6
6
from diffusers .models .transformers .transformer_flux import FluxTransformer2DModel
7
7
from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
8
8
from optimum .quanto import qfloat8
9
- from optimum .quanto .models import QuantizedDiffusersModel
9
+ from optimum .quanto .models import QuantizedDiffusersModel , QuantizedTransformersModel
10
10
from PIL import Image
11
11
from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
12
+ from transformers .models .auto import AutoModelForTextEncoding
12
13
13
14
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
14
15
from invokeai .app .invocations .fields import InputField , WithBoard , WithMetadata
@@ -24,6 +25,10 @@ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
24
25
base_class = FluxTransformer2DModel
25
26
26
27
28
+ class QuantizedModelForTextEncoding (QuantizedTransformersModel ):
29
+ auto_class = AutoModelForTextEncoding
30
+
31
+
27
32
@invocation (
28
33
"flux_text_to_image" ,
29
34
title = "FLUX Text to Image" ,
@@ -196,9 +201,35 @@ def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
196
201
assert isinstance (model , CLIPTextModel )
197
202
return model
198
203
199
- @staticmethod
200
- def _load_flux_text_encoder_2 (path : Path ) -> T5EncoderModel :
201
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
204
+ def _load_flux_text_encoder_2 (self , path : Path ) -> T5EncoderModel :
205
+ if self .use_8bit :
206
+ model_8bit_path = path / "quantized"
207
+ if model_8bit_path .exists ():
208
+ # The quantized model exists, load it.
209
+ # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
210
+ # something that we should be able to make much faster.
211
+ q_model = QuantizedModelForTextEncoding .from_pretrained (model_8bit_path )
212
+
213
+ # Access the underlying wrapped model.
214
+ # We access the wrapped model, even though it is private, because it simplifies the type checking by
215
+ # always returning a T5EncoderModel from this function.
216
+ model = q_model ._wrapped
217
+ else :
218
+ # The quantized model does not exist yet, quantize and save it.
219
+ # TODO(ryand): dtype?
220
+ model = T5EncoderModel .from_pretrained (path , local_files_only = True )
221
+ assert isinstance (model , T5EncoderModel )
222
+
223
+ q_model = QuantizedModelForTextEncoding .quantize (model , weights = qfloat8 )
224
+
225
+ model_8bit_path .mkdir (parents = True , exist_ok = True )
226
+ q_model .save_pretrained (model_8bit_path )
227
+
228
+ # (See earlier comment about accessing the wrapped model.)
229
+ model = q_model ._wrapped
230
+ else :
231
+ model = T5EncoderModel .from_pretrained (path , local_files_only = True )
232
+
202
233
assert isinstance (model , T5EncoderModel )
203
234
return model
204
235
0 commit comments