1
+ import json
1
2
from pathlib import Path
2
3
from typing import Literal
3
4
4
5
import torch
5
6
from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler
6
7
from diffusers .models .transformers .transformer_flux import FluxTransformer2DModel
7
- from diffusers .pipelines .flux import FluxPipeline
8
+ from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
9
+ from optimum .quanto import freeze , qfloat8 , quantization_map , quantize , requantize
8
10
from PIL import Image
11
+ from safetensors .torch import load_file , save_file
9
12
from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
10
13
11
14
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
@@ -29,6 +32,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
29
32
"""Text-to-image generation using a FLUX model."""
30
33
31
34
model : TFluxModelKeys = InputField (description = "The FLUX model to use for text-to-image generation." )
35
+ use_8bit : bool = InputField (
36
+ default = False , description = "Whether to quantize the T5 model and transformer model to 8-bit precision."
37
+ )
32
38
positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
33
39
width : int = InputField (default = 1024 , multiple_of = 16 , description = "Width of the generated image." )
34
40
height : int = InputField (default = 1024 , multiple_of = 16 , description = "Height of the generated image." )
@@ -110,7 +116,10 @@ def _run_diffusion(
110
116
clip_embeddings : torch .Tensor ,
111
117
t5_embeddings : torch .Tensor ,
112
118
):
113
- scheduler = FlowMatchEulerDiscreteScheduler ()
119
+ scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (flux_model_dir / "scheduler" , local_files_only = True )
120
+
121
+ # HACK(ryand): Manually empty the cache.
122
+ context .models ._services .model_manager .load .ram_cache .make_room (24 * 2 ** 30 )
114
123
115
124
transformer_path = flux_model_dir / "transformer"
116
125
with context .models .load_local_model (
@@ -144,7 +153,7 @@ def _run_vae_decoding(
144
153
self ,
145
154
context : InvocationContext ,
146
155
flux_model_dir : Path ,
147
- latent : torch .Tensor ,
156
+ latents : torch .Tensor ,
148
157
) -> Image .Image :
149
158
vae_path = flux_model_dir / "vae"
150
159
with context .models .load_local_model (model_path = vae_path , loader = self ._load_flux_vae ) as vae :
@@ -166,8 +175,9 @@ def _run_vae_decoding(
166
175
latents = (
167
176
latents / flux_pipeline_with_vae .vae .config .scaling_factor
168
177
) + flux_pipeline_with_vae .vae .config .shift_factor
178
+ latents = latents .to (dtype = vae .dtype )
169
179
image = flux_pipeline_with_vae .vae .decode (latents , return_dict = False )[0 ]
170
- image = flux_pipeline_with_vae .image_processor .postprocess (image , output_type = "pil" )
180
+ image = flux_pipeline_with_vae .image_processor .postprocess (image , output_type = "pil" )[ 0 ]
171
181
172
182
assert isinstance (image , Image .Image )
173
183
return image
@@ -184,9 +194,38 @@ def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
184
194
assert isinstance (model , T5EncoderModel )
185
195
return model
186
196
187
- @staticmethod
188
- def _load_flux_transformer (path : Path ) -> FluxTransformer2DModel :
189
- model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True )
197
+ def _load_flux_transformer (self , path : Path ) -> FluxTransformer2DModel :
198
+ if self .use_8bit :
199
+ model_8bit_path = path / "quantized"
200
+ model_8bit_weights_path = model_8bit_path / "weights.safetensors"
201
+ model_8bit_map_path = model_8bit_path / "quantization_map.json"
202
+ if model_8bit_path .exists ():
203
+ # The quantized model exists, load it.
204
+ with torch .device ("meta" ):
205
+ model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True )
206
+ assert isinstance (model , FluxTransformer2DModel )
207
+
208
+ state_dict = load_file (model_8bit_weights_path )
209
+ with open (model_8bit_map_path , "r" ) as f :
210
+ quant_map = json .load (f )
211
+ requantize (model = model , state_dict = state_dict , quantization_map = quant_map )
212
+ else :
213
+ # The quantized model does not exist yet, quantize and save it.
214
+ model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
215
+ assert isinstance (model , FluxTransformer2DModel )
216
+
217
+ quantize (model , weights = qfloat8 )
218
+ freeze (model )
219
+
220
+ model_8bit_path .mkdir (parents = True , exist_ok = True )
221
+ save_file (model .state_dict (), model_8bit_weights_path )
222
+ with open (model_8bit_map_path , "w" ) as f :
223
+ json .dump (quantization_map (model ), f )
224
+ else :
225
+ model = FluxTransformer2DModel .from_pretrained (
226
+ path , local_files_only = True , torch_dtype = TorchDevice .choose_torch_dtype ()
227
+ )
228
+
190
229
assert isinstance (model , FluxTransformer2DModel )
191
230
return model
192
231
0 commit comments