Skip to content

Commit 2c1ed50

Browse files
authored
Provide option to reduce CPU RAM usage in Group Offload (#11106)
* update * update * clean up
1 parent 15ad97f commit 2c1ed50

File tree

2 files changed

+93
-55
lines changed

2 files changed

+93
-55
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from contextlib import nullcontext
15+
from contextlib import contextmanager, nullcontext
1616
from typing import Dict, List, Optional, Set, Tuple
1717

1818
import torch
@@ -56,23 +56,58 @@ def __init__(
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
59-
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
59+
low_cpu_mem_usage=False,
6060
onload_self: bool = True,
6161
) -> None:
6262
self.modules = modules
6363
self.offload_device = offload_device
6464
self.onload_device = onload_device
6565
self.offload_leader = offload_leader
6666
self.onload_leader = onload_leader
67-
self.parameters = parameters
68-
self.buffers = buffers
67+
self.parameters = parameters or []
68+
self.buffers = buffers or []
6969
self.non_blocking = non_blocking or stream is not None
7070
self.stream = stream
71-
self.cpu_param_dict = cpu_param_dict
7271
self.onload_self = onload_self
72+
self.low_cpu_mem_usage = low_cpu_mem_usage
7373

74-
if self.stream is not None and self.cpu_param_dict is None:
75-
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
74+
self.cpu_param_dict = self._init_cpu_param_dict()
75+
76+
def _init_cpu_param_dict(self):
77+
cpu_param_dict = {}
78+
if self.stream is None:
79+
return cpu_param_dict
80+
81+
for module in self.modules:
82+
for param in module.parameters():
83+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
84+
for buffer in module.buffers():
85+
cpu_param_dict[buffer] = (
86+
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
87+
)
88+
89+
for param in self.parameters:
90+
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
91+
92+
for buffer in self.buffers:
93+
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
94+
95+
return cpu_param_dict
96+
97+
@contextmanager
98+
def _pinned_memory_tensors(self):
99+
pinned_dict = {}
100+
try:
101+
for param, tensor in self.cpu_param_dict.items():
102+
if not tensor.is_pinned():
103+
pinned_dict[param] = tensor.pin_memory()
104+
else:
105+
pinned_dict[param] = tensor
106+
107+
yield pinned_dict
108+
109+
finally:
110+
pinned_dict = None
76111

77112
def onload_(self):
78113
r"""Onloads the group of modules to the onload_device."""
@@ -82,15 +117,30 @@ def onload_(self):
82117
self.stream.synchronize()
83118

84119
with context:
85-
for group_module in self.modules:
86-
for param in group_module.parameters():
87-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
88-
for buffer in group_module.buffers():
89-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
90-
if self.parameters is not None:
120+
if self.stream is not None:
121+
with self._pinned_memory_tensors() as pinned_memory:
122+
for group_module in self.modules:
123+
for param in group_module.parameters():
124+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
125+
for buffer in group_module.buffers():
126+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
127+
128+
for param in self.parameters:
129+
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
130+
131+
for buffer in self.buffers:
132+
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
133+
134+
else:
135+
for group_module in self.modules:
136+
for param in group_module.parameters():
137+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
138+
for buffer in group_module.buffers():
139+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
140+
91141
for param in self.parameters:
92142
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
93-
if self.buffers is not None:
143+
94144
for buffer in self.buffers:
95145
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
96146

@@ -101,21 +151,18 @@ def offload_(self):
101151
for group_module in self.modules:
102152
for param in group_module.parameters():
103153
param.data = self.cpu_param_dict[param]
104-
if self.parameters is not None:
105-
for param in self.parameters:
106-
param.data = self.cpu_param_dict[param]
107-
if self.buffers is not None:
108-
for buffer in self.buffers:
109-
buffer.data = self.cpu_param_dict[buffer]
154+
for param in self.parameters:
155+
param.data = self.cpu_param_dict[param]
156+
for buffer in self.buffers:
157+
buffer.data = self.cpu_param_dict[buffer]
158+
110159
else:
111160
for group_module in self.modules:
112161
group_module.to(self.offload_device, non_blocking=self.non_blocking)
113-
if self.parameters is not None:
114-
for param in self.parameters:
115-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
116-
if self.buffers is not None:
117-
for buffer in self.buffers:
118-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
162+
for param in self.parameters:
163+
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
164+
for buffer in self.buffers:
165+
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
119166

120167

121168
class GroupOffloadingHook(ModelHook):
@@ -284,6 +331,7 @@ def apply_group_offloading(
284331
num_blocks_per_group: Optional[int] = None,
285332
non_blocking: bool = False,
286333
use_stream: bool = False,
334+
low_cpu_mem_usage=False,
287335
) -> None:
288336
r"""
289337
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +413,12 @@ def apply_group_offloading(
365413
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
366414

367415
_apply_group_offloading_block_level(
368-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
416+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
369417
)
370418
elif offload_type == "leaf_level":
371-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
419+
_apply_group_offloading_leaf_level(
420+
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
421+
)
372422
else:
373423
raise ValueError(f"Unsupported offload_type: {offload_type}")
374424

@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
380430
onload_device: torch.device,
381431
non_blocking: bool,
382432
stream: Optional[torch.cuda.Stream] = None,
433+
low_cpu_mem_usage: bool = False,
383434
) -> None:
384435
r"""
385436
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
400451
for overlapping computation and data transfer.
401452
"""
402453

403-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
404-
cpu_param_dict = None
405-
if stream is not None:
406-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
407-
408454
# Create module groups for ModuleList and Sequential blocks
409455
modules_with_group_offloading = set()
410456
unmatched_modules = []
@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
425471
onload_leader=current_modules[0],
426472
non_blocking=non_blocking,
427473
stream=stream,
428-
cpu_param_dict=cpu_param_dict,
474+
low_cpu_mem_usage=low_cpu_mem_usage,
429475
onload_self=stream is None,
430476
)
431477
matched_module_groups.append(group)
@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
462508
buffers=buffers,
463509
non_blocking=False,
464510
stream=None,
465-
cpu_param_dict=None,
466511
onload_self=True,
467512
)
468513
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
475520
onload_device: torch.device,
476521
non_blocking: bool,
477522
stream: Optional[torch.cuda.Stream] = None,
523+
low_cpu_mem_usage: bool = False,
478524
) -> None:
479525
r"""
480526
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
497543
for overlapping computation and data transfer.
498544
"""
499545

500-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
501-
cpu_param_dict = None
502-
if stream is not None:
503-
cpu_param_dict = _get_pinned_cpu_param_dict(module)
504-
505546
# Create module groups for leaf modules and apply group offloading hooks
506547
modules_with_group_offloading = set()
507548
for name, submodule in module.named_modules():
@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
515556
onload_leader=submodule,
516557
non_blocking=non_blocking,
517558
stream=stream,
518-
cpu_param_dict=cpu_param_dict,
559+
low_cpu_mem_usage=low_cpu_mem_usage,
519560
onload_self=True,
520561
)
521562
_apply_group_offloading_hook(submodule, group, None)
@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
560601
buffers=buffers,
561602
non_blocking=non_blocking,
562603
stream=stream,
563-
cpu_param_dict=cpu_param_dict,
604+
low_cpu_mem_usage=low_cpu_mem_usage,
564605
onload_self=True,
565606
)
566607
_apply_group_offloading_hook(parent_module, group, None)
@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
579620
buffers=None,
580621
non_blocking=False,
581622
stream=None,
582-
cpu_param_dict=None,
623+
low_cpu_mem_usage=low_cpu_mem_usage,
583624
onload_self=True,
584625
)
585626
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
616657
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
617658

618659

619-
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
620-
cpu_param_dict = {}
621-
for param in module.parameters():
622-
param.data = param.data.cpu().pin_memory()
623-
cpu_param_dict[param] = param.data
624-
for buffer in module.buffers():
625-
buffer.data = buffer.data.cpu().pin_memory()
626-
cpu_param_dict[buffer] = buffer.data
627-
return cpu_param_dict
628-
629-
630660
def _gather_parameters_with_no_group_offloading_parent(
631661
module: torch.nn.Module, modules_with_group_offloading: Set[str]
632662
) -> List[torch.nn.Parameter]:

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549+
low_cpu_mem_usage=False,
549550
) -> None:
550551
r"""
551552
Activates group offloading for the current model.
@@ -584,7 +585,14 @@ def enable_group_offload(
584585
f"open an issue at https://github.com/huggingface/diffusers/issues."
585586
)
586587
apply_group_offloading(
587-
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
588+
self,
589+
onload_device,
590+
offload_device,
591+
offload_type,
592+
num_blocks_per_group,
593+
non_blocking,
594+
use_stream,
595+
low_cpu_mem_usage=low_cpu_mem_usage,
588596
)
589597

590598
def save_pretrained(

0 commit comments

Comments
 (0)