Skip to content

Commit 3aabef5

Browse files
committed
update
1 parent 39be374 commit 3aabef5

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def test_pipeline_call_signature(self):
143143

144144
def _check_for_parameters(parameters, expected_parameters, param_type):
145145
remaining_parameters = {param for param in parameters if param not in expected_parameters}
146-
assert len(remaining_parameters) == 0, (
147-
f"Required {param_type} parameters not present: {remaining_parameters}"
148-
)
146+
assert (
147+
len(remaining_parameters) == 0
148+
), f"Required {param_type} parameters not present: {remaining_parameters}"
149149

150150
_check_for_parameters(self.params, input_parameters, "input")
151151
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
@@ -274,9 +274,9 @@ def test_to_device(self):
274274
model_devices = [
275275
component.device.type for component in pipe.components.values() if hasattr(component, "device")
276276
]
277-
assert all(device == torch_device for device in model_devices), (
278-
"All pipeline components are not on accelerator device"
279-
)
277+
assert all(
278+
device == torch_device for device in model_devices
279+
), "All pipeline components are not on accelerator device"
280280

281281
def test_inference_is_not_nan_cpu(self):
282282
pipe = self.get_pipeline()
@@ -318,3 +318,13 @@ def test_num_images_per_prompt(self):
318318
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
319319

320320
assert images.shape[0] == batch_size * num_images_per_prompt
321+
322+
@require_accelerator
323+
def test_components_auto_cpu_offload(self):
324+
base_pipe = self.get_pipeline().to(torch_device)
325+
for component in base_pipe.components:
326+
assert component.device == torch_device
327+
328+
cm = ComponentsManager()
329+
cm.enable_auto_cpu_offload(device=torch_device)
330+
offload_pipe = self.get_pipeline(components_manager=cm)

0 commit comments

Comments
 (0)