File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed
invokeai/backend/model_manager Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change 18
18
19
19
20
20
@ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .VAE , format = ModelFormat .Diffusers )
21
- @ModelLoaderRegistry .register (base = BaseModelType .StableDiffusion1 , type = ModelType .VAE , format = ModelFormat .Checkpoint )
22
- @ModelLoaderRegistry .register (base = BaseModelType .StableDiffusion2 , type = ModelType .VAE , format = ModelFormat .Checkpoint )
21
+ @ModelLoaderRegistry .register (base = BaseModelType .Any , type = ModelType .VAE , format = ModelFormat .Checkpoint )
23
22
class VAELoader (GenericDiffusersLoader ):
24
23
"""Class to load VAE models."""
25
24
Original file line number Diff line number Diff line change @@ -451,8 +451,16 @@ def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
451
451
452
452
class VaeCheckpointProbe (CheckpointProbeBase ):
453
453
def get_base_type (self ) -> BaseModelType :
454
- # I can't find any standalone 2.X VAEs to test with!
455
- return BaseModelType .StableDiffusion1
454
+ # VAEs of all base types have the same structure, so we wimp out and
455
+ # guess using the name.
456
+ for regexp , basetype in [
457
+ (r"xl" , BaseModelType .StableDiffusionXL ),
458
+ (r"sd2" , BaseModelType .StableDiffusion2 ),
459
+ (r"vae" , BaseModelType .StableDiffusion1 ),
460
+ ]:
461
+ if re .search (regexp , self .model_path .name , re .IGNORECASE ):
462
+ return basetype
463
+ raise InvalidModelConfigException ("Cannot determine base type" )
456
464
457
465
458
466
class LoRACheckpointProbe (CheckpointProbeBase ):
You can’t perform that action at this time.
0 commit comments