Skip to content

Commit fde58ce

Browse files
Merge remote-tracking branch 'origin/main' into lstein/feat/simple-mm2-api
2 parents dc13493 + 6d067e5 commit fde58ce

File tree

42 files changed

+1659
-828
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1659
-828
lines changed

docs/contributing/MODEL_MANAGER.md

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,30 +1366,54 @@ the in-memory loaded model:
13661366
| `model` | AnyModel | The instantiated model (details below) |
13671367
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
13681368

1369-
Because the loader can return multiple model types, it is typed to
1370-
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
1371-
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
1372-
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
1373-
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
1374-
models. The others are obvious.
1369+
### get_model_by_key(key, [submodel]) -> LoadedModel
1370+
1371+
The `get_model_by_key()` method will retrieve the model using its
1372+
unique database key. For example:
1373+
1374+
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1375+
1376+
`get_model_by_key()` may raise any of the following exceptions:
1377+
1378+
* `UnknownModelException` -- key not in database
1379+
* `ModelNotFoundException` -- key in database but model not found at path
1380+
* `NotImplementedException` -- the loader doesn't know how to load this type of model
1381+
1382+
### Using the Loaded Model in Inference
13751383

13761384
`LoadedModel` acts as a context manager. The context loads the model
13771385
into the execution device (e.g. VRAM on CUDA systems), locks the model
13781386
in the execution device for the duration of the context, and returns
13791387
the model. Use it like this:
13801388

13811389
```
1382-
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1383-
with model_info as vae:
1390+
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1391+
with loaded_model as vae:
13841392
image = vae.decode(latents)[0]
13851393
```
13861394

1387-
`get_model_by_key()` may raise any of the following exceptions:
1395+
The object returned by the LoadedModel context manager is an
1396+
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
1397+
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
1398+
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
1399+
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
1400+
models. The others are obvious.
1401+
1402+
In addition, you may call `LoadedModel.model_on_device()`, a context
1403+
manager that returns a tuple of the model's state dict in CPU and the
1404+
model itself in VRAM. It is used to optimize the LoRA patching and
1405+
unpatching process:
1406+
1407+
```
1408+
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1409+
with loaded_model.model_on_device() as (state_dict, vae):
1410+
image = vae.decode(latents)[0]
1411+
```
1412+
1413+
Since not all models have state dicts, the `state_dict` return value
1414+
can be None.
1415+
13881416

1389-
* `UnknownModelException` -- key not in database
1390-
* `ModelNotFoundException` -- key in database but model not found at path
1391-
* `NotImplementedException` -- the loader doesn't know how to load this type of model
1392-
13931417
### Emitting model loading events
13941418

13951419
When the `context` argument is passed to `load_model_*()`, it will

invokeai/app/invocations/compel.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8181

