Skip to content

Commit f81b8bc

Browse files
Lincoln Steinpsychedelicious
authored andcommitted
add support for generic loading of diffusers directories
1 parent a9962fd commit f81b8bc

File tree

8 files changed

+44
-27
lines changed

8 files changed

+44
-27
lines changed

invokeai/app/services/model_load/model_load_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch import Tensor
99

1010
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
11-
from invokeai.backend.model_manager.load import LoadedModel
11+
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
1212
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
1313
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
1414

@@ -38,7 +38,7 @@ def convert_cache(self) -> ModelConvertCacheBase:
3838
@abstractmethod
3939
def load_model_from_path(
4040
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
41-
) -> LoadedModel:
41+
) -> LoadedModelWithoutConfig:
4242
"""
4343
Load the model file or directory located at the indicated Path.
4444
@@ -47,7 +47,8 @@ def load_model_from_path(
4747
memory. Otherwise the method will call safetensors.torch.load_file() or
4848
torch.load() as appropriate to the file suffix.
4949
50-
Be aware that the LoadedModel object will have a `config` attribute of None.
50+
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
51+
LoadedModel, but without the config attribute.
5152
5253
Args:
5354
model_path: A pathlib.Path to a checkpoint-style models file

invokeai/app/services/model_load/model_load_default.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
1515
from invokeai.backend.model_manager.load import (
1616
LoadedModel,
17+
LoadedModelWithoutConfig,
1718
ModelLoaderRegistry,
1819
ModelLoaderRegistryBase,
1920
)
@@ -85,12 +86,12 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
8586
return loaded_model
8687

8788
def load_model_from_path(
88-
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
89-
) -> LoadedModel:
89+
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
90+
) -> LoadedModelWithoutConfig:
9091
cache_key = str(model_path)
9192
ram_cache = self.ram_cache
9293
try:
93-
return LoadedModel(_locker=ram_cache.get(key=cache_key))
94+
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
9495
except IndexError:
9596
pass
9697

@@ -113,11 +114,13 @@ def diffusers_load_directory(directory: Path) -> AnyModel:
113114

114115
if loader is None:
115116
loader = (
116-
torch_load_file
117+
diffusers_load_directory
118+
if model_path.is_dir()
119+
else torch_load_file
117120
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
118121
else lambda path: safetensors_load_file(path, device="cpu")
119122
)
120123

121124
raw_model = loader(model_path)
122125
ram_cache.put(key=cache_key, model=raw_model)
123-
return LoadedModel(_locker=ram_cache.get(key=cache_key))
126+
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

invokeai/app/services/shared/invocation_context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from invokeai.app.services.model_records.model_records_base import UnknownModelException
1717
from invokeai.app.util.step_callback import stable_diffusion_step_callback
1818
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
19-
from invokeai.backend.model_manager.load.load_base import LoadedModel
19+
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
2020
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
2121
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
2222

@@ -461,7 +461,7 @@ def load_and_cache_model(
461461
self,
462462
source: Path | str | AnyHttpUrl,
463463
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
464-
) -> LoadedModel:
464+
) -> LoadedModelWithoutConfig:
465465
"""
466466
Download, cache, and load the model file located at the indicated URL.
467467
@@ -470,14 +470,14 @@ def load_and_cache_model(
470470
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
471471
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
472472
473-
Be aware that the LoadedModel object will have a `config` attribute of None.
473+
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
474474
475475
Args:
476476
source: A model Path, URL, or repoid.
477477
loader: A Callable that expects a Path and returns a dict[str|int, Any]
478478
479479
Returns:
480-
A LoadedModel object.
480+
A LoadedModelWithoutConfig object.
481481
"""
482482

483483
if isinstance(source, Path):

invokeai/backend/model_manager/load/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88

99
from .convert_cache.convert_cache_default import ModelConvertCache
10-
from .load_base import LoadedModel, ModelLoaderBase
10+
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
1111
from .load_default import ModelLoader
1212
from .model_cache.model_cache_default import ModelCache
1313
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@@ -19,6 +19,7 @@
1919

2020
__all__ = [
2121
"LoadedModel",
22+
"LoadedModelWithoutConfig",
2223
"ModelCache",
2324
"ModelConvertCache",
2425
"ModelLoaderBase",

invokeai/backend/model_manager/load/load_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020

2121

2222
@dataclass
23-
class LoadedModel:
23+
class LoadedModelWithoutConfig:
2424
"""Context manager object that mediates transfer from RAM<->VRAM."""
2525

2626
_locker: ModelLockerBase
27-
config: Optional[AnyModelConfig] = None
2827

2928
def __enter__(self) -> AnyModel:
3029
"""Context entry."""
@@ -41,6 +40,13 @@ def model(self) -> AnyModel:
4140
return self._locker.model
4241

4342

43+
@dataclass
44+
class LoadedModel(LoadedModelWithoutConfig):
45+
"""Context manager object that mediates transfer from RAM<->VRAM."""
46+
47+
config: Optional[AnyModelConfig] = None
48+
49+
4450
# TODO(MM2):
4551
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
4652
# know about. I think the problem may be related to this class being an ABC.

invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,11 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
6565
else:
6666
try:
6767
config = self._load_diffusers_config(model_path, config_name="config.json")
68-
class_name = config.get("_class_name", None)
69-
if class_name:
68+
if class_name := config.get("_class_name"):
7069
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
71-
if config.get("model_type", None) == "clip_vision_model":
72-
class_name = config.get("architectures")
73-
assert class_name is not None
70+
elif class_name := config.get("architectures"):
7471
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
75-
if not class_name:
72+
else:
7673
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
7774
except KeyError as e:
7875
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e

tests/app/services/model_load/test_load_api.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import pytest
44
import torch
5+
from diffusers import AutoencoderTiny
56

67
from invokeai.app.services.invocation_services import InvocationServices
78
from invokeai.app.services.model_manager import ModelManagerServiceBase
89
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
9-
from invokeai.backend.model_manager.load.load_base import LoadedModel
10+
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
1011
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
1112

1213

@@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
4344
"https://www.test.foo/download/test_embedding.safetensors"
4445
)
4546
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
46-
assert isinstance(loaded_model_1, LoadedModel)
47+
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
4748

4849
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
49-
assert isinstance(loaded_model_2, LoadedModel)
50+
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
5051
assert loaded_model_1.model is loaded_model_2.model
5152

5253
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
53-
assert isinstance(loaded_model_3, LoadedModel)
54+
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
5455
assert loaded_model_1.model is not loaded_model_3.model
5556
assert isinstance(loaded_model_1.model, dict)
5657
assert isinstance(loaded_model_3.model, dict)
5758
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
5859

60+
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
61+
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
62+
assert isinstance(loaded_model, LoadedModelWithoutConfig)
63+
assert isinstance(loaded_model.model, AutoencoderTiny)
5964

6065
def test_download_and_load(mock_context: InvocationContext) -> None:
6166
loaded_model_1 = mock_context.models.load_and_cache_model(
6267
"https://www.test.foo/download/test_embedding.safetensors"
6368
)
64-
assert isinstance(loaded_model_1, LoadedModel)
69+
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
6570

6671
loaded_model_2 = mock_context.models.load_and_cache_model(
6772
"https://www.test.foo/download/test_embedding.safetensors"
6873
)
69-
assert isinstance(loaded_model_2, LoadedModel)
74+
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
7075
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
7176

7277

tests/backend/model_manager/model_manager_fixtures.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path:
6060
def embedding_file(mm2_model_files: Path) -> Path:
6161
return mm2_model_files / "test_embedding.safetensors"
6262

63+
@pytest.fixture
64+
def vae_directory(mm2_model_files: Path) -> Path:
65+
return mm2_model_files / "taesdxl"
66+
6367

6468
@pytest.fixture
6569
def diffusers_dir(mm2_model_files: Path) -> Path:

0 commit comments

Comments
 (0)