6
6
from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
7
7
8
8
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
9
- from invokeai .app .invocations .fields import InputField
10
- from invokeai .app .invocations .flux_text_to_image import FLUX_MODELS , QuantizedModelForTextEncoding , TFluxModelKeys
9
+ from invokeai .app .invocations .model import CLIPField , T5EncoderField
10
+ from invokeai .app .invocations .fields import InputField , FieldDescriptions , Input
11
+ from invokeai .app .invocations .flux_text_to_image import FLUX_MODELS , QuantizedModelForTextEncoding
12
+ from invokeai .app .invocations .model import CLIPField , T5EncoderField
11
13
from invokeai .app .invocations .primitives import ConditioningOutput
12
14
from invokeai .app .services .shared .invocation_context import InvocationContext
13
15
from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ConditioningFieldData , FLUXConditioningInfo
22
24
version = "1.0.0" ,
23
25
)
24
26
class FluxTextEncoderInvocation (BaseInvocation ):
25
- model : TFluxModelKeys = InputField (description = "The FLUX model to use for text-to-image generation." )
26
- use_8bit : bool = InputField (
27
- default = False , description = "Whether to quantize the transformer model to 8-bit precision."
27
+ clip : CLIPField = InputField (
28
+ title = "CLIP" ,
29
+ description = FieldDescriptions .clip ,
30
+ input = Input .Connection ,
31
+ )
32
+ t5Encoder : T5EncoderField = InputField (
33
+ title = "T5EncoderField" ,
34
+ description = FieldDescriptions .t5Encoder ,
35
+ input = Input .Connection ,
28
36
)
29
37
positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
30
38
31
39
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
32
40
# compatible with other ConditioningOutputs.
33
41
@torch .no_grad ()
34
42
def invoke (self , context : InvocationContext ) -> ConditioningOutput :
35
- model_path = context .models .download_and_cache_model (FLUX_MODELS [self .model ])
36
43
37
- t5_embeddings , clip_embeddings = self ._encode_prompt (context , model_path )
44
+ t5_embeddings , clip_embeddings = self ._encode_prompt (context )
38
45
conditioning_data = ConditioningFieldData (
39
46
conditionings = [FLUXConditioningInfo (clip_embeds = clip_embeddings , t5_embeds = t5_embeddings )]
40
47
)
41
48
42
49
conditioning_name = context .conditioning .save (conditioning_data )
43
50
return ConditioningOutput .build (conditioning_name )
51
+
52
+ def _encode_prompt (self , context : InvocationContext ) -> tuple [torch .Tensor , torch .Tensor ]:
53
+ # TODO: Determine the T5 max sequence length based on the model.
54
+ # if self.model == "flux-schnell":
55
+ max_seq_len = 256
56
+ # # elif self.model == "flux-dev":
57
+ # # max_seq_len = 512
58
+ # else:
59
+ # raise ValueError(f"Unknown model: {self.model}")
60
+
61
+ # Load CLIP.
62
+ clip_tokenizer_info = context .models .load (self .clip .tokenizer )
63
+ clip_text_encoder_info = context .models .load (self .clip .text_encoder )
64
+
65
+ # Load T5.
66
+ t5_tokenizer_info = context .models .load (self .t5Encoder .tokenizer )
67
+ t5_text_encoder_info = context .models .load (self .t5Encoder .text_encoder )
44
68
45
- def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
46
- # Determine the T5 max sequence length based on the model.
47
- if self .model == "flux-schnell" :
48
- max_seq_len = 256
49
- # elif self.model == "flux-dev":
50
- # max_seq_len = 512
51
- else :
52
- raise ValueError (f"Unknown model: { self .model } " )
53
-
54
- # Load the CLIP tokenizer.
55
- clip_tokenizer_path = flux_model_dir / "tokenizer"
56
- clip_tokenizer = CLIPTokenizer .from_pretrained (clip_tokenizer_path , local_files_only = True )
57
- assert isinstance (clip_tokenizer , CLIPTokenizer )
58
-
59
- # Load the T5 tokenizer.
60
- t5_tokenizer_path = flux_model_dir / "tokenizer_2"
61
- t5_tokenizer = T5TokenizerFast .from_pretrained (t5_tokenizer_path , local_files_only = True )
62
- assert isinstance (t5_tokenizer , T5TokenizerFast )
63
-
64
- clip_text_encoder_path = flux_model_dir / "text_encoder"
65
- t5_text_encoder_path = flux_model_dir / "text_encoder_2"
66
69
with (
67
- context .models .load_local_model (
68
- model_path = clip_text_encoder_path , loader = self ._load_flux_text_encoder
69
- ) as clip_text_encoder ,
70
- context .models .load_local_model (
71
- model_path = t5_text_encoder_path , loader = self ._load_flux_text_encoder_2
72
- ) as t5_text_encoder ,
70
+ clip_text_encoder_info as clip_text_encoder ,
71
+ t5_text_encoder_info as t5_text_encoder ,
72
+ clip_tokenizer_info as clip_tokenizer ,
73
+ t5_tokenizer_info as t5_tokenizer ,
73
74
):
74
75
assert isinstance (clip_text_encoder , CLIPTextModel )
75
76
assert isinstance (t5_text_encoder , T5EncoderModel )
77
+ assert isinstance (clip_tokenizer , CLIPTokenizer )
78
+ assert isinstance (t5_tokenizer , T5TokenizerFast )
79
+
76
80
pipeline = FluxPipeline (
77
81
scheduler = None ,
78
82
vae = None ,
@@ -85,7 +89,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
85
89
86
90
# prompt_embeds: T5 embeddings
87
91
# pooled_prompt_embeds: CLIP embeddings
88
- prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
92
+ prompt_embeds , pooled_prompt_embeds , _ = pipeline .encode_prompt (
89
93
prompt = self .positive_prompt ,
90
94
prompt_2 = self .positive_prompt ,
91
95
device = TorchDevice .choose_torch_device (),
@@ -95,41 +99,3 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
95
99
assert isinstance (prompt_embeds , torch .Tensor )
96
100
assert isinstance (pooled_prompt_embeds , torch .Tensor )
97
101
return prompt_embeds , pooled_prompt_embeds
98
-
99
- @staticmethod
100
- def _load_flux_text_encoder (path : Path ) -> CLIPTextModel :
101
- model = CLIPTextModel .from_pretrained (path , local_files_only = True )
102
- assert isinstance (model , CLIPTextModel )
103
- return model
104
-
105
- def _load_flux_text_encoder_2 (self , path : Path ) -> T5EncoderModel :
106
- if self .use_8bit :
107
- model_8bit_path = path / "quantized"
108
- if model_8bit_path .exists ():
109
- # The quantized model exists, load it.
110
- # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
111
- # something that we should be able to make much faster.
112
- q_model = QuantizedModelForTextEncoding .from_pretrained (model_8bit_path )
113
-
114
- # Access the underlying wrapped model.
115
- # We access the wrapped model, even though it is private, because it simplifies the type checking by
116
- # always returning a T5EncoderModel from this function.
117
- model = q_model ._wrapped
118
- else :
119
- # The quantized model does not exist yet, quantize and save it.
120
- # TODO(ryand): dtype?
121
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
122
- assert isinstance (model , T5EncoderModel )
123
-
124
- q_model = QuantizedModelForTextEncoding .quantize (model , weights = qfloat8 )
125
-
126
- model_8bit_path .mkdir (parents = True , exist_ok = True )
127
- q_model .save_pretrained (model_8bit_path )
128
-
129
- # (See earlier comment about accessing the wrapped model.)
130
- model = q_model ._wrapped
131
- else :
132
- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
133
-
134
- assert isinstance (model , T5EncoderModel )
135
- return model
0 commit comments