@@ -2109,14 +2109,15 @@ def test_correct_lora_configs_with_different_ranks(self):
2109
2109
self .assertTrue (not np .allclose (lora_output_diff_alpha , lora_output_same_rank , atol = 1e-3 , rtol = 1e-3 ))
2110
2110
2111
2111
def test_layerwise_casting_inference_denoiser (self ):
2112
- from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
2112
+ from diffusers .hooks ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2113
+ from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
2113
2114
2114
2115
def check_linear_dtype (module , storage_dtype , compute_dtype ):
2115
2116
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
2116
2117
if getattr (module , "_skip_layerwise_casting_patterns" , None ) is not None :
2117
2118
patterns_to_check += tuple (module ._skip_layerwise_casting_patterns )
2118
2119
for name , submodule in module .named_modules ():
2119
- if not isinstance (submodule , SUPPORTED_PYTORCH_LAYERS ):
2120
+ if not isinstance (submodule , _GO_LC_SUPPORTED_PYTORCH_LAYERS ):
2120
2121
continue
2121
2122
dtype_to_check = storage_dtype
2122
2123
if "lora" in name or any (re .search (pattern , name ) for pattern in patterns_to_check ):
@@ -2167,10 +2168,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
2167
2168
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
2168
2169
"""
2169
2170
2171
+ from diffusers .hooks ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2170
2172
from diffusers .hooks .layerwise_casting import (
2171
2173
_PEFT_AUTOCAST_DISABLE_HOOK ,
2172
2174
DEFAULT_SKIP_MODULES_PATTERN ,
2173
- SUPPORTED_PYTORCH_LAYERS ,
2174
2175
apply_layerwise_casting ,
2175
2176
)
2176
2177
@@ -2180,7 +2181,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
2180
2181
def check_module (denoiser ):
2181
2182
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
2182
2183
for name , module in denoiser .named_modules ():
2183
- if not isinstance (module , SUPPORTED_PYTORCH_LAYERS ):
2184
+ if not isinstance (module , _GO_LC_SUPPORTED_PYTORCH_LAYERS ):
2184
2185
continue
2185
2186
dtype_to_check = storage_dtype
2186
2187
if any (re .search (pattern , name ) for pattern in patterns_to_check ):
0 commit comments