7
7
from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
8
8
from optimum .quanto import qfloat8
9
9
from PIL import Image
10
- from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
11
10
from transformers .models .auto import AutoModelForTextEncoding
12
11
13
12
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
14
- from invokeai .app .invocations .fields import InputField , WithBoard , WithMetadata
13
+ from invokeai .app .invocations .fields import (
14
+ ConditioningField ,
15
+ FieldDescriptions ,
16
+ Input ,
17
+ InputField ,
18
+ WithBoard ,
19
+ WithMetadata ,
20
+ )
15
21
from invokeai .app .invocations .primitives import ImageOutput
16
22
from invokeai .app .services .shared .invocation_context import InvocationContext
17
23
from invokeai .backend .quantization .fast_quantized_diffusion_model import FastQuantizedDiffusersModel
18
24
from invokeai .backend .quantization .fast_quantized_transformers_model import FastQuantizedTransformersModel
19
- from invokeai .backend .util . devices import TorchDevice
25
+ from invokeai .backend .stable_diffusion . diffusion . conditioning_data import FLUXConditioningInfo
20
26
21
27
TFluxModelKeys = Literal ["flux-schnell" ]
22
28
FLUX_MODELS : dict [TFluxModelKeys , str ] = {"flux-schnell" : "black-forest-labs/FLUX.1-schnell" }
@@ -44,7 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
44
50
use_8bit : bool = InputField (
45
51
default = False , description = "Whether to quantize the transformer model to 8-bit precision."
46
52
)
47
- positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
53
+ positive_text_conditioning : ConditioningField = InputField (
54
+ description = FieldDescriptions .positive_cond , input = Input .Connection
55
+ )
48
56
width : int = InputField (default = 1024 , multiple_of = 16 , description = "Width of the generated image." )
49
57
height : int = InputField (default = 1024 , multiple_of = 16 , description = "Height of the generated image." )
50
58
num_steps : int = InputField (default = 4 , description = "Number of diffusion steps." )
@@ -58,66 +66,17 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
58
66
def invoke (self , context : InvocationContext ) -> ImageOutput :
59
67
model_path = context .models .download_and_cache_model (FLUX_MODELS [self .model ])
60
68
61
- t5_embeddings , clip_embeddings = self ._encode_prompt (context , model_path )
62
- latents = self ._run_diffusion (context , model_path , clip_embeddings , t5_embeddings )
69
+ # Load the conditioning data.
70
+ cond_data = context .conditioning .load (self .positive_text_conditioning .conditioning_name )
71
+ assert len (cond_data .conditionings ) == 1
72
+ flux_conditioning = cond_data .conditionings [0 ]
73
+ assert isinstance (flux_conditioning , FLUXConditioningInfo )
74
+
75
+ latents = self ._run_diffusion (context , model_path , flux_conditioning .clip_embeds , flux_conditioning .t5_embeds )
63
76
image = self ._run_vae_decoding (context , model_path , latents )
64
77
image_dto = context .images .save (image = image )
65
78
return ImageOutput .build (image_dto )
66
79
67
- def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
68
- # Determine the T5 max sequence length based on the model.
69
- if self .model == "flux-schnell" :
70
- max_seq_len = 256
71
- # elif self.model == "flux-dev":
72
- # max_seq_len = 512
73
- else :
74
- raise ValueError (f"Unknown model: { self .model } " )
75
-
76
- # Load the CLIP tokenizer.
77
- clip_tokenizer_path = flux_model_dir / "tokenizer"
78
- clip_tokenizer = CLIPTokenizer .from_pretrained (clip_tokenizer_path , local_files_only = True )
79
- assert isinstance (clip_tokenizer , CLIPTokenizer )
80
-
81
- # Load the T5 tokenizer.
82
- t5_tokenizer_path = flux_model_dir / "tokenizer_2"
83
- t5_tokenizer = T5TokenizerFast .from_pretrained (t5_tokenizer_path , local_files_only = True )
84
- assert isinstance (t5_tokenizer , T5TokenizerFast )
85
-
86
- clip_text_encoder_path = flux_model_dir / "text_encoder"
87
- t5_text_encoder_path = flux_model_dir / "text_encoder_2"
88
- with (
89
- context .models .load_local_model (
90
- model_path = clip_text_encoder_path , loader = self ._load_flux_text_encoder
91
- ) as clip_text_encoder ,
92
- context .models .load_local_model (
93
- model_path = t5_text_encoder_path , loader = self ._load_flux_text_encoder_2
94
- ) as t5_text_encoder ,
95
- ):
96
- assert isinstance (clip_text_encoder , CLIPTextModel )
97
- assert isinstance (t5_text_encoder , T5EncoderModel )
98
- pipeline = FluxPipeline (
99
- scheduler = None ,
100
- vae = None ,
101
- text_encoder = clip_text_encoder ,
102
- tokenizer = clip_tokenizer ,
103
- text_encoder_2 = t5_text_encoder ,
104
- tokenizer_2 = t5_tokenizer ,
105
- transformer = None ,
106
- )
107
-
108
- # prompt_embeds: T5 embeddings
109
- # pooled_prompt_embeds: CLIP embeddings
110
- prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
111
- prompt = self .positive_prompt ,
112
- prompt_2 = self .positive_prompt ,
113
- device = TorchDevice .choose_torch_device (),
114
- max_sequence_length = max_seq_len ,
115
- )
116
-
117
- assert isinstance (prompt_embeds , torch .Tensor )
118
- assert isinstance (pooled_prompt_embeds , torch .Tensor )
119
- return prompt_embeds , pooled_prompt_embeds
120
-
121
80
def _run_diffusion (
122
81
self ,
123
82
context : InvocationContext ,
@@ -199,44 +158,6 @@ def _run_vae_decoding(
199
158
assert isinstance (image , Image .Image )
200
159
return image
201
160
202
- @staticmethod
203
- def _load_flux_text_encoder (path : Path ) -> CLIPTextModel :
204
- model = CLIPTextModel .from_pretrained (path , local_files_only = True )
205
- assert isinstance (model , CLIPTextModel )
206
- return model
207
-
208
- def _load_flux_text_encoder_2 (self , path : Path ) -> T5EncoderModel :
209
- if self .use_8bit :
210
- model_8bit_path = path / "quantized"
211
- if model_8bit_path .exists ():
212
- # The quantized model exists, load it.
213
- # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
214
- # something that we should be able to make much faster.
215
- q_model = QuantizedModelForTextEncoding .from_pretrained (model_8bit_path )
216
-
217
- # Access the underlying wrapped model.
218
- # We access the wrapped model, even though it is private, because it simplifies the type checking by
219
- # always returning a T5EncoderModel from this function.
220
- model = q_model ._wrapped
221
- else :
222
- # The quantized model does not exist yet, quantize and save it.
223
- # TODO(ryand): dtype?
224
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
225
- assert isinstance (model , T5EncoderModel )
226
-
227
- q_model = QuantizedModelForTextEncoding .quantize (model , weights = qfloat8 )
228
-
229
- model_8bit_path .mkdir (parents = True , exist_ok = True )
230
- q_model .save_pretrained (model_8bit_path )
231
-
232
- # (See earlier comment about accessing the wrapped model.)
233
- model = q_model ._wrapped
234
- else :
235
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
236
-
237
- assert isinstance (model , T5EncoderModel )
238
- return model
239
-
240
161
def _load_flux_transformer (self , path : Path ) -> FluxTransformer2DModel :
241
162
if self .use_8bit :
242
163
model_8bit_path = path / "quantized"
0 commit comments