1
1
import copy
2
- from typing import List , Optional
2
+ from time import sleep
3
+ from typing import List , Optional , Literal , Dict
3
4
4
5
from pydantic import BaseModel , Field
5
6
13
14
from invokeai .app .invocations .fields import FieldDescriptions , Input , InputField , OutputField , UIType
14
15
from invokeai .app .services .shared .invocation_context import InvocationContext
15
16
from invokeai .app .shared .models import FreeUConfig
16
- from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType , SubModelType
17
+ from invokeai .app .services .model_records import ModelRecordChanges
18
+ from invokeai .backend .model_manager .config import AnyModelConfig , BaseModelType , ModelType , SubModelType , ModelFormat
17
19
18
20
19
21
class ModelIdentifierField (BaseModel ):
@@ -62,7 +64,6 @@ class CLIPField(BaseModel):
62
64
63
65
class TransformerField (BaseModel ):
64
66
transformer : ModelIdentifierField = Field (description = "Info to load Transformer submodel" )
65
- scheduler : ModelIdentifierField = Field (description = "Info to load scheduler submodel" )
66
67
67
68
68
69
class T5EncoderField (BaseModel ):
@@ -131,6 +132,30 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
131
132
132
133
return ModelIdentifierOutput (model = self .model )
133
134
135
+ T5_ENCODER_OPTIONS = Literal ["base" , "16b_quantized" , "8b_quantized" ]
136
+ T5_ENCODER_MAP : Dict [str , Dict [str , str ]] = {
137
+ "base" : {
138
+ "text_encoder_repo" : "black-forest-labs/FLUX.1-schnell::text_encoder_2" ,
139
+ "tokenizer_repo" : "black-forest-labs/FLUX.1-schnell::tokenizer_2" ,
140
+ "text_encoder_name" : "FLUX.1-schnell_text_encoder_2" ,
141
+ "tokenizer_name" : "FLUX.1-schnell_tokenizer_2" ,
142
+ "format" : ModelFormat .T5Encoder ,
143
+ },
144
+ "8b_quantized" : {
145
+ "text_encoder_repo" : "hf_repo1" ,
146
+ "tokenizer_repo" : "hf_repo1" ,
147
+ "text_encoder_name" : "hf_repo1" ,
148
+ "tokenizer_name" : "hf_repo1" ,
149
+ "format" : ModelFormat .T5Encoder8b ,
150
+ },
151
+ "4b_quantized" : {
152
+ "text_encoder_repo" : "hf_repo2" ,
153
+ "tokenizer_repo" : "hf_repo2" ,
154
+ "text_encoder_name" : "hf_repo2" ,
155
+ "tokenizer_name" : "hf_repo2" ,
156
+ "format" : ModelFormat .T5Encoder8b ,
157
+ },
158
+ }
134
159
135
160
@invocation_output ("flux_model_loader_output" )
136
161
class FluxModelLoaderOutput (BaseInvocationOutput ):
@@ -151,29 +176,55 @@ class FluxModelLoaderInvocation(BaseInvocation):
151
176
ui_type = UIType .FluxMainModel ,
152
177
input = Input .Direct ,
153
178
)
179
+
180
+ t5_encoder : T5_ENCODER_OPTIONS = InputField (description = "The T5 Encoder model to use." )
154
181
155
182
def invoke (self , context : InvocationContext ) -> FluxModelLoaderOutput :
156
183
model_key = self .model .key
157
184
158
- # TODO: not found exceptions
159
185
if not context .models .exists (model_key ):
160
186
raise Exception (f"Unknown model: { model_key } " )
161
-
162
- transformer = self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
163
- scheduler = self .model .model_copy (update = {"submodel_type" : SubModelType .Scheduler })
164
- tokenizer = self .model .model_copy (update = {"submodel_type" : SubModelType .Tokenizer })
165
- text_encoder = self .model .model_copy (update = {"submodel_type" : SubModelType .TextEncoder })
166
- tokenizer2 = self .model .model_copy (update = {"submodel_type" : SubModelType .Tokenizer2 })
167
- text_encoder2 = self .model .model_copy (update = {"submodel_type" : SubModelType .TextEncoder2 })
168
- vae = self .model .model_copy (update = {"submodel_type" : SubModelType .VAE })
187
+ transformer = self ._get_model (context , SubModelType .Transformer )
188
+ tokenizer = self ._get_model (context , SubModelType .Tokenizer )
189
+ tokenizer2 = self ._get_model (context , SubModelType .Tokenizer2 )
190
+ clip_encoder = self ._get_model (context , SubModelType .TextEncoder )
191
+ t5_encoder = self ._get_model (context , SubModelType .TextEncoder2 )
192
+ vae = self ._install_model (context , SubModelType .VAE , "FLUX.1-schnell_ae" , "black-forest-labs/FLUX.1-schnell::ae.safetensors" , ModelFormat .Checkpoint , ModelType .VAE , BaseModelType .Flux )
169
193
170
194
return FluxModelLoaderOutput (
171
- transformer = TransformerField (transformer = transformer , scheduler = scheduler ),
172
- clip = CLIPField (tokenizer = tokenizer , text_encoder = text_encoder , loras = [], skipped_layers = 0 ),
173
- t5Encoder = T5EncoderField (tokenizer = tokenizer2 , text_encoder = text_encoder2 ),
195
+ transformer = TransformerField (transformer = transformer ),
196
+ clip = CLIPField (tokenizer = tokenizer , text_encoder = clip_encoder , loras = [], skipped_layers = 0 ),
197
+ t5Encoder = T5EncoderField (tokenizer = tokenizer2 , text_encoder = t5_encoder ),
174
198
vae = VAEField (vae = vae ),
175
199
)
176
200
201
+ def _get_model (self , context : InvocationContext , submodel :SubModelType ) -> ModelIdentifierField :
202
+ match (submodel ):
203
+ case SubModelType .Transformer :
204
+ return self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
205
+ case submodel if submodel in [SubModelType .Tokenizer , SubModelType .TextEncoder ]:
206
+ return self ._install_model (context , submodel , "clip-vit-large-patch14" , "openai/clip-vit-large-patch14" , ModelFormat .Diffusers , ModelType .CLIPEmbed , BaseModelType .Any )
207
+ case SubModelType .TextEncoder2 :
208
+ return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_name" ], T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
209
+ case SubModelType .Tokenizer2 :
210
+ return self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_name" ], T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
211
+ case _:
212
+ raise Exception (f"{ submodel .value } is not a supported submodule for a flux model" )
213
+
214
+ def _install_model (self , context : InvocationContext , submodel :SubModelType , name : str , repo_id : str , format : ModelFormat , type : ModelType , base : BaseModelType ):
215
+ if (models := context .models .search_by_attrs (name = name , base = base , type = type )):
216
+ if len (models ) != 1 :
217
+ raise Exception (f"Multiple models detected for selected model with name { name } " )
218
+ return ModelIdentifierField .from_config (models [0 ]).model_copy (update = {"submodel_type" : submodel })
219
+ else :
220
+ model_path = context .models .download_and_cache_model (repo_id )
221
+ config = ModelRecordChanges (name = name , base = base , type = type , format = format )
222
+ model_install_job = context .models .import_local_model (model_path = model_path , config = config )
223
+ while not model_install_job .in_terminal_state :
224
+ sleep (0.01 )
225
+ if not model_install_job .config_out :
226
+ raise Exception (f"Failed to install { name } " )
227
+ return ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (update = {"submodel_type" : submodel })
177
228
178
229
@invocation (
179
230
"main_model_loader" ,
0 commit comments