@@ -1528,14 +1528,16 @@ def test_fn(storage_dtype, compute_dtype):
1528
1528
test_fn (torch .float8_e5m2 , torch .float32 )
1529
1529
test_fn (torch .float8_e4m3fn , torch .bfloat16 )
1530
1530
1531
+ @torch .no_grad ()
1531
1532
def test_layerwise_casting_inference (self ):
1532
1533
from diffusers .hooks .layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN , SUPPORTED_PYTORCH_LAYERS
1533
1534
1534
1535
torch .manual_seed (0 )
1535
1536
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1536
- model = self .model_class (** config ).eval ()
1537
- model = model .to (torch_device )
1538
- base_slice = model (** inputs_dict )[0 ].flatten ().detach ().cpu ().numpy ()
1537
+ model = self .model_class (** config )
1538
+ model .eval ()
1539
+ model .to (torch_device )
1540
+ base_slice = model (** inputs_dict )[0 ].detach ().flatten ().cpu ().numpy ()
1539
1541
1540
1542
def check_linear_dtype (module , storage_dtype , compute_dtype ):
1541
1543
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1573,6 +1575,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
1573
1575
test_layerwise_casting (torch .float8_e4m3fn , torch .bfloat16 )
1574
1576
1575
1577
@require_torch_accelerator
1578
+ @torch .no_grad ()
1576
1579
def test_layerwise_casting_memory (self ):
1577
1580
MB_TOLERANCE = 0.2
1578
1581
LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1706,10 +1709,6 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
1706
1709
if not self .model_class ._supports_group_offloading :
1707
1710
pytest .skip ("Model does not support group offloading." )
1708
1711
1709
- torch .manual_seed (0 )
1710
- init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1711
- model = self .model_class (** init_dict )
1712
-
1713
1712
torch .manual_seed (0 )
1714
1713
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
1715
1714
model = self .model_class (** init_dict )
@@ -1725,7 +1724,7 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
1725
1724
** additional_kwargs ,
1726
1725
)
1727
1726
has_safetensors = glob .glob (f"{ tmpdir } /*.safetensors" )
1728
- assert has_safetensors , "No safetensors found in the directory."
1727
+ self . assertTrue ( len ( has_safetensors ) > 0 , "No safetensors found in the offload directory." )
1729
1728
_ = model (** inputs_dict )[0 ]
1730
1729
1731
1730
def test_auto_model (self , expected_max_diff = 5e-5 ):
0 commit comments