Skip to content

Commit 9372647

Browse files
committed
improvements
1 parent 3d84b9e commit 9372647

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

src/diffusers/hooks/layerwise_upcasting.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16-
from typing import List, Type
16+
from typing import Optional, Tuple, Type
1717

1818
import torch
1919

@@ -25,13 +25,13 @@
2525

2626

2727
# fmt: off
28-
_SUPPORTED_PYTORCH_LAYERS = [
28+
_SUPPORTED_PYTORCH_LAYERS = (
2929
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
3030
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
3131
torch.nn.Linear,
32-
]
32+
)
3333

34-
_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"]
34+
_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
3535
# fmt: on
3636

3737

@@ -66,10 +66,11 @@ def apply_layerwise_upcasting(
6666
module: torch.nn.Module,
6767
storage_dtype: torch.dtype,
6868
compute_dtype: torch.dtype,
69-
skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN,
70-
skip_modules_classes: List[Type[torch.nn.Module]] = [],
69+
skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
70+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = [],
7171
non_blocking: bool = False,
72-
) -> torch.nn.Module:
72+
_prefix: str = "",
73+
) -> None:
7374
r"""
7475
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
7576
nn.Module using diffusers layers or pytorch primitives.
@@ -82,30 +83,45 @@ def apply_layerwise_upcasting(
8283
The dtype to cast the module to before/after the forward pass for storage.
8384
compute_dtype (`torch.dtype`):
8485
The dtype to cast the module to during the forward pass for computation.
85-
skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
86+
skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
8687
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
87-
skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`):
88+
skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `[]`):
8889
A list of module classes to skip during the layerwise upcasting process.
8990
non_blocking (`bool`, defaults to `False`):
9091
If `True`, the weight casting operations are non-blocking.
9192
"""
92-
for name, submodule in module.named_modules():
93-
if (
94-
any(re.search(pattern, name) for pattern in skip_modules_pattern)
95-
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
96-
or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS))
97-
or len(list(submodule.children())) > 0
98-
):
99-
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
100-
continue
101-
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
102-
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype, non_blocking)
103-
return module
93+
if skip_modules_classes is None and skip_modules_pattern is None:
94+
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
95+
return
96+
97+
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
98+
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
99+
)
100+
if should_skip:
101+
logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
102+
return
103+
104+
if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
105+
logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
106+
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
107+
return
108+
109+
for name, submodule in module.named_children():
110+
layer_name = f"{_prefix}.{name}" if _prefix else name
111+
apply_layerwise_upcasting(
112+
submodule,
113+
storage_dtype,
114+
compute_dtype,
115+
skip_modules_pattern,
116+
skip_modules_classes,
117+
non_blocking,
118+
_prefix=layer_name,
119+
)
104120

105121

106122
def apply_layerwise_upcasting_hook(
107123
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
108-
) -> torch.nn.Module:
124+
) -> None:
109125
r"""
110126
Applies a `LayerwiseUpcastingHook` to a given module.
111127
@@ -118,10 +134,6 @@ def apply_layerwise_upcasting_hook(
118134
The dtype to cast the module to during the forward pass.
119135
non_blocking (`bool`):
120136
If `True`, the weight casting operations are non-blocking.
121-
122-
Returns:
123-
`torch.nn.Module`:
124-
The same module, with the hook attached (the module is modified in place).
125137
"""
126138
registry = HookRegistry.check_if_exists_or_initialize(module)
127139
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,18 +376,17 @@ def enable_layerwise_upcasting(
376376
skip_modules_pattern.extend(self._keep_in_fp32_modules)
377377
if self._always_upcast_modules is not None:
378378
skip_modules_pattern.extend(self._always_upcast_modules)
379-
skip_modules_pattern = list(set(skip_modules_pattern))
379+
skip_modules_pattern = tuple(set(skip_modules_pattern))
380380

381381
if skip_modules_classes is None:
382-
skip_modules_classes = []
382+
skip_modules_classes = ()
383383
if is_peft_available():
384384
# By default, we want to skip all peft layers because they have a very low memory footprint.
385385
# If users want to apply layerwise upcasting on peft layers as well, they can utilize the
386386
# `~diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting` function which provides
387387
# them with more flexibility and control.
388-
from peft.tuners.tuners_utils import BaseTunerLayer
389-
390-
skip_modules_classes.append(BaseTunerLayer)
388+
if "lora" not in skip_modules_pattern:
389+
skip_modules_pattern += ("lora",)
391390

392391
if compute_dtype is None:
393392
logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.")

0 commit comments

Comments
 (0)