Skip to content

Commit da766f5

Browse files
committed
Fix support for 8b quantized t5 encoders, update exception messages in flux loaders
1 parent 120e1cf commit da766f5

File tree

3 files changed

+42
-10
lines changed

3 files changed

+42
-10
lines changed

invokeai/app/invocations/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
143143
"format": ModelFormat.T5Encoder,
144144
},
145145
"8b_quantized": {
146-
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
146+
"repo": "invokeai/flux_schnell::t5_xxl_encoder/optimum_quanto_qfloat8",
147147
"name": "t5_8b_quantized_encoder",
148-
"format": ModelFormat.T5Encoder,
148+
"format": ModelFormat.T5Encoder8b,
149149
},
150150
}
151151

invokeai/backend/model_manager/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ def get_tag() -> Tag:
225225
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
226226

227227

228+
class T5Encoder8bConfig(T5EncoderConfigBase):
229+
format: Literal[ModelFormat.T5Encoder8b] = ModelFormat.T5Encoder8b
230+
231+
@staticmethod
232+
def get_tag() -> Tag:
233+
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder8b.value}")
234+
235+
228236
class LoRALyCORISConfig(LoRAConfigBase):
229237
"""Model config for LoRA/Lycoris models."""
230238

@@ -460,6 +468,7 @@ def get_model_discriminator_value(v: Any) -> str:
460468
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
461469
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
462470
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
471+
Annotated[T5Encoder8bConfig, T5Encoder8bConfig.get_tag()],
463472
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
464473
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
465474
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
MainBnbQuantized4bCheckpointConfig,
2929
MainCheckpointConfig,
3030
T5EncoderConfig,
31+
T5Encoder8bConfig,
3132
VAECheckpointConfig,
3233
)
3334
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
3435
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
3536
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
3637
from invokeai.backend.util.silence_warnings import SilenceWarnings
38+
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
3739

3840
app_config = get_config()
3941

@@ -82,15 +84,36 @@ def _load_model(
8284
submodel_type: Optional[SubModelType] = None,
8385
) -> AnyModel:
8486
if not isinstance(config, CLIPEmbedDiffusersConfig):
85-
raise Exception("Only Checkpoint Flux models are currently supported.")
87+
raise Exception("Only CLIPEmbedDiffusersConfig models are currently supported here.")
8688

8789
match submodel_type:
8890
case SubModelType.Tokenizer:
8991
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
9092
case SubModelType.TextEncoder:
9193
return CLIPTextModel.from_pretrained(config.path)
9294

93-
raise Exception("Only Checkpoint Flux models are currently supported.")
95+
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
96+
97+
98+
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
99+
class T5Encoder8bCheckpointModel(GenericDiffusersLoader):
100+
"""Class to load main models."""
101+
102+
def _load_model(
103+
self,
104+
config: AnyModelConfig,
105+
submodel_type: Optional[SubModelType] = None,
106+
) -> AnyModel:
107+
if not isinstance(config, T5Encoder8bConfig):
108+
raise Exception("Only T5Encoder8bConfig models are currently supported here.")
109+
110+
match submodel_type:
111+
case SubModelType.Tokenizer2:
112+
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
113+
case SubModelType.TextEncoder2:
114+
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
115+
116+
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
94117

95118

96119
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@@ -103,7 +126,7 @@ def _load_model(
103126
submodel_type: Optional[SubModelType] = None,
104127
) -> AnyModel:
105128
if not isinstance(config, T5EncoderConfig):
106-
raise Exception("Only Checkpoint Flux models are currently supported.")
129+
raise Exception("Only T5EncoderConfig models are currently supported here.")
107130

108131
match submodel_type:
109132
case SubModelType.Tokenizer2:
@@ -113,7 +136,7 @@ def _load_model(
113136
Path(config.path) / "text_encoder_2"
114137
) # TODO: Fix hf subfolder install
115138

116-
raise Exception("Only Checkpoint Flux models are currently supported.")
139+
raise Exception("Only Tokenizer and TextEncoder submodels are currently supported.")
117140

118141

119142
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -126,7 +149,7 @@ def _load_model(
126149
submodel_type: Optional[SubModelType] = None,
127150
) -> AnyModel:
128151
if not isinstance(config, CheckpointConfigBase):
129-
raise Exception("Only Checkpoint Flux models are currently supported.")
152+
raise Exception("Only CheckpointConfigBase models are currently supported here.")
130153
legacy_config_path = app_config.legacy_conf_path / config.config_path
131154
config_path = legacy_config_path.as_posix()
132155
with open(config_path, "r") as stream:
@@ -139,7 +162,7 @@ def _load_model(
139162
case SubModelType.Transformer:
140163
return self._load_from_singlefile(config, flux_conf)
141164

142-
raise Exception("Only Checkpoint Flux models are currently supported.")
165+
raise Exception("Only Transformer submodels are currently supported.")
143166

144167
def _load_from_singlefile(
145168
self,
@@ -171,7 +194,7 @@ def _load_model(
171194
submodel_type: Optional[SubModelType] = None,
172195
) -> AnyModel:
173196
if not isinstance(config, CheckpointConfigBase):
174-
raise Exception("Only Checkpoint Flux models are currently supported.")
197+
raise Exception("Only CheckpointConfigBase models are currently supported here.")
175198
legacy_config_path = app_config.legacy_conf_path / config.config_path
176199
config_path = legacy_config_path.as_posix()
177200
with open(config_path, "r") as stream:
@@ -184,7 +207,7 @@ def _load_model(
184207
case SubModelType.Transformer:
185208
return self._load_from_singlefile(config, flux_conf)
186209

187-
raise Exception("Only Checkpoint Flux models are currently supported.")
210+
raise Exception("Only Transformer submodels are currently supported.")
188211

189212
def _load_from_singlefile(
190213
self,

0 commit comments

Comments
 (0)