Skip to content

Commit 1d84505

Browse files
committed
fix tests
1 parent 0517f0d commit 1d84505

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

tests/lora/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,14 +2109,15 @@ def test_correct_lora_configs_with_different_ranks(self):
21092109
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
21102110

21112111
def test_layerwise_casting_inference_denoiser(self):
2112-
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
2112+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
2113+
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
21132114

21142115
def check_linear_dtype(module, storage_dtype, compute_dtype):
21152116
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
21162117
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
21172118
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
21182119
for name, submodule in module.named_modules():
2119-
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
2120+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
21202121
continue
21212122
dtype_to_check = storage_dtype
21222123
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -2167,10 +2168,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
21672168
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
21682169
"""
21692170

2171+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
21702172
from diffusers.hooks.layerwise_casting import (
21712173
_PEFT_AUTOCAST_DISABLE_HOOK,
21722174
DEFAULT_SKIP_MODULES_PATTERN,
2173-
SUPPORTED_PYTORCH_LAYERS,
21742175
apply_layerwise_casting,
21752176
)
21762177

@@ -2180,7 +2181,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
21802181
def check_module(denoiser):
21812182
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
21822183
for name, module in denoiser.named_modules():
2183-
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
2184+
if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
21842185
continue
21852186
dtype_to_check = storage_dtype
21862187
if any(re.search(pattern, name) for pattern in patterns_to_check):

tests/models/test_modeling_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,8 @@ def test_fn(storage_dtype, compute_dtype):
15301530

15311531
@torch.no_grad()
15321532
def test_layerwise_casting_inference(self):
1533-
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
1533+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
1534+
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
15341535

15351536
torch.manual_seed(0)
15361537
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1544,7 +1545,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype):
15441545
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
15451546
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
15461547
for name, submodule in module.named_modules():
1547-
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
1548+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
15481549
continue
15491550
dtype_to_check = storage_dtype
15501551
if any(re.search(pattern, name) for pattern in patterns_to_check):

0 commit comments

Comments
 (0)