@@ -143,9 +143,9 @@ def test_pipeline_call_signature(self):
143
143
144
144
def _check_for_parameters (parameters , expected_parameters , param_type ):
145
145
remaining_parameters = {param for param in parameters if param not in expected_parameters }
146
- assert len ( remaining_parameters ) == 0 , (
147
- f"Required { param_type } parameters not present: { remaining_parameters } "
148
- )
146
+ assert (
147
+ len ( remaining_parameters ) == 0
148
+ ), f"Required { param_type } parameters not present: { remaining_parameters } "
149
149
150
150
_check_for_parameters (self .params , input_parameters , "input" )
151
151
_check_for_parameters (self .intermediate_params , intermediate_parameters , "intermediate" )
@@ -274,9 +274,9 @@ def test_to_device(self):
274
274
model_devices = [
275
275
component .device .type for component in pipe .components .values () if hasattr (component , "device" )
276
276
]
277
- assert all (device == torch_device for device in model_devices ), (
278
- "All pipeline components are not on accelerator device"
279
- )
277
+ assert all (
278
+ device == torch_device for device in model_devices
279
+ ), "All pipeline components are not on accelerator device"
280
280
281
281
def test_inference_is_not_nan_cpu (self ):
282
282
pipe = self .get_pipeline ()
@@ -318,3 +318,13 @@ def test_num_images_per_prompt(self):
318
318
images = pipe (** inputs , num_images_per_prompt = num_images_per_prompt , output = "images" )
319
319
320
320
assert images .shape [0 ] == batch_size * num_images_per_prompt
321
+
322
+ @require_accelerator
323
+ def test_components_auto_cpu_offload (self ):
324
+ base_pipe = self .get_pipeline ().to (torch_device )
325
+ for component in base_pipe .components :
326
+ assert component .device == torch_device
327
+
328
+ cm = ComponentsManager ()
329
+ cm .enable_auto_cpu_offload (device = torch_device )
330
+ offload_pipe = self .get_pipeline (components_manager = cm )
0 commit comments