Skip to content

Commit b03073d

Browse files
lsteinLincoln Stein
andauthored
[MM] Add support for probing and loading SDXL VAE checkpoint files (#6524)
* add support for probing and loading SDXL VAE checkpoint files * broaden regexp probe for SDXL VAEs --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
1 parent a43d602 commit b03073d

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

invokeai/backend/model_manager/load/model_loaders/vae.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323

2424
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
25-
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
26-
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
25+
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
2726
class VAELoader(GenericDiffusersLoader):
2827
"""Class to load VAE models."""
2928

@@ -40,12 +39,8 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path:
4039
return True
4140

4241
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
43-
# TODO(MM2): check whether sdxl VAE models convert.
44-
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
45-
raise Exception(f"VAE conversion not supported for model type: {config.base}")
46-
else:
47-
assert isinstance(config, CheckpointConfigBase)
48-
config_file = self._app_config.legacy_conf_path / config.config_path
42+
assert isinstance(config, CheckpointConfigBase)
43+
config_file = self._app_config.legacy_conf_path / config.config_path
4944

5045
if model_path.suffix == ".safetensors":
5146
checkpoint = safetensors_load_file(model_path, device="cpu")

invokeai/backend/model_manager/probe.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,16 @@ def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
451451

452452
class VaeCheckpointProbe(CheckpointProbeBase):
453453
def get_base_type(self) -> BaseModelType:
454-
# I can't find any standalone 2.X VAEs to test with!
455-
return BaseModelType.StableDiffusion1
454+
# VAEs of all base types have the same structure, so we wimp out and
455+
# guess using the name.
456+
for regexp, basetype in [
457+
(r"xl", BaseModelType.StableDiffusionXL),
458+
(r"sd2", BaseModelType.StableDiffusion2),
459+
(r"vae", BaseModelType.StableDiffusion1),
460+
]:
461+
if re.search(regexp, self.model_path.name, re.IGNORECASE):
462+
return basetype
463+
raise InvalidModelConfigException("Cannot determine base type")
456464

457465

458466
class LoRACheckpointProbe(CheckpointProbeBase):

0 commit comments

Comments
 (0)