Skip to content

Commit dc13493

Browse files
Lincoln Steinpsychedelicious
authored andcommitted
replace load_and_cache_model() with load_remote_model() and load_local_odel()
1 parent 9f93796 commit dc13493

File tree

12 files changed

+107
-70
lines changed

12 files changed

+107
-70
lines changed

docs/contributing/MODEL_MANAGER.md

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,9 +1585,9 @@ Within invocations, the following methods are available from the
15851585

15861586
### context.download_and_cache_model(source) -> Path
15871587

1588-
This method accepts a `source` of a model, downloads and caches it
1589-
locally, and returns a Path to the local model. The source can be a
1590-
local file or directory, a URL, or a HuggingFace repo_id.
1588+
This method accepts a `source` of a remote model, downloads and caches
1589+
it locally, and then returns a Path to the local model. The source can
1590+
be a direct download URL or a HuggingFace repo_id.
15911591

15921592
In the case of HuggingFace repo_id, the following variants are
15931593
recognized:
@@ -1602,16 +1602,34 @@ directory using this syntax:
16021602

16031603
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
16041604

1605-
### context.load_and_cache_model(source, [loader]) -> LoadedModel
1605+
### context.load_local_model(model_path, [loader]) -> LoadedModel
16061606

1607-
This method takes a model source, downloads it, caches it, and then
1608-
loads it into the RAM cache for use in inference. The optional loader
1609-
is a Callable that accepts a Path to the object, and returns a
1610-
`Dict[str, torch.Tensor]`. If no loader is provided, then the method
1611-
will use `torch.load()` for a .ckpt or .bin checkpoint file,
1612-
`safetensors.torch.load_file()` for a safetensors checkpoint file, or
1613-
`*.from_pretrained()` for a directory that looks like a
1614-
diffusers directory.
1607+
This method loads a local model from the indicated path, returning a
1608+
`LoadedModel`. The optional loader is a Callable that accepts a Path
1609+
to the object, and returns a `AnyModel` object. If no loader is
1610+
provided, then the method will use `torch.load()` for a .ckpt or .bin
1611+
checkpoint file, `safetensors.torch.load_file()` for a safetensors
1612+
checkpoint file, or `cls.from_pretrained()` for a directory that looks
1613+
like a diffusers directory.
1614+
1615+
### context.load_remote_model(source, [loader]) -> LoadedModel
1616+
1617+
This method accepts a `source` of a remote model, downloads and caches
1618+
it locally, loads it, and returns a `LoadedModel`. The source can be a
1619+
direct download URL or a HuggingFace repo_id.
1620+
1621+
In the case of HuggingFace repo_id, the following variants are
1622+
recognized:
1623+
1624+
* stabilityai/stable-diffusion-v4 -- default model
1625+
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
1626+
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
1627+
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
1628+
1629+
You can also point at an arbitrary individual file within a repo_id
1630+
directory using this syntax:
1631+
1632+
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
16151633

16161634

16171635

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def loader(model_path: Path):
611611
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
612612
)
613613

614-
with self._context.models.load_and_cache_model(
614+
with self._context.models.load_remote_model(
615615
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
616616
) as model:
617617
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())

invokeai/app/invocations/infill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
134134
"""Infills transparent areas of an image using the LaMa model"""
135135

