Skip to content

Commit 1109708

Browse files
authored
Merge branch 'main' into lstein/feat/load-one-file
2 parents e7b7737 + 785bb1d commit 1109708

File tree

13 files changed

+95
-48
lines changed

13 files changed

+95
-48
lines changed

invokeai/app/invocations/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Literal
22

33
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
4+
from invokeai.backend.util.devices import TorchDevice
45

56
LATENT_SCALE_FACTOR = 8
67
"""
@@ -15,3 +16,5 @@
1516

1617
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
1718
"""A literal type for PIL image modes supported by Invoke"""
19+
20+
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

invokeai/app/invocations/create_denoise_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torchvision.transforms.functional import resize as tv_resize
77

88
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9-
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
9+
from invokeai.app.invocations.constants import DEFAULT_PRECISION
1010
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
1111
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
1212
from invokeai.app.invocations.model import VAEField
@@ -30,7 +30,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
3030
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
3131
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
3232
fp32: bool = InputField(
33-
default=DEFAULT_PRECISION == "float32",
33+
default=DEFAULT_PRECISION == torch.float32,
3434
description=FieldDescriptions.fp32,
3535
ui_order=4,
3636
)

invokeai/app/invocations/create_gradient_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.transforms.functional import resize as tv_resize
88

99
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
10-
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
10+
from invokeai.app.invocations.constants import DEFAULT_PRECISION
1111
from invokeai.app.invocations.fields import (
1212
DenoiseMaskField,
1313
FieldDescriptions,
@@ -74,7 +74,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
7474
)
7575
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
7676
fp32: bool = InputField(
77-
default=DEFAULT_PRECISION == "float32",
77+
default=DEFAULT_PRECISION == torch.float32,
7878
description=FieldDescriptions.fp32,
7979
ui_order=9,
8080
)

invokeai/app/invocations/denoise_latents.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@
5959
from .controlnet_image_processors import ControlField
6060
from .model import ModelIdentifierField, UNetField
6161

62-
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
63-
6462

6563
def get_scheduler(
6664
context: InvocationContext,

invokeai/app/invocations/image_to_latents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
1313

1414
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
15-
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
15+
from invokeai.app.invocations.constants import DEFAULT_PRECISION
1616
from invokeai.app.invocations.fields import (
1717
FieldDescriptions,
1818
ImageField,
@@ -44,7 +44,7 @@ class ImageToLatentsInvocation(BaseInvocation):
4444
input=Input.Connection,
4545
)
4646
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
47-
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
47+
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
4848

4949
@staticmethod
5050
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:

invokeai/app/invocations/latents_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
1212

1313
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
14-
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION
14+
from invokeai.app.invocations.constants import DEFAULT_PRECISION
1515
from invokeai.app.invocations.fields import (
1616
FieldDescriptions,
1717
Input,
@@ -46,7 +46,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4646
input=Input.Connection,
4747
)
4848
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
49-
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
49+
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
5050

5151
@torch.no_grad()
5252
def invoke(self, context: InvocationContext) -> ImageOutput:

invokeai/backend/ip_adapter/ip_adapter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,16 @@ def __init__(
125125
self.device, dtype=self.dtype
126126
)
127127

128-
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
129-
self.device = device
128+
def to(
129+
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
130+
):
131+
if device is not None:
132+
self.device = device
130133
if dtype is not None:
131134
self.dtype = dtype
132135

133-
self._image_proj_model.to(device=self.device, dtype=self.dtype)
134-
self.attn_weights.to(device=self.device, dtype=self.dtype)
136+
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
137+
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
135138

136139
def calc_size(self):
137140
# workaround for circular import

invokeai/backend/lora.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ def to(
6161
self,
6262
device: Optional[torch.device] = None,
6363
dtype: Optional[torch.dtype] = None,
64+
non_blocking: bool = False,
6465
) -> None:
6566
if self.bias is not None:
66-
self.bias = self.bias.to(device=device, dtype=dtype)
67+
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
6768

6869

6970
# TODO: find and debug lora/locon with bias
@@ -109,14 +110,15 @@ def to(
109110
self,
110111
device: Optional[torch.device] = None,
111112
dtype: Optional[torch.dtype] = None,
113+
non_blocking: bool = False,
112114
) -> None:
113-
super().to(device=device, dtype=dtype)
115+
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
114116

115-
self.up = self.up.to(device=device, dtype=dtype)
116-
self.down = self.down.to(device=device, dtype=dtype)
117+
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
118+
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
117119

118120
if self.mid is not None:
119-
self.mid = self.mid.to(device=device, dtype=dtype)
121+
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
120122

121123

122124
class LoHALayer(LoRALayerBase):
@@ -169,18 +171,19 @@ def to(
169171
self,
170172
device: Optional[torch.device] = None,
171173
dtype: Optional[torch.dtype] = None,
174+
non_blocking: bool = False,
172175
) -> None:
173176
super().to(device=device, dtype=dtype)
174177

175-
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
176-
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
178+
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
179+
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
177180
if self.t1 is not None:
178-
self.t1 = self.t1.to(device=device, dtype=dtype)
181+
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
179182

180-
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
181-
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
183+
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
184+
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
182185
if self.t2 is not None:
183-
self.t2 = self.t2.to(device=device, dtype=dtype)
186+
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
184187

185188

186189
class LoKRLayer(LoRALayerBase):
@@ -265,6 +268,7 @@ def to(
265268
self,
266269
device: Optional[torch.device] = None,
267270
dtype: Optional[torch.dtype] = None,
271+
non_blocking: bool = False,
268272
) -> None:
269273
super().to(device=device, dtype=dtype)
270274

@@ -273,19 +277,19 @@ def to(
273277
else:
274278
assert self.w1_a is not None
275279
assert self.w1_b is not None
276-
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
277-
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
280+
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
281+
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
278282

279283
if self.w2 is not None:
280-
self.w2 = self.w2.to(device=device, dtype=dtype)
284+
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
281285
else:
282286
assert self.w2_a is not None
283287
assert self.w2_b is not None
284-
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
285-
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
288+
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
289+
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
286290

287291
if self.t2 is not None:
288-
self.t2 = self.t2.to(device=device, dtype=dtype)
292+
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
289293

290294

291295
class FullLayer(LoRALayerBase):
@@ -319,10 +323,11 @@ def to(
319323
self,
320324
device: Optional[torch.device] = None,
321325
dtype: Optional[torch.dtype] = None,
326+
non_blocking: bool = False,
322327
) -> None:
323328
super().to(device=device, dtype=dtype)
324329

325-
self.weight = self.weight.to(device=device, dtype=dtype)
330+
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
326331

327332

328333
class IA3Layer(LoRALayerBase):
@@ -358,11 +363,12 @@ def to(
358363
self,
359364
device: Optional[torch.device] = None,
360365
dtype: Optional[torch.dtype] = None,
366+
non_blocking: bool = False,
361367
):
362368
super().to(device=device, dtype=dtype)
363369

364-
self.weight = self.weight.to(device=device, dtype=dtype)
365-
self.on_input = self.on_input.to(device=device, dtype=dtype)
370+
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
371+
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
366372

367373

368374
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@@ -388,10 +394,11 @@ def to(
388394
self,
389395
device: Optional[torch.device] = None,
390396
dtype: Optional[torch.dtype] = None,
397+
non_blocking: bool = False,
391398
) -> None:
392399
# TODO: try revert if exception?
393400
for _key, layer in self.layers.items():
394-
layer.to(device=device, dtype=dtype)
401+
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
395402

396403
def calc_size(self) -> int:
397404
model_size = 0
@@ -514,7 +521,7 @@ def from_checkpoint(
514521
# lower memory consumption by removing already parsed layer values
515522
state_dict[layer_key].clear()
516523

517-
layer.to(device=device, dtype=dtype)
524+
layer.to(device=device, dtype=dtype, non_blocking=True)
518525
model.layers[layer_key] = layer
519526

520527
return model

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
285285
else:
286286
new_dict: Dict[str, torch.Tensor] = {}
287287
for k, v in cache_entry.state_dict.items():
288-
new_dict[k] = v.to(torch.device(target_device), copy=True)
288+
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
289289
cache_entry.model.load_state_dict(new_dict, assign=True)
290-
cache_entry.model.to(target_device)
290+
cache_entry.model.to(target_device, non_blocking=True)
291291
cache_entry.device = target_device
292292
except Exception as e: # blow away cache entry
293293
self._delete_cache_entry(cache_entry)

invokeai/backend/model_patcher.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def apply_lora_unet(
6767
unet: UNet2DConditionModel,
6868
loras: Iterator[Tuple[LoRAModelRaw, float]],
6969
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
70-
) -> None:
70+
) -> Generator[None, None, None]:
7171
with cls.apply_lora(
7272
unet,
7373
loras=loras,
@@ -83,7 +83,7 @@ def apply_lora_text_encoder(
8383
text_encoder: CLIPTextModel,
8484
loras: Iterator[Tuple[LoRAModelRaw, float]],
8585
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
86-
) -> None:
86+
) -> Generator[None, None, None]:
8787
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
8888
yield
8989

@@ -95,7 +95,7 @@ def apply_lora(
9595
loras: Iterator[Tuple[LoRAModelRaw, float]],
9696
prefix: str,
9797
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
98-
) -> Generator[Any, None, None]:
98+
) -> Generator[None, None, None]:
9999
"""
100100
Apply one or more LoRAs to a model.
101101
@@ -139,12 +139,12 @@ def apply_lora(
139139
# We intentionally move to the target device first, then cast. Experimentally, this was found to
140140
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
141141
# same thing in a single call to '.to(...)'.
142-
layer.to(device=device)
143-
layer.to(dtype=torch.float32)
142+
layer.to(device=device, non_blocking=True)
143+
layer.to(dtype=torch.float32, non_blocking=True)
144144
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
145145
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
146146
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
147-
layer.to(device=torch.device("cpu"))
147+
layer.to(device=torch.device("cpu"), non_blocking=True)
148148

149149
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
150150
if module.weight.shape != layer_weight.shape:
@@ -153,15 +153,15 @@ def apply_lora(
153153
layer_weight = layer_weight.reshape(module.weight.shape)
154154

155155
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
156-
module.weight += layer_weight.to(dtype=dtype)
156+
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
157157

158158
yield # wait for context manager exit
159159

160160
finally:
161161
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
162162
with torch.no_grad():
163163
for module_key, weight in original_weights.items():
164-
model.get_submodule(module_key).weight.copy_(weight)
164+
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
165165

166166
@classmethod
167167
@contextmanager

0 commit comments

Comments
 (0)