Skip to content

Commit a3cb5da

Browse files
lsteinLincoln SteinhipsterusernameRyanJDick
authored
Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere (#6490)
* allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * do not save original weights if there is a CPU copy of state dict * Update invokeai/backend/model_manager/load/load_base.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * documentation fixes requested during penultimate review * add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases * fix ruff errors * prevent crash on non-cuda-enabled systems --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
1 parent 568a484 commit a3cb5da

File tree

7 files changed

+84
-38
lines changed

7 files changed

+84
-38
lines changed

invokeai/backend/ip_adapter/ip_adapter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,16 @@ def __init__(
125125
self.device, dtype=self.dtype
126126
)
127127

128-
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
129-
self.device = device
128+
def to(
129+
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
130+
):
131+
if device is not None:
132+
self.device = device
130133
if dtype is not None:
131134
self.dtype = dtype
132135

133-
self._image_proj_model.to(device=self.device, dtype=self.dtype)
134-
self.attn_weights.to(device=self.device, dtype=self.dtype)
136+
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
137+
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
135138

136139
def calc_size(self):
137140
# workaround for circular import

invokeai/backend/lora.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def to(
6161
self,
6262
device: Optional[torch.device] = None,
6363
dtype: Optional[torch.dtype] = None,
64+
non_blocking: bool = False,
6465
) -> None:
6566
if self.bias is not None:
66-
self.bias = self.bias.to(device=device, dtype=dtype)
67+
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
6768

6869

6970
# TODO: find and debug lora/locon with bias
@@ -109,14 +110,15 @@ def to(
109110
self,
110111
device: Optional[torch.device] = None,
111112
dtype: Optional[torch.dtype] = None,
113+
non_blocking: bool = False,
112114
) -> None:
113-
super().to(device=device, dtype=dtype)
115+
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
114116

115-
self.up = self.up.to(device=device, dtype=dtype)
116-
self.down = self.down.to(device=device, dtype=dtype)
117+
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
118+
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
117119

118120
if self.mid is not None:
119-
self.mid = self.mid.to(device=device, dtype=dtype)
121+
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
120122

121123

122124
class LoHALayer(LoRALayerBase):
@@ -169,18 +171,19 @@ def to(
169171
self,
170172
device: Optional[torch.device] = None,
171173
dtype: Optional[torch.dtype] = None,
174+
non_blocking: bool = False,
172175
) -> None:
173176
super().to(device=device, dtype=dtype)
174177

175-
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
176-
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
178+
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
179+
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
177180
if self.t1 is not None:
178-
self.t1 = self.t1.to(device=device, dtype=dtype)
181+
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
179182

180-
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
181-
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
183+
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
184+
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
182185
if self.t2 is not None:
183-
self.t2 = self.t2.to(device=device, dtype=dtype)
186+
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
184187

185188

186189
class LoKRLayer(LoRALayerBase):
@@ -265,6 +268,7 @@ def to(
265268
self,
266269
device: Optional[torch.device] = None,
267270
dtype: Optional[torch.dtype] = None,
271+
non_blocking: bool = False,
268272
) -> None:
269273
super().to(device=device, dtype=dtype)
270274

@@ -273,19 +277,19 @@ def to(
273277
else:
274278
assert self.w1_a is not None
275279
assert self.w1_b is not None
276-
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
277-
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
280+
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
281+
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
278282

279283
if self.w2 is not None:
280-
self.w2 = self.w2.to(device=device, dtype=dtype)
284+
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
281285
else:
282286
assert self.w2_a is not None
283287
assert self.w2_b is not None
284-
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
285-
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
288+
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
289+
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
286290

287291
if self.t2 is not None:
288-
self.t2 = self.t2.to(device=device, dtype=dtype)
292+
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
289293

290294

291295
class FullLayer(LoRALayerBase):
@@ -319,10 +323,11 @@ def to(
319323
self,
320324
device: Optional[torch.device] = None,
321325
dtype: Optional[torch.dtype] = None,
326+
non_blocking: bool = False,
322327
) -> None:
323328
super().to(device=device, dtype=dtype)
324329

325-
self.weight = self.weight.to(device=device, dtype=dtype)
330+
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
326331

327332

328333
class IA3Layer(LoRALayerBase):
@@ -358,11 +363,12 @@ def to(
358363
self,
359364
device: Optional[torch.device] = None,
360365
dtype: Optional[torch.dtype] = None,
366+
non_blocking: bool = False,
361367
):
362368
super().to(device=device, dtype=dtype)
363369

364-
self.weight = self.weight.to(device=device, dtype=dtype)
365-
self.on_input = self.on_input.to(device=device, dtype=dtype)
370+
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
371+
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
366372

367373

368374
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@@ -388,10 +394,11 @@ def to(
388394
self,
389395
device: Optional[torch.device] = None,
390396
dtype: Optional[torch.dtype] = None,
397+
non_blocking: bool = False,
391398
) -> None:
392399
# TODO: try revert if exception?
393400
for _key, layer in self.layers.items():
394-
layer.to(device=device, dtype=dtype)
401+
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
395402

396403
def calc_size(self) -> int:
397404
model_size = 0
@@ -514,7 +521,7 @@ def from_checkpoint(
514521
# lower memory consumption by removing already parsed layer values
515522
state_dict[layer_key].clear()
516523

517-
layer.to(device=device, dtype=dtype)
524+
layer.to(device=device, dtype=dtype, non_blocking=True)
518525
model.layers[layer_key] = layer
519526

520527
return model

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
285285
else:
286286
new_dict: Dict[str, torch.Tensor] = {}
287287
for k, v in cache_entry.state_dict.items():
288-
new_dict[k] = v.to(torch.device(target_device), copy=True)
288+
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
289289
cache_entry.model.load_state_dict(new_dict, assign=True)
290-
cache_entry.model.to(target_device)
290+
cache_entry.model.to(target_device, non_blocking=True)
291291
cache_entry.device = target_device
292292
except Exception as e: # blow away cache entry
293293
self._delete_cache_entry(cache_entry)

invokeai/backend/model_patcher.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def apply_lora_unet(
6767
unet: UNet2DConditionModel,
6868
loras: Iterator[Tuple[LoRAModelRaw, float]],
6969
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
70-
) -> None:
70+
) -> Generator[None, None, None]:
7171
with cls.apply_lora(
7272
unet,
7373
loras=loras,
@@ -83,7 +83,7 @@ def apply_lora_text_encoder(
8383
text_encoder: CLIPTextModel,
8484
loras: Iterator[Tuple[LoRAModelRaw, float]],
8585
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
86-
) -> None:
86+
) -> Generator[None, None, None]:
8787
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
8888
yield
8989

@@ -95,7 +95,7 @@ def apply_lora(
9595
loras: Iterator[Tuple[LoRAModelRaw, float]],
9696
prefix: str,
9797
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
98-
) -> Generator[Any, None, None]:
98+
) -> Generator[None, None, None]:
9999
"""
100100
Apply one or more LoRAs to a model.
101101
@@ -139,12 +139,12 @@ def apply_lora(
139139
# We intentionally move to the target device first, then cast. Experimentally, this was found to
140140
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
141141
# same thing in a single call to '.to(...)'.
142-
layer.to(device=device)
143-
layer.to(dtype=torch.float32)
142+
layer.to(device=device, non_blocking=True)
143+
layer.to(dtype=torch.float32, non_blocking=True)
144144
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
145145
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
146146
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
147-
layer.to(device=torch.device("cpu"))
147+
layer.to(device=torch.device("cpu"), non_blocking=True)
148148

149149
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
150150
if module.weight.shape != layer_weight.shape:
@@ -153,15 +153,15 @@ def apply_lora(
153153
layer_weight = layer_weight.reshape(module.weight.shape)
154154

155155
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
156-
module.weight += layer_weight.to(dtype=dtype)
156+
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
157157

158158
yield # wait for context manager exit
159159

160160
finally:
161161
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
162162
with torch.no_grad():
163163
for module_key, weight in original_weights.items():
164-
model.get_submodule(module_key).weight.copy_(weight)
164+
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
165165

166166
@classmethod
167167
@contextmanager

invokeai/backend/onnx/onnx_runtime.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import onnx
9+
import torch
910
from onnx import numpy_helper
1011
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
1112

@@ -188,6 +189,15 @@ def __call__(self, **kwargs):
188189
# return self.io_binding.copy_outputs_to_cpu()
189190
return self.session.run(None, inputs)
190191

192+
# compatability with RawModel ABC
193+
def to(
194+
self,
195+
device: Optional[torch.device] = None,
196+
dtype: Optional[torch.dtype] = None,
197+
non_blocking: bool = False,
198+
) -> None:
199+
pass
200+
191201
# compatability with diffusers load code
192202
@classmethod
193203
def from_pretrained(

invokeai/backend/raw_model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
that adds additional methods and attributes.
1111
"""
1212

13+
from abc import ABC, abstractmethod
14+
from typing import Optional
1315

14-
class RawModel:
15-
"""Base class for 'Raw' model wrappers."""
16+
import torch
17+
18+
19+
class RawModel(ABC):
20+
"""Abstract base class for 'Raw' model wrappers."""
21+
22+
@abstractmethod
23+
def to(
24+
self,
25+
device: Optional[torch.device] = None,
26+
dtype: Optional[torch.dtype] = None,
27+
non_blocking: bool = False,
28+
) -> None:
29+
pass

invokeai/backend/textual_inversion.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ def from_checkpoint(
6565

6666
return result
6767

68+
def to(
69+
self,
70+
device: Optional[torch.device] = None,
71+
dtype: Optional[torch.dtype] = None,
72+
non_blocking: bool = False,
73+
) -> None:
74+
if not torch.cuda.is_available():
75+
return
76+
for emb in [self.embedding, self.embedding_2]:
77+
if emb is not None:
78+
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
79+
6880

6981
class TextualInversionManager(BaseTextualInversionManager):
7082
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""

0 commit comments

Comments
 (0)