Skip to content

Commit 2871676

Browse files
lsteinLincoln SteinhipsterusernameRyanJDick
authored
LoRA patching optimization (#6439)
* allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * do not save original weights if there is a CPU copy of state dict * Update invokeai/backend/model_manager/load/load_base.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * documentation fixes added during penultimate review --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
1 parent 1c5c3cd commit 2871676

File tree

7 files changed

+146
-48
lines changed

7 files changed

+146
-48
lines changed

docs/contributing/MODEL_MANAGER.md

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,30 +1367,54 @@ the in-memory loaded model:
13671367
| `model` | AnyModel | The instantiated model (details below) |
13681368
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
13691369

1370-
Because the loader can return multiple model types, it is typed to
1371-
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
1372-
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
1373-
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
1374-
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
1375-
models. The others are obvious.
1370+
### get_model_by_key(key, [submodel]) -> LoadedModel
1371+
1372+
The `get_model_by_key()` method will retrieve the model using its
1373+
unique database key. For example:
1374+
1375+
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1376+
1377+
`get_model_by_key()` may raise any of the following exceptions:
1378+
1379+
* `UnknownModelException` -- key not in database
1380+
* `ModelNotFoundException` -- key in database but model not found at path
1381+
* `NotImplementedException` -- the loader doesn't know how to load this type of model
1382+
1383+
### Using the Loaded Model in Inference
13761384

13771385
`LoadedModel` acts as a context manager. The context loads the model
13781386
into the execution device (e.g. VRAM on CUDA systems), locks the model
13791387
in the execution device for the duration of the context, and returns
13801388
the model. Use it like this:
13811389

13821390
```
1383-
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1384-
with model_info as vae:
1391+
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1392+
with loaded_model as vae:
13851393
image = vae.decode(latents)[0]
13861394
```
13871395

1388-
`get_model_by_key()` may raise any of the following exceptions:
1396+
The object returned by the LoadedModel context manager is an
1397+
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
1398+
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
1399+
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
1400+
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
1401+
models. The others are obvious.
1402+
1403+
In addition, you may call `LoadedModel.model_on_device()`, a context
1404+
manager that returns a tuple of the model's state dict in CPU and the
1405+
model itself in VRAM. It is used to optimize the LoRA patching and
1406+
unpatching process:
1407+
1408+
```
1409+
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
1410+
with loaded_model.model_on_device() as (state_dict, vae):
1411+
image = vae.decode(latents)[0]
1412+
```
1413+
1414+
Since not all models have state dicts, the `state_dict` return value
1415+
can be None.
1416+
13891417

1390-
* `UnknownModelException` -- key not in database
1391-
* `ModelNotFoundException` -- key in database but model not found at path
1392-
* `NotImplementedException` -- the loader doesn't know how to load this type of model
1393-
13941418
### Emitting model loading events
13951419

13961420
When the `context` argument is passed to `load_model_*()`, it will

invokeai/app/invocations/compel.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8181

8282
with (
8383
# apply all patches while the model is on the target device
84-
text_encoder_info as text_encoder,
84+
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
8585
tokenizer_info as tokenizer,
86-
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
86+
ModelPatcher.apply_lora_text_encoder(
87+
text_encoder,
88+
loras=_lora_loader(),
89+
model_state_dict=model_state_dict,
90+
),
8791
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
8892
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
8993
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
@@ -172,9 +176,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
172176

173177
with (
174178
# apply all patches while the model is on the target device
175-
text_encoder_info as text_encoder,
179+
text_encoder_info.model_on_device() as (state_dict, text_encoder),
176180
tokenizer_info as tokenizer,
177-
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
181+
ModelPatcher.apply_lora(
182+
text_encoder,
183+
loras=_lora_loader(),
184+
prefix=lora_prefix,
185+
model_state_dict=state_dict,
186+
),
178187
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
179188
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
180189
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (

invokeai/app/invocations/latent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -952,11 +952,15 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
952952
assert isinstance(unet_info.model, UNet2DConditionModel)
953953
with (
954954
ExitStack() as exit_stack,
955-
unet_info as unet,
955+
unet_info.model_on_device() as (model_state_dict, unet),
956956
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
957957
set_seamless(unet, self.unet.seamless_axes), # FIXME
958958
# Apply the LoRA after unet has been moved to its target device for faster patching.
959-
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
959+
ModelPatcher.apply_lora_unet(
960+
unet,
961+
loras=_lora_loader(),
962+
model_state_dict=model_state_dict,
963+
),
960964
):
961965
assert isinstance(unet, UNet2DConditionModel)
962966
latents = latents.to(device=unet.device, dtype=unet.dtype)

invokeai/backend/model_manager/load/load_base.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
"""
55

66
from abc import ABC, abstractmethod
7+
from contextlib import contextmanager
78
from dataclasses import dataclass
89
from logging import Logger
910
from pathlib import Path
10-
from typing import Any, Optional
11+
from typing import Any, Dict, Generator, Optional, Tuple
12+
13+
import torch
1114

1215
from invokeai.app.services.config import InvokeAIAppConfig
1316
from invokeai.backend.model_manager.config import (
@@ -21,7 +24,42 @@
2124

2225
@dataclass
2326
class LoadedModel:
24-
"""Context manager object that mediates transfer from RAM<->VRAM."""
27+
"""
28+
Context manager object that mediates transfer from RAM<->VRAM.
29+
30+
This is a context manager object that has two distinct APIs:
31+
32+
1. Older API (deprecated):
33+
Use the LoadedModel object directly as a context manager.
34+
It will move the model into VRAM (on CUDA devices), and
35+
return the model in a form suitable for passing to torch.
36+
Example:
37+
```
38+
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
39+
with loaded_model as vae:
40+
image = vae.decode(latents)[0]
41+
```
42+
43+
2. Newer API (recommended):
44+
Call the LoadedModel's `model_on_device()` method in a
45+
context. It returns a tuple consisting of a copy of
46+
the model's state dict in CPU RAM followed by a copy
47+
of the model in VRAM. The state dict is provided to allow
48+
LoRAs and other model patchers to return the model to
49+
its unpatched state without expensive copy and restore
50+
operations.
51+
52+
Example:
53+
```
54+
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
55+
with loaded_model.model_on_device() as (state_dict, vae):
56+
image = vae.decode(latents)[0]
57+
```
58+
59+
The state_dict should be treated as a read-only object and
60+
never modified. Also be aware that some loadable models do
61+
not have a state_dict, in which case this value will be None.
62+
"""
2563

2664
config: AnyModelConfig
2765
_locker: ModelLockerBase
@@ -35,6 +73,16 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None:
3573
"""Context exit."""
3674
self._locker.unlock()
3775

76+
@contextmanager
77+
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
78+
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
79+
locked_model = self._locker.lock()
80+
try:
81+
state_dict = self._locker.get_state_dict()
82+
yield (state_dict, locked_model)
83+
finally:
84+
self._locker.unlock()
85+
3886
@property
3987
def model(self) -> AnyModel:
4088
"""Return the model without locking it."""

invokeai/backend/model_manager/load/model_cache/model_cache_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def unlock(self) -> None:
3030
"""Unlock the contained model, and remove it from VRAM."""
3131
pass
3232

33+
@abstractmethod
34+
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
35+
"""Return the state dict (if any) for the cached model."""
36+
pass
37+
3338
@property
3439
@abstractmethod
3540
def model(self) -> AnyModel:
@@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
5661
and then injected into the model. When the model is finished, the VRAM
5762
copy of the state dict is deleted, and the RAM version is reinjected
5863
into the model.
64+
65+
The state_dict should be treated as a read-only attribute. Do not attempt
66+
to patch or otherwise modify it. Instead, patch the copy of the state_dict
67+
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
68+
context manager call `model_on_device()`.
5969
"""
6070

6171
key: str

invokeai/backend/model_manager/load/model_cache/model_locker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Base class and implementation of a class that moves models in and out of VRAM.
33
"""
44

5+
from typing import Dict, Optional
6+
57
import torch
68

79
from invokeai.backend.model_manager import AnyModel
@@ -27,6 +29,10 @@ def model(self) -> AnyModel:
2729
"""Return the model without moving it around."""
2830
return self._cache_entry.model
2931

32+
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
33+
"""Return the state dict (if any) for the cached model."""
34+
return self._cache_entry.state_dict
35+
3036
def lock(self) -> AnyModel:
3137
"""Move the model into the execution device (GPU) and lock it."""
3238
if not hasattr(self.model, "to"):
@@ -37,10 +43,8 @@ def lock(self) -> AnyModel:
3743
try:
3844
if self._cache.lazy_offloading:
3945
self._cache.offload_unlocked_models(self._cache_entry.size)
40-
4146
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
4247
self._cache_entry.loaded = True
43-
4448
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
4549
self._cache.print_cuda_stats()
4650
except torch.cuda.OutOfMemoryError:

invokeai/backend/model_patcher.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pickle
77
from contextlib import contextmanager
8-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
8+
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
99

1010
import numpy as np
1111
import torch
@@ -66,8 +66,14 @@ def apply_lora_unet(
6666
cls,
6767
unet: UNet2DConditionModel,
6868
loras: Iterator[Tuple[LoRAModelRaw, float]],
69+
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
6970
) -> None:
70-
with cls.apply_lora(unet, loras, "lora_unet_"):
71+
with cls.apply_lora(
72+
unet,
73+
loras=loras,
74+
prefix="lora_unet_",
75+
model_state_dict=model_state_dict,
76+
):
7177
yield
7278

7379
@classmethod
@@ -76,28 +82,9 @@ def apply_lora_text_encoder(
7682
cls,
7783
text_encoder: CLIPTextModel,
7884
loras: Iterator[Tuple[LoRAModelRaw, float]],
85+
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
7986
) -> None:
80-
with cls.apply_lora(text_encoder, loras, "lora_te_"):
81-
yield
82-
83-
@classmethod
84-
@contextmanager
85-
def apply_sdxl_lora_text_encoder(
86-
cls,
87-
text_encoder: CLIPTextModel,
88-
loras: List[Tuple[LoRAModelRaw, float]],
89-
) -> None:
90-
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
91-
yield
92-
93-
@classmethod
94-
@contextmanager
95-
def apply_sdxl_lora_text_encoder2(
96-
cls,
97-
text_encoder: CLIPTextModel,
98-
loras: List[Tuple[LoRAModelRaw, float]],
99-
) -> None:
100-
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
87+
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
10188
yield
10289

10390
@classmethod
@@ -107,7 +94,16 @@ def apply_lora(
10794
model: AnyModel,
10895
loras: Iterator[Tuple[LoRAModelRaw, float]],
10996
prefix: str,
110-
) -> None:
97+
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
98+
) -> Generator[Any, None, None]:
99+
"""
100+
Apply one or more LoRAs to a model.
101+
102+
:param model: The model to patch.
103+
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
104+
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
105+
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
106+
"""
111107
original_weights = {}
112108
try:
113109
with torch.no_grad():
@@ -133,7 +129,10 @@ def apply_lora(
133129
dtype = module.weight.dtype
134130

135131
if module_key not in original_weights:
136-
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
132+
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
133+
original_weights[module_key] = model_state_dict[module_key + ".weight"]
134+
else:
135+
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
137136

138137
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
139138

0 commit comments

Comments
 (0)