13
13
# limitations under the License.
14
14
15
15
import re
16
- from typing import List , Type
16
+ from typing import Optional , Tuple , Type
17
17
18
18
import torch
19
19
25
25
26
26
27
27
# fmt: off
28
- _SUPPORTED_PYTORCH_LAYERS = [
28
+ _SUPPORTED_PYTORCH_LAYERS = (
29
29
torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
30
30
torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
31
31
torch .nn .Linear ,
32
- ]
32
+ )
33
33
34
- _DEFAULT_SKIP_MODULES_PATTERN = [ "pos_embed" , "patch_embed" , "norm" ]
34
+ _DEFAULT_SKIP_MODULES_PATTERN = ( "pos_embed" , "patch_embed" , "norm" )
35
35
# fmt: on
36
36
37
37
@@ -66,10 +66,11 @@ def apply_layerwise_upcasting(
66
66
module : torch .nn .Module ,
67
67
storage_dtype : torch .dtype ,
68
68
compute_dtype : torch .dtype ,
69
- skip_modules_pattern : List [ str ] = _DEFAULT_SKIP_MODULES_PATTERN ,
70
- skip_modules_classes : List [ Type [torch .nn .Module ]] = [],
69
+ skip_modules_pattern : Optional [ Tuple [ str ] ] = _DEFAULT_SKIP_MODULES_PATTERN ,
70
+ skip_modules_classes : Optional [ Tuple [ Type [torch .nn .Module ] ]] = [],
71
71
non_blocking : bool = False ,
72
- ) -> torch .nn .Module :
72
+ _prefix : str = "" ,
73
+ ) -> None :
73
74
r"""
74
75
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
75
76
nn.Module using diffusers layers or pytorch primitives.
@@ -82,30 +83,45 @@ def apply_layerwise_upcasting(
82
83
The dtype to cast the module to before/after the forward pass for storage.
83
84
compute_dtype (`torch.dtype`):
84
85
The dtype to cast the module to during the forward pass for computation.
85
- skip_modules_pattern (`List [str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
86
+ skip_modules_pattern (`Tuple [str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
86
87
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
87
- skip_modules_classes (`List [Type[torch.nn.Module]]`, defaults to `[]`):
88
+ skip_modules_classes (`Tuple [Type[torch.nn.Module]]`, defaults to `[]`):
88
89
A list of module classes to skip during the layerwise upcasting process.
89
90
non_blocking (`bool`, defaults to `False`):
90
91
If `True`, the weight casting operations are non-blocking.
91
92
"""
92
- for name , submodule in module .named_modules ():
93
- if (
94
- any (re .search (pattern , name ) for pattern in skip_modules_pattern )
95
- or any (isinstance (submodule , module_class ) for module_class in skip_modules_classes )
96
- or not isinstance (submodule , tuple (_SUPPORTED_PYTORCH_LAYERS ))
97
- or len (list (submodule .children ())) > 0
98
- ):
99
- logger .debug (f'Skipping layerwise upcasting for layer "{ name } "' )
100
- continue
101
- logger .debug (f'Applying layerwise upcasting to layer "{ name } "' )
102
- apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype , non_blocking )
103
- return module
93
+ if skip_modules_classes is None and skip_modules_pattern is None :
94
+ apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
95
+ return
96
+
97
+ should_skip = (skip_modules_classes is not None and isinstance (module , skip_modules_classes )) or (
98
+ skip_modules_pattern is not None and any (re .search (pattern , _prefix ) for pattern in skip_modules_pattern )
99
+ )
100
+ if should_skip :
101
+ logger .debug (f'Skipping layerwise upcasting for layer "{ _prefix } "' )
102
+ return
103
+
104
+ if isinstance (module , _SUPPORTED_PYTORCH_LAYERS ):
105
+ logger .debug (f'Applying layerwise upcasting to layer "{ _prefix } "' )
106
+ apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype , non_blocking )
107
+ return
108
+
109
+ for name , submodule in module .named_children ():
110
+ layer_name = f"{ _prefix } .{ name } " if _prefix else name
111
+ apply_layerwise_upcasting (
112
+ submodule ,
113
+ storage_dtype ,
114
+ compute_dtype ,
115
+ skip_modules_pattern ,
116
+ skip_modules_classes ,
117
+ non_blocking ,
118
+ _prefix = layer_name ,
119
+ )
104
120
105
121
106
122
def apply_layerwise_upcasting_hook (
107
123
module : torch .nn .Module , storage_dtype : torch .dtype , compute_dtype : torch .dtype , non_blocking : bool
108
- ) -> torch . nn . Module :
124
+ ) -> None :
109
125
r"""
110
126
Applies a `LayerwiseUpcastingHook` to a given module.
111
127
@@ -118,10 +134,6 @@ def apply_layerwise_upcasting_hook(
118
134
The dtype to cast the module to during the forward pass.
119
135
non_blocking (`bool`):
120
136
If `True`, the weight casting operations are non-blocking.
121
-
122
- Returns:
123
- `torch.nn.Module`:
124
- The same module, with the hook attached (the module is modified in place).
125
137
"""
126
138
registry = HookRegistry .check_if_exists_or_initialize (module )
127
139
hook = LayerwiseUpcastingHook (storage_dtype , compute_dtype , non_blocking )
0 commit comments