136136
def infill(self, image: Image.Image):
137-
with self._context.models.load_and_cache_model(
137+
with self._context.models.load_remote_model(
138138
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
139139
loader=LaMA.load_jit_model,
140140
) as model:

invokeai/app/invocations/upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
9191
context.logger.error(msg)
9292
raise ValueError(msg)
9393

94-
loadnet = context.models.load_and_cache_model(
94+
loadnet = context.models.load_remote_model(
9595
source=ESRGAN_MODEL_URLS[self.model_name],
9696
)
9797

invokeai/app/services/model_install/model_install_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Union
77

8+
from pydantic.networks import AnyHttpUrl
9+
810
from invokeai.app.services.config import InvokeAIAppConfig
911
from invokeai.app.services.download import DownloadQueueServiceBase
1012
from invokeai.app.services.events.events_base import EventServiceBase
@@ -241,7 +243,7 @@ def sync_model_path(self, key: str) -> AnyModelConfig:
241243
"""
242244

243245
@abstractmethod
244-
def download_and_cache_model(self, source: str) -> Path:
246+
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
245247
"""
246248
Download the model file located at source to the models cache and return its Path.
247249

invokeai/app/services/model_install/model_install_default.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import yaml
1616
from huggingface_hub import HfFolder
1717
from pydantic.networks import AnyHttpUrl
18+
from pydantic_core import Url
1819
from requests import Session
1920

2021
from invokeai.app.services.config import InvokeAIAppConfig
@@ -374,7 +375,7 @@ def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: Invoke
374375

375376
def download_and_cache_model(
376377
self,
377-
source: str,
378+
source: str | AnyHttpUrl,
378379
) -> Path:
379380
"""Download the model file located at source to the models cache and return its Path."""
380381
model_path = self._download_cache_path(str(source), self._app_config)
@@ -388,7 +389,7 @@ def download_and_cache_model(
388389
return contents[0]
389390

390391
model_path.mkdir(parents=True, exist_ok=True)
391-
model_source = self._guess_source(source)
392+
model_source = self._guess_source(str(source))
392393
remote_files, _ = self._remote_files_from_source(model_source)
393394
job = self._multifile_download(
394395
dest=model_path,
@@ -447,7 +448,7 @@ def _guess_source(self, source: str) -> ModelSource:
447448
)
448449
elif re.match(r"^https?://[^/]+", source):
449450
source_obj = URLModelSource(
450-
url=AnyHttpUrl(source),
451+
url=Url(source),
451452
)
452453
else:
453454
raise ValueError(f"Unsupported model source: '{source}'")

invokeai/app/services/model_load/model_load_base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
from abc import ABC, abstractmethod
55
from pathlib import Path
6-
from typing import Callable, Dict, Optional
7-
8-
from torch import Tensor
6+
from typing import Callable, Optional
97

108
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
119
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
@@ -37,7 +35,7 @@ def convert_cache(self) -> ModelConvertCacheBase:
3735

3836
@abstractmethod
3937
def load_model_from_path(
40-
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
38+
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
4139
) -> LoadedModelWithoutConfig:
4240
"""
4341
Load the model file or directory located at the indicated Path.

invokeai/app/services/model_load/model_load_default.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
"""Implementation of model loader service."""
33

44
from pathlib import Path
5-
from typing import Callable, Dict, Optional, Type
5+
from typing import Callable, Optional, Type
66

77
from picklescan.scanner import scan_file_path
88
from safetensors.torch import load_file as safetensors_load_file
9-
from torch import Tensor
109
from torch import load as torch_load
1110

1211
from invokeai.app.services.config import InvokeAIAppConfig
@@ -86,7 +85,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
8685
return loaded_model
8786

8887
def load_model_from_path(
89-
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
88+
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
9089
) -> LoadedModelWithoutConfig:
9190
cache_key = str(model_path)
9291
ram_cache = self.ram_cache
@@ -95,11 +94,11 @@ def load_model_from_path(
9594
except IndexError:
9695
pass
9796

98-
def torch_load_file(checkpoint: Path) -> Dict[str, Tensor]:
97+
def torch_load_file(checkpoint: Path) -> AnyModel:
9998
scan_result = scan_file_path(checkpoint)
10099
if scan_result.infected_files != 0:
101100
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
102-
result: Dict[str, Tensor] = torch_load(checkpoint, map_location="cpu")
101+
result = torch_load(checkpoint, map_location="cpu")
103102
return result
104103

105104
def diffusers_load_directory(directory: Path) -> AnyModel:
@@ -109,18 +108,16 @@ def diffusers_load_directory(directory: Path) -> AnyModel:
109108
ram_cache=self._ram_cache,
110109
convert_cache=self.convert_cache,
111110
).get_hf_load_class(directory)
112-
result: AnyModel = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
113-
return result
114-
115-
if loader is None:
116-
loader = (
117-
diffusers_load_directory
118-
if model_path.is_dir()
119-
else torch_load_file
120-
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
121-
else lambda path: safetensors_load_file(path, device="cpu")
122-
)
123-
111+
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
112+
113+
loader = loader or (
114+
diffusers_load_directory
115+
if model_path.is_dir()
116+
else torch_load_file
117+
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
118+
else lambda path: safetensors_load_file(path, device="cpu")
119+
)
120+
assert loader is not None
124121
raw_model = loader(model_path)
125122
ram_cache.put(key=cache_key, model=raw_model)
126123
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

invokeai/app/services/shared/invocation_context.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
from invokeai.app.services.invocation_services import InvocationServices
1616
from invokeai.app.services.model_records.model_records_base import UnknownModelException
1717
from invokeai.app.util.step_callback import stable_diffusion_step_callback
18-
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
18+
from invokeai.backend.model_manager.config import (
19+
AnyModel,
20+
AnyModelConfig,
21+
BaseModelType,
22+
ModelFormat,
23+
ModelType,
24+
SubModelType,
25+
)
1926
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
2027
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
2128
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@@ -449,21 +456,42 @@ def download_and_cache_model(
449456
installed, the cached path will be returned. Otherwise it will be downloaded.
450457
451458
Args:
452-
source: A model path, URL or repo_id.
459+
source: A URL that points to the model, or a huggingface repo_id.
453460
454461
Returns:
455462
Path to the downloaded model
456463
"""
457-
458464
return self._services.model_manager.install.download_and_cache_model(source=source)
459465

460-
def load_and_cache_model(
466+
def load_local_model(
461467
self,
462-
source: Path | str | AnyHttpUrl,
463-
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
468+
model_path: Path,
469+
loader: Optional[Callable[[Path], AnyModel]] = None,
464470
) -> LoadedModelWithoutConfig:
465471
"""
466-
Download, cache, and load the model file located at the indicated URL.
472+
Load the model file located at the indicated path
473+
474+
If a loader callable is provided, it will be invoked to load the model. Otherwise,
475+
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
476+
477+
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
478+
479+
Args:
480+
path: A model Path
481+
loader: A Callable that expects a Path and returns a dict[str|int, Any]
482+
483+
Returns:
484+
A LoadedModelWithoutConfig object.
485+
"""
486+
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
487+
488+
def load_remote_model(
489+
self,
490+
source: str | AnyHttpUrl,
491+
loader: Optional[Callable[[Path], AnyModel]] = None,
492+
) -> LoadedModelWithoutConfig:
493+
"""
494+
Download, cache, and load the model file located at the indicated URL or repo_id.
467495
468496
If the model is already downloaded, it will be loaded from the cache.
469497
@@ -473,18 +501,14 @@ def load_and_cache_model(
473501
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
474502
475503
Args:
476-
source: A model Path, URL, or repoid.
504+
source: A URL or huggingface repoid.
477505
loader: A Callable that expects a Path and returns a dict[str|int, Any]
478506
479507
Returns:
480508
A LoadedModelWithoutConfig object.
481509
"""
482-
483-
if isinstance(source, Path):
484-
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
485-
else:
486-
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
487-
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
510+
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
511+
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
488512

489513

490514
class ConfigInterface(InvocationContextInterface):

invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,12 @@ def _remove_unused_core_models(self) -> None:
5959

6060
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
6161
"""
62-
Build the migration from database version 9 to 10.
62+
Build the migration from database version 10 to 11.
6363
6464
This migration does the following:
6565
- Moves "core" models previously downloaded with download_with_progress_bar() into new
6666
"models/.download_cache" directory.
6767
- Renames "models/.cache" to "models/.convert_cache".
68-
- Adds `error_type` and `error_message` columns to the session queue table.
69-
- Renames the `error` column to `error_traceback`.
7068
"""
7169
migration_11 = Migration(
7270
from_version=10,

0 commit comments

Comments
 (0)