6
6
import torch
7
7
from diffusers .models .transformers .transformer_flux import FluxTransformer2DModel
8
8
from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
9
- from invokeai .app .invocations .model import ModelIdentifierField
9
+ from invokeai .app .invocations .model import TransformerField , VAEField
10
10
from optimum .quanto import qfloat8
11
11
from PIL import Image
12
12
from safetensors .torch import load_file
@@ -52,17 +52,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
52
52
class FluxTextToImageInvocation (BaseInvocation , WithMetadata , WithBoard ):
53
53
"""Text-to-image generation using a FLUX model."""
54
54
55
- flux_model : ModelIdentifierField = InputField (
56
- description = "The Flux model" ,
57
- input = Input .Any ,
58
- ui_type = UIType . FluxMainModel
55
+ transformer : TransformerField = InputField (
56
+ description = FieldDescriptions . unet ,
57
+ input = Input .Connection ,
58
+ title = "Transformer" ,
59
59
)
60
- model : TFluxModelKeys = InputField (description = "The FLUX model to use for text-to-image generation." )
61
- quantization_type : Literal ["raw" , "NF4" , "llm_int8" ] = InputField (
62
- default = "raw" , description = "The type of quantization to use for the transformer model."
63
- )
64
- use_8bit : bool = InputField (
65
- default = False , description = "Whether to quantize the transformer model to 8-bit precision."
60
+ vae : VAEField = InputField (
61
+ description = FieldDescriptions .vae ,
62
+ input = Input .Connection ,
66
63
)
67
64
positive_text_conditioning : ConditioningField = InputField (
68
65
description = FieldDescriptions .positive_cond , input = Input .Connection
@@ -78,70 +75,38 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
78
75
79
76
@torch .no_grad ()
80
77
def invoke (self , context : InvocationContext ) -> ImageOutput :
81
- # model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
82
- flux_transformer_path = context .models .download_and_cache_model (
83
- "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
84
- )
85
- flux_ae_path = context .models .download_and_cache_model (
86
- "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
87
- )
88
78
89
79
# Load the conditioning data.
90
80
cond_data = context .conditioning .load (self .positive_text_conditioning .conditioning_name )
91
81
assert len (cond_data .conditionings ) == 1
92
82
flux_conditioning = cond_data .conditionings [0 ]
93
83
assert isinstance (flux_conditioning , FLUXConditioningInfo )
94
84
95
- latents = self ._run_diffusion (
96
- context , flux_transformer_path , flux_conditioning .clip_embeds , flux_conditioning .t5_embeds
97
- )
98
- image = self ._run_vae_decoding (context , flux_ae_path , latents )
85
+ latents = self ._run_diffusion (context , flux_conditioning .clip_embeds , flux_conditioning .t5_embeds )
86
+ image = self ._run_vae_decoding (context , latents )
99
87
image_dto = context .images .save (image = image )
100
88
return ImageOutput .build (image_dto )
101
89
102
90
def _run_diffusion (
103
91
self ,
104
92
context : InvocationContext ,
105
- flux_transformer_path : Path ,
106
93
clip_embeddings : torch .Tensor ,
107
94
t5_embeddings : torch .Tensor ,
108
95
):
109
- inference_dtype = TorchDevice .choose_torch_dtype ()
110
-
111
- # Prepare input noise.
112
- # TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
113
- # CPU RNG?
114
- x = get_noise (
115
- num_samples = 1 ,
116
- height = self .height ,
117
- width = self .width ,
118
- device = TorchDevice .choose_torch_device (),
119
- dtype = inference_dtype ,
120
- seed = self .seed ,
121
- )
122
-
123
- img , img_ids = self ._prepare_latent_img_patches (x )
124
-
125
- # HACK(ryand): Find a better way to determine if this is a schnell model or not.
126
- is_schnell = "shnell" in str (flux_transformer_path )
127
- timesteps = get_schedule (
128
- num_steps = self .num_steps ,
129
- image_seq_len = img .shape [1 ],
130
- shift = not is_schnell ,
131
- )
132
-
133
- bs , t5_seq_len , _ = t5_embeddings .shape
134
- txt_ids = torch .zeros (bs , t5_seq_len , 3 , dtype = inference_dtype , device = TorchDevice .choose_torch_device ())
96
+ scheduler_info = context .models .load (self .transformer .scheduler )
97
+ transformer_info = context .models .load (self .transformer .transformer )
135
98
136
99
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
137
100
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
138
101
# if the cache is not empty.
139
- context .models ._services .model_manager .load .ram_cache .make_room (24 * 2 ** 30 )
102
+ # context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
140
103
141
- with context .models .load_local_model (
142
- model_path = flux_transformer_path , loader = self ._load_flux_transformer
143
- ) as transformer :
144
- assert isinstance (transformer , Flux )
104
+ with (
105
+ transformer_info as transformer ,
106
+ scheduler_info as scheduler
107
+ ):
108
+ assert isinstance (transformer , FluxTransformer2DModel )
109
+ assert isinstance (scheduler , FlowMatchEulerDiscreteScheduler )
145
110
146
111
x = denoise (
147
112
model = transformer ,
@@ -185,75 +150,25 @@ def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.T
185
150
def _run_vae_decoding (
186
151
self ,
187
152
context : InvocationContext ,
188
- flux_ae_path : Path ,
189
153
latents : torch .Tensor ,
190
154
) -> Image .Image :
191
- with context .models .load_local_model (model_path = flux_ae_path , loader = self ._load_flux_vae ) as vae :
192
- assert isinstance (vae , AutoEncoder )
193
- # TODO(ryand): Test that this works with both float16 and bfloat16.
194
- with torch .autocast (device_type = latents .device .type , dtype = TorchDevice .choose_torch_dtype ()):
195
- img = vae .decode (latents )
155
+ vae_info = context .models .load (self .vae .vae )
156
+ with vae_info as vae :
157
+ assert isinstance (vae , AutoencoderKL )
196
158
197
159
img .clamp (- 1 , 1 )
198
160
img = rearrange (img [0 ], "c h w -> h w c" )
199
161
img_pil = Image .fromarray ((127.5 * (img + 1.0 )).byte ().cpu ().numpy ())
200
162
201
- return img_pil
202
-
203
- def _load_flux_transformer (self , path : Path ) -> FluxTransformer2DModel :
204
- inference_dtype = TorchDevice .choose_torch_dtype ()
205
- if self .quantization_type == "raw" :
206
- # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
207
- params = flux_configs ["flux-schnell" ].params
208
-
209
- # Initialize the model on the "meta" device.
210
- with accelerate .init_empty_weights ():
211
- model = Flux (params ).to (inference_dtype )
212
-
213
- state_dict = load_file (path )
214
- # TODO(ryand): Cast the state_dict to the appropriate dtype?
215
- model .load_state_dict (state_dict , strict = True , assign = True )
216
- elif self .quantization_type == "NF4" :
217
- model_path = path .parent / "bnb_nf4.safetensors"
218
-
219
- # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
220
- params = flux_configs ["flux-schnell" ].params
221
- # Initialize the model on the "meta" device.
222
- with accelerate .init_empty_weights ():
223
- model = Flux (params )
224
- model = quantize_model_nf4 (model , modules_to_not_convert = set (), compute_dtype = torch .bfloat16 )
225
-
226
- # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
227
- # this on GPUs without bfloat16 support.
228
- state_dict = load_file (model_path )
229
- model .load_state_dict (state_dict , strict = True , assign = True )
230
-
231
- elif self .quantization_type == "llm_int8" :
232
- raise NotImplementedError ("LLM int8 quantization is not yet supported." )
233
- # model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
234
- # with accelerate.init_empty_weights():
235
- # empty_model = FluxTransformer2DModel.from_config(model_config)
236
- # assert isinstance(empty_model, FluxTransformer2DModel)
237
- # model_int8_path = path / "bnb_llm_int8"
238
- # assert model_int8_path.exists()
239
- # with accelerate.init_empty_weights():
240
- # model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
241
-
242
- # sd = load_file(model_int8_path / "model.safetensors")
243
- # model.load_state_dict(sd, strict=True, assign=True)
244
- else :
245
- raise ValueError (f"Unsupported quantization type: { self .quantization_type } " )
246
-
247
- assert isinstance (model , FluxTransformer2DModel )
248
- return model
249
-
250
- @staticmethod
251
- def _load_flux_vae (path : Path ) -> AutoEncoder :
252
- # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
253
- ae_params = flux_configs ["flux1-schnell" ].ae_params
254
- with accelerate .init_empty_weights ():
255
- ae = AutoEncoder (ae_params )
256
-
257
- state_dict = load_file (path )
258
- ae .load_state_dict (state_dict , strict = True , assign = True )
259
- return ae
163
+ latents = flux_pipeline_with_vae ._unpack_latents (
164
+ latents , self .height , self .width , flux_pipeline_with_vae .vae_scale_factor
165
+ )
166
+ latents = (
167
+ latents / flux_pipeline_with_vae .vae .config .scaling_factor
168
+ ) + flux_pipeline_with_vae .vae .config .shift_factor
169
+ latents = latents .to (dtype = vae .dtype )
170
+ image = flux_pipeline_with_vae .vae .decode (latents , return_dict = False )[0 ]
171
+ image = flux_pipeline_with_vae .image_processor .postprocess (image , output_type = "pil" )[0 ]
172
+
173
+ assert isinstance (image , Image .Image )
174
+ return image
0 commit comments