Skip to content

Commit 2a4254c

Browse files
author
Lincoln Stein
committed
merge with main
2 parents 74f0c31 + b03073d commit 2a4254c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

invokeai/backend/model_manager/load/model_loaders/vae.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818

1919

2020
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
21-
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
22-
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
21+
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
2322
class VAELoader(GenericDiffusersLoader):
2423
"""Class to load VAE models."""
2524

invokeai/backend/model_manager/probe.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,16 @@ def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
451451

452452
class VaeCheckpointProbe(CheckpointProbeBase):
453453
def get_base_type(self) -> BaseModelType:
454-
# I can't find any standalone 2.X VAEs to test with!
455-
return BaseModelType.StableDiffusion1
454+
# VAEs of all base types have the same structure, so we wimp out and
455+
# guess using the name.
456+
for regexp, basetype in [
457+
(r"xl", BaseModelType.StableDiffusionXL),
458+
(r"sd2", BaseModelType.StableDiffusion2),
459+
(r"vae", BaseModelType.StableDiffusion1),
460+
]:
461+
if re.search(regexp, self.model_path.name, re.IGNORECASE):
462+
return basetype
463+
raise InvalidModelConfigException("Cannot determine base type")
456464

457465

458466
class LoRACheckpointProbe(CheckpointProbeBase):

0 commit comments

Comments
 (0)