Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/diffusers/modular_pipelines/components_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def __call__(self, hooks, model_id, model, execution_device):
if len(hooks) == 0:
return []

current_module_size = model.get_memory_footprint()
try:
current_module_size = model.get_memory_footprint()
except AttributeError:
raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.")

device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
Expand Down Expand Up @@ -703,19 +706,28 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")

# TODO: add a warning if mem_get_info isn't available on `device`.
if device is None:
device = get_device()
if not isinstance(device, torch.device):
device = torch.device(device)

device_type = device.type
device_module = getattr(torch, device_type, torch.cuda)
if not hasattr(device_module, "mem_get_info"):
raise NotImplementedError(
f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}."
)

if device.index is None:
device = torch.device(f"{device.type}:{0}")

for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
remove_hook_from_module(component, recurse=True)

self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
if device is None:
device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")

all_hooks = []
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):
Expand Down
Loading