Skip to content

Commit 5df02fc

Browse files
authored
[tests] Fix group offloading and layerwise casting test interaction (#11796)
* update * update * update
1 parent 7392c8f commit 5df02fc

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
110110
self.patch_size = patch_size
111111
self.patch_method = patch_method
112112

113-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
114-
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
113+
wavelets = _WAVELETS.get(patch_method).clone()
114+
arange = torch.arange(wavelets.shape[0])
115+
116+
self.register_buffer("wavelets", wavelets, persistent=False)
117+
self.register_buffer("_arange", arange, persistent=False)
115118

116119
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
117120
dtype = hidden_states.dtype
@@ -185,12 +188,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
185188
self.patch_size = patch_size
186189
self.patch_method = patch_method
187190

188-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
189-
self.register_buffer(
190-
"_arange",
191-
torch.arange(_WAVELETS[patch_method].shape[0]),
192-
persistent=False,
193-
)
191+
wavelets = _WAVELETS.get(patch_method).clone()
192+
arange = torch.arange(wavelets.shape[0])
193+
194+
self.register_buffer("wavelets", wavelets, persistent=False)
195+
self.register_buffer("_arange", arange, persistent=False)
194196

195197
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
196198
device = hidden_states.device

tests/models/test_modeling_common.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
15281528
test_fn(torch.float8_e5m2, torch.float32)
15291529
test_fn(torch.float8_e4m3fn, torch.bfloat16)
15301530

1531+
@torch.no_grad()
15311532
def test_layerwise_casting_inference(self):
15321533
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
15331534

15341535
torch.manual_seed(0)
15351536
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1536-
model = self.model_class(**config).eval()
1537-
model = model.to(torch_device)
1538-
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
1537+
model = self.model_class(**config)
1538+
model.eval()
1539+
model.to(torch_device)
1540+
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
15391541

15401542
def check_linear_dtype(module, storage_dtype, compute_dtype):
15411543
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1573,6 +1575,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
15731575
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
15741576

15751577
@require_torch_accelerator
1578+
@torch.no_grad()
15761579
def test_layerwise_casting_memory(self):
15771580
MB_TOLERANCE = 0.2
15781581
LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1706,10 +1709,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17061709
if not self.model_class._supports_group_offloading:
17071710
pytest.skip("Model does not support group offloading.")
17081711

1709-
torch.manual_seed(0)
1710-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1711-
model = self.model_class(**init_dict)
1712-
17131712
torch.manual_seed(0)
17141713
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17151714
model = self.model_class(**init_dict)
@@ -1725,7 +1724,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17251724
**additional_kwargs,
17261725
)
17271726
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1728-
assert has_safetensors, "No safetensors found in the directory."
1727+
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
17291728
_ = model(**inputs_dict)[0]
17301729

17311730
def test_auto_model(self, expected_max_diff=5e-5):

0 commit comments

Comments
 (0)