8282
with (
8383
# apply all patches while the model is on the target device
84-
text_encoder_info as text_encoder,
84+
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
8585
tokenizer_info as tokenizer,
86-
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
86+
ModelPatcher.apply_lora_text_encoder(
87+
text_encoder,
88+
loras=_lora_loader(),
89+
model_state_dict=model_state_dict,
90+
),
8791
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
8892
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
8993
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
@@ -172,9 +176,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
172176

173177
with (
174178
# apply all patches while the model is on the target device
175-
text_encoder_info as text_encoder,
179+
text_encoder_info.model_on_device() as (state_dict, text_encoder),
176180
tokenizer_info as tokenizer,
177-
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
181+
ModelPatcher.apply_lora(
182+
text_encoder,
183+
loras=_lora_loader(),
184+
prefix=lora_prefix,
185+
model_state_dict=state_dict,
186+
),
178187
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
179188
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
180189
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (

invokeai/app/invocations/latent.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
5151
from invokeai.app.services.shared.invocation_context import InvocationContext
5252
from invokeai.app.util.controlnet_utils import prepare_control_image
53-
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
53+
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
5454
from invokeai.backend.lora import LoRAModelRaw
5555
from invokeai.backend.model_manager import BaseModelType, LoadedModel
5656
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
@@ -672,54 +672,52 @@ def prep_control_data(
672672

673673
return controlnet_data
674674

675+
def prep_ip_adapter_image_prompts(
676+
self,
677+
context: InvocationContext,
678+
ip_adapters: List[IPAdapterField],
679+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
680+
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
681+
image_prompts = []
682+
for single_ip_adapter in ip_adapters:
683+
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
684+
assert isinstance(ip_adapter_model, IPAdapter)
685+
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
686+
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
687+
single_ipa_image_fields = single_ip_adapter.image
688+
if not isinstance(single_ipa_image_fields, list):
689+
single_ipa_image_fields = [single_ipa_image_fields]
690+
691+
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
692+
with image_encoder_model_info as image_encoder_model:
693+
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
694+
# Get image embeddings from CLIP and ImageProjModel.
695+
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
696+
single_ipa_images, image_encoder_model
697+
)
698+
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
699+
700+
return image_prompts
701+
675702
def prep_ip_adapter_data(
676703
self,
677704
context: InvocationContext,
678-
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
705+
ip_adapters: List[IPAdapterField],
706+
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
679707
exit_stack: ExitStack,
680708
latent_height: int,
681709
latent_width: int,
682710
dtype: torch.dtype,
683-
) -> Optional[list[IPAdapterData]]:
684-
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
685-
to the `conditioning_data` (in-place).
686-
"""
687-
if ip_adapter is None:
688-
return None
689-
690-
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
691-
if not isinstance(ip_adapter, list):
692-
ip_adapter = [ip_adapter]
693-
694-
if len(ip_adapter) == 0:
695-
return None
696-
711+
) -> Optional[List[IPAdapterData]]:
712+
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
697713
ip_adapter_data_list = []
698-
for single_ip_adapter in ip_adapter:
699-
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
700-
context.models.load(single_ip_adapter.ip_adapter_model)
701-
)
702-
703-
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
704-
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
705-
single_ipa_image_fields = single_ip_adapter.image
706-
if not isinstance(single_ipa_image_fields, list):
707-
single_ipa_image_fields = [single_ipa_image_fields]
708-
709-
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
710-
711-
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
712-
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
713-
with image_encoder_model_info as image_encoder_model:
714-
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
715-
# Get image embeddings from CLIP and ImageProjModel.
716-
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
717-
single_ipa_images, image_encoder_model
718-
)
714+
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
715+
ip_adapters, image_prompts, strict=True
716+
):
717+
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
719718

720-
mask = single_ip_adapter.mask
721-
if mask is not None:
722-
mask = context.tensors.load(mask.tensor_name)
719+
mask_field = single_ip_adapter.mask
720+
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
723721
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
724722

725723
ip_adapter_data_list.append(
@@ -734,7 +732,7 @@ def prep_ip_adapter_data(
734732
)
735733
)
736734

737-
return ip_adapter_data_list
735+
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
738736

739737
def run_t2i_adapters(
740738
self,
@@ -855,6 +853,16 @@ def init_scheduler(
855853
# At some point, someone decided that schedulers that accept a generator should use the original seed with
856854
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
857855
# reproducibility.
856+
#
857+
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
858+
# - DDIMScheduler
859+
# - DDPMScheduler
860+
# - DPMSolverMultistepScheduler
861+
# - EulerAncestralDiscreteScheduler
862+
# - EulerDiscreteScheduler
863+
# - KDPM2AncestralDiscreteScheduler
864+
# - LCMScheduler
865+
# - TCDScheduler
858866
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
859867
if isinstance(scheduler, TCDScheduler):
860868
scheduler_step_kwargs.update({"eta": 1.0})
@@ -912,6 +920,20 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
912920
do_classifier_free_guidance=True,
913921
)
914922

923+
ip_adapters: List[IPAdapterField] = []
924+
if self.ip_adapter is not None:
925+
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
926+
if isinstance(self.ip_adapter, list):
927+
ip_adapters = self.ip_adapter
928+
else:
929+
ip_adapters = [self.ip_adapter]
930+
931+
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
932+
# a series of image conditioning embeddings. This is being done here rather than in the
933+
# big model context below in order to use less VRAM on low-VRAM systems.
934+
# The image prompts are then passed to prep_ip_adapter_data().
935+
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
936+
915937
# get the unet's config so that we can pass the base to dispatch_progress()
916938
unet_config = context.models.get_config(self.unet.unet.key)
917939

@@ -930,11 +952,15 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
930952
assert isinstance(unet_info.model, UNet2DConditionModel)
931953
with (
932954
ExitStack() as exit_stack,
933-
unet_info as unet,
955+
unet_info.model_on_device() as (model_state_dict, unet),
934956
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
935957
set_seamless(unet, self.unet.seamless_axes), # FIXME
936958
# Apply the LoRA after unet has been moved to its target device for faster patching.
937-
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
959+
ModelPatcher.apply_lora_unet(
960+
unet,
961+
loras=_lora_loader(),
962+
model_state_dict=model_state_dict,
963+
),
938964
):
939965
assert isinstance(unet, UNet2DConditionModel)
940966
latents = latents.to(device=unet.device, dtype=unet.dtype)
@@ -970,7 +996,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
970996

971997
ip_adapter_data = self.prep_ip_adapter_data(
972998
context=context,
973-
ip_adapter=self.ip_adapter,
999+
ip_adapters=ip_adapters,
1000+
image_prompts=image_prompts,
9741001
exit_stack=exit_stack,
9751002
latent_height=latent_height,
9761003
latent_width=latent_width,
@@ -1285,7 +1312,7 @@ def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTenso
12851312
title="Blend Latents",
12861313
tags=["latents", "blend"],
12871314
category="latents",
1288-
version="1.0.2",
1315+
version="1.0.3",
12891316
)
12901317
class BlendLatentsInvocation(BaseInvocation):
12911318
"""Blend two latents using a given alpha. Latents must have same size."""
@@ -1364,7 +1391,7 @@ def slerp(
13641391
TorchDevice.empty_cache()
13651392

13661393
name = context.tensors.save(tensor=blended_latents)
1367-
return LatentsOutput.build(latents_name=name, latents=blended_latents)
1394+
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
13681395

13691396

13701397
# The Crop Latents node was copied from @skunkworxdark's implementation here:

invokeai/backend/model_manager/load/load_base.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
"""
55

66
from abc import ABC, abstractmethod
7+
from contextlib import contextmanager
78
from dataclasses import dataclass
89
from logging import Logger
910
from pathlib import Path
10-
from typing import Any, Optional
11+
from typing import Any, Dict, Generator, Optional, Tuple
12+
13+
import torch
1114

1215
from invokeai.app.services.config import InvokeAIAppConfig
1316
from invokeai.backend.model_manager.config import (
@@ -21,7 +24,42 @@
2124

2225
@dataclass
2326
class LoadedModelWithoutConfig:
24-
"""Context manager object that mediates transfer from RAM<->VRAM."""
27+
"""
28+
Context manager object that mediates transfer from RAM<->VRAM.
29+
30+
This is a context manager object that has two distinct APIs:
31+
32+
1. Older API (deprecated):
33+
Use the LoadedModel object directly as a context manager.
34+
It will move the model into VRAM (on CUDA devices), and
35+
return the model in a form suitable for passing to torch.
36+
Example:
37+
```
38+
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
39+
with loaded_model as vae:
40+
image = vae.decode(latents)[0]
41+
```
42+
43+
2. Newer API (recommended):
44+
Call the LoadedModel's `model_on_device()` method in a
45+
context. It returns a tuple consisting of a copy of
46+
the model's state dict in CPU RAM followed by a copy
47+
of the model in VRAM. The state dict is provided to allow
48+
LoRAs and other model patchers to return the model to
49+
its unpatched state without expensive copy and restore
50+
operations.
51+
52+
Example:
53+
```
54+
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
55+
with loaded_model.model_on_device() as (state_dict, vae):
56+
image = vae.decode(latents)[0]
57+
```
58+
59+
The state_dict should be treated as a read-only object and
60+
never modified. Also be aware that some loadable models do
61+
not have a state_dict, in which case this value will be None.
62+
"""
2563

2664
_locker: ModelLockerBase
2765

@@ -34,6 +72,16 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None:
3472
"""Context exit."""
3573
self._locker.unlock()
3674

75+
@contextmanager
76+
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
77+
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
78+
locked_model = self._locker.lock()
79+
try:
80+
state_dict = self._locker.get_state_dict()
81+
yield (state_dict, locked_model)
82+
finally:
83+
self._locker.unlock()
84+
3785
@property
3886
def model(self) -> AnyModel:
3987
"""Return the model without locking it."""

invokeai/backend/model_manager/load/model_cache/model_cache_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def unlock(self) -> None:
3030
"""Unlock the contained model, and remove it from VRAM."""
3131
pass
3232

33+
@abstractmethod
34+
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
35+
"""Return the state dict (if any) for the cached model."""
36+
pass
37+
3338
@property
3439
@abstractmethod
3540
def model(self) -> AnyModel:
@@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
5661
and then injected into the model. When the model is finished, the VRAM
5762
copy of the state dict is deleted, and the RAM version is reinjected
5863
into the model.
64+
65+
The state_dict should be treated as a read-only attribute. Do not attempt
66+
to patch or otherwise modify it. Instead, patch the copy of the state_dict
67+
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
68+
context manager call `model_on_device()`.
5969
"""
6070

6171
key: str

0 commit comments

Comments
 (0)