Skip to content

Commit 8b0b496

Browse files
committed
More flux loader cleanup
1 parent ada483f commit 8b0b496

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _load_model(
6464
with SilenceWarnings():
6565
model = AutoEncoder(params)
6666
sd = load_file(model_path)
67-
model.load_state_dict(sd, strict=False, assign=True)
67+
model.load_state_dict(sd, assign=True)
6868

6969
return model
7070

@@ -83,11 +83,11 @@ def _load_model(
8383

8484
match submodel_type:
8585
case SubModelType.Tokenizer:
86-
return CLIPTokenizer.from_pretrained(config.path, max_length=77)
86+
return CLIPTokenizer.from_pretrained(config.path)
8787
case SubModelType.TextEncoder:
8888
return CLIPTextModel.from_pretrained(config.path)
8989

90-
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
90+
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
9191

9292

9393
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder8b)
@@ -108,7 +108,7 @@ def _load_model(
108108
case SubModelType.TextEncoder2:
109109
return FastQuantizedTransformersModel.from_pretrained(Path(config.path) / "text_encoder_2")
110110

111-
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
111+
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
112112

113113

114114
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
@@ -131,7 +131,7 @@ def _load_model(
131131
Path(config.path) / "text_encoder_2"
132132
) # TODO: Fix hf subfolder install
133133

134-
raise ValueError("Only Tokenizer and TextEncoder submodels are currently supported.")
134+
raise ValueError(f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
135135

136136

137137
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -154,15 +154,14 @@ def _load_model(
154154
case SubModelType.Transformer:
155155
return self._load_from_singlefile(config, flux_conf)
156156

157-
raise ValueError("Only Transformer submodels are currently supported.")
157+
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
158158

159159
def _load_from_singlefile(
160160
self,
161161
config: AnyModelConfig,
162162
flux_conf: Any,
163163
) -> AnyModel:
164164
assert isinstance(config, MainCheckpointConfig)
165-
params = None
166165
model_path = Path(config.path)
167166
dataclass_fields = {f.name for f in fields(FluxParams)}
168167
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
@@ -171,7 +170,7 @@ def _load_from_singlefile(
171170
with SilenceWarnings():
172171
model = Flux(params)
173172
sd = load_file(model_path)
174-
model.load_state_dict(sd, strict=False, assign=True)
173+
model.load_state_dict(sd, assign=True)
175174
return model
176175

177176

@@ -195,15 +194,14 @@ def _load_model(
195194
case SubModelType.Transformer:
196195
return self._load_from_singlefile(config, flux_conf)
197196

198-
raise ValueError("Only Transformer submodels are currently supported.")
197+
raise ValueError(f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}")
199198

200199
def _load_from_singlefile(
201200
self,
202201
config: AnyModelConfig,
203202
flux_conf: Any,
204203
) -> AnyModel:
205204
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
206-
params = None
207205
model_path = Path(config.path)
208206
dataclass_fields = {f.name for f in fields(FluxParams)}
209207
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
@@ -214,5 +212,5 @@ def _load_from_singlefile(
214212
model = Flux(params)
215213
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
216214
sd = load_file(model_path)
217-
model.load_state_dict(sd, strict=False, assign=True)
215+
model.load_state_dict(sd, assign=True)
218216
return model

invokeai/backend/model_manager/probe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
224224

225225
for key in [str(k) for k in ckpt.keys()]:
226226
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")):
227+
# Keys starting with double_blocks are associated with Flux models
227228
return ModelType.Main
228229
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
229230
return ModelType.VAE

0 commit comments

Comments
 (0)