Skip to content

Commit 84a1e21

Browse files
authored
Merge branch 'main' into dependabot/pip/examples/server/jinja2-3.1.6
2 parents 8a556a6 + 1a04812 commit 84a1e21

File tree

11 files changed

+263
-20
lines changed

11 files changed

+263
-20
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ jobs:
417417
additional_deps: ["peft"]
418418
- backend: "gguf"
419419
test_location: "gguf"
420-
additional_deps: []
420+
additional_deps: ["peft"]
421421
- backend: "torchao"
422422
test_location: "torchao"
423423
additional_deps: []

docs/source/en/optimization/memory.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch
178178
# We can utilize the enable_group_offload method for Diffusers model implementations
179179
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
180180

181+
# Uncomment the following to also allow recording the current streams.
182+
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
183+
181184
# For any other model implementations, the apply_group_offloading function can be used
182185
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
183186
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
@@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s
205208
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
206209
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
207210
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
211+
- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this.
208212

209213
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
210214

src/diffusers/hooks/group_offloading.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
59+
record_stream: Optional[bool] = False,
5960
low_cpu_mem_usage=False,
6061
onload_self: bool = True,
6162
) -> None:
@@ -68,11 +69,14 @@ def __init__(
6869
self.buffers = buffers or []
6970
self.non_blocking = non_blocking or stream is not None
7071
self.stream = stream
72+
self.record_stream = record_stream
7173
self.onload_self = onload_self
7274
self.low_cpu_mem_usage = low_cpu_mem_usage
73-
7475
self.cpu_param_dict = self._init_cpu_param_dict()
7576

77+
if self.stream is None and self.record_stream:
78+
raise ValueError("`record_stream` cannot be True when `stream` is None.")
79+
7680
def _init_cpu_param_dict(self):
7781
cpu_param_dict = {}
7882
if self.stream is None:
@@ -112,6 +116,8 @@ def _pinned_memory_tensors(self):
112116
def onload_(self):
113117
r"""Onloads the group of modules to the onload_device."""
114118
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
119+
current_stream = torch.cuda.current_stream() if self.record_stream else None
120+
115121
if self.stream is not None:
116122
# Wait for previous Host->Device transfer to complete
117123
self.stream.synchronize()
@@ -122,14 +128,22 @@ def onload_(self):
122128
for group_module in self.modules:
123129
for param in group_module.parameters():
124130
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
131+
if self.record_stream:
132+
param.data.record_stream(current_stream)
125133
for buffer in group_module.buffers():
126134
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
135+
if self.record_stream:
136+
buffer.data.record_stream(current_stream)
127137

128138
for param in self.parameters:
129139
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
140+
if self.record_stream:
141+
param.data.record_stream(current_stream)
130142

131143
for buffer in self.buffers:
132144
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
145+
if self.record_stream:
146+
buffer.data.record_stream(current_stream)
133147

134148
else:
135149
for group_module in self.modules:
@@ -143,11 +157,14 @@ def onload_(self):
143157

144158
for buffer in self.buffers:
145159
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
160+
if self.record_stream:
161+
buffer.data.record_stream(current_stream)
146162

147163
def offload_(self):
148164
r"""Offloads the group of modules to the offload_device."""
149165
if self.stream is not None:
150-
torch.cuda.current_stream().synchronize()
166+
if not self.record_stream:
167+
torch.cuda.current_stream().synchronize()
151168
for group_module in self.modules:
152169
for param in group_module.parameters():
153170
param.data = self.cpu_param_dict[param]
@@ -331,6 +348,7 @@ def apply_group_offloading(
331348
num_blocks_per_group: Optional[int] = None,
332349
non_blocking: bool = False,
333350
use_stream: bool = False,
351+
record_stream: bool = False,
334352
low_cpu_mem_usage: bool = False,
335353
) -> None:
336354
r"""
@@ -378,6 +396,10 @@ def apply_group_offloading(
378396
use_stream (`bool`, defaults to `False`):
379397
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
380398
overlapping computation and data transfer.
399+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
400+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
401+
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
402+
details.
381403
low_cpu_mem_usage (`bool`, defaults to `False`):
382404
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
383405
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
@@ -417,11 +439,24 @@ def apply_group_offloading(
417439
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
418440

419441
_apply_group_offloading_block_level(
420-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
442+
module=module,
443+
num_blocks_per_group=num_blocks_per_group,
444+
offload_device=offload_device,
445+
onload_device=onload_device,
446+
non_blocking=non_blocking,
447+
stream=stream,
448+
record_stream=record_stream,
449+
low_cpu_mem_usage=low_cpu_mem_usage,
421450
)
422451
elif offload_type == "leaf_level":
423452
_apply_group_offloading_leaf_level(
424-
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
453+
module=module,
454+
offload_device=offload_device,
455+
onload_device=onload_device,
456+
non_blocking=non_blocking,
457+
stream=stream,
458+
record_stream=record_stream,
459+
low_cpu_mem_usage=low_cpu_mem_usage,
425460
)
426461
else:
427462
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
434469
onload_device: torch.device,
435470
non_blocking: bool,
436471
stream: Optional[torch.cuda.Stream] = None,
472+
record_stream: Optional[bool] = False,
437473
low_cpu_mem_usage: bool = False,
438474
) -> None:
439475
r"""
@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
453489
stream (`torch.cuda.Stream`, *optional*):
454490
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
455491
for overlapping computation and data transfer.
492+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
493+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
494+
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
495+
details.
496+
low_cpu_mem_usage (`bool`, defaults to `False`):
497+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
498+
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
499+
the CPU memory is a bottleneck but may counteract the benefits of using streams.
456500
"""
457501

458502
# Create module groups for ModuleList and Sequential blocks
@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
475519
onload_leader=current_modules[0],
476520
non_blocking=non_blocking,
477521
stream=stream,
522+
record_stream=record_stream,
478523
low_cpu_mem_usage=low_cpu_mem_usage,
479524
onload_self=stream is None,
480525
)
@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
512557
buffers=buffers,
513558
non_blocking=False,
514559
stream=None,
560+
record_stream=False,
515561
onload_self=True,
516562
)
517563
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
524570
onload_device: torch.device,
525571
non_blocking: bool,
526572
stream: Optional[torch.cuda.Stream] = None,
573+
record_stream: Optional[bool] = False,
527574
low_cpu_mem_usage: bool = False,
528575
) -> None:
529576
r"""
@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
545592
stream (`torch.cuda.Stream`, *optional*):
546593
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
547594
for overlapping computation and data transfer.
595+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
596+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
597+
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
598+
details.
599+
low_cpu_mem_usage (`bool`, defaults to `False`):
600+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
601+
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
602+
the CPU memory is a bottleneck but may counteract the benefits of using streams.
548603
"""
549604

550605
# Create module groups for leaf modules and apply group offloading hooks
@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
560615
onload_leader=submodule,
561616
non_blocking=non_blocking,
562617
stream=stream,
618+
record_stream=record_stream,
563619
low_cpu_mem_usage=low_cpu_mem_usage,
564620
onload_self=True,
565621
)
@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
605661
buffers=buffers,
606662
non_blocking=non_blocking,
607663
stream=stream,
664+
record_stream=record_stream,
608665
low_cpu_mem_usage=low_cpu_mem_usage,
609666
onload_self=True,
610667
)
@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
624681
buffers=None,
625682
non_blocking=False,
626683
stream=None,
684+
record_stream=False,
627685
low_cpu_mem_usage=low_cpu_mem_usage,
628686
onload_self=True,
629687
)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
USE_PEFT_BACKEND,
2323
deprecate,
2424
get_submodule_by_name,
25+
is_bitsandbytes_available,
26+
is_gguf_available,
2527
is_peft_available,
2628
is_peft_version,
2729
is_torch_version,
@@ -68,6 +70,49 @@
6870
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
6971

7072

73+
def _maybe_dequantize_weight_for_expanded_lora(model, module):
74+
if is_bitsandbytes_available():
75+
from ..quantizers.bitsandbytes import dequantize_bnb_weight
76+
77+
if is_gguf_available():
78+
from ..quantizers.gguf.utils import dequantize_gguf_tensor
79+
80+
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
81+
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
82+
83+
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
84+
raise ValueError(
85+
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
86+
)
87+
if is_gguf_quantized and not is_gguf_available():
88+
raise ValueError(
89+
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
90+
)
91+
92+
weight_on_cpu = False
93+
if not module.weight.is_cuda:
94+
weight_on_cpu = True
95+
96+
if is_bnb_4bit_quantized:
97+
module_weight = dequantize_bnb_weight(
98+
module.weight.cuda() if weight_on_cpu else module.weight,
99+
state=module.weight.quant_state,
100+
dtype=model.dtype,
101+
).data
102+
elif is_gguf_quantized:
103+
module_weight = dequantize_gguf_tensor(
104+
module.weight.cuda() if weight_on_cpu else module.weight,
105+
)
106+
module_weight = module_weight.to(model.dtype)
107+
else:
108+
module_weight = module.weight.data
109+
110+
if weight_on_cpu:
111+
module_weight = module_weight.cpu()
112+
113+
return module_weight
114+
115+
71116
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
72117
r"""
73118
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -2267,6 +2312,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
22672312
overwritten_params = {}
22682313

22692314
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
2315+
is_quantized = hasattr(transformer, "hf_quantizer")
22702316
for name, module in transformer.named_modules():
22712317
if isinstance(module, torch.nn.Linear):
22722318
module_weight = module.weight.data
@@ -2291,9 +2337,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
22912337
if tuple(module_weight_shape) == (out_features, in_features):
22922338
continue
22932339

2294-
# TODO (sayakpaul): We still need to consider if the module we're expanding is
2295-
# quantized and handle it accordingly if that is the case.
2296-
module_out_features, module_in_features = module_weight.shape
2340+
module_out_features, module_in_features = module_weight_shape
22972341
debug_message = ""
22982342
if in_features > module_in_features:
22992343
debug_message += (
@@ -2316,6 +2360,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
23162360
parent_module_name, _, current_module_name = name.rpartition(".")
23172361
parent_module = transformer.get_submodule(parent_module_name)
23182362

2363+
if is_quantized:
2364+
module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
2365+
2366+
# TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
23192367
with torch.device("meta"):
23202368
expanded_module = torch.nn.Linear(
23212369
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2327,7 +2375,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23272375
new_weight = torch.zeros_like(
23282376
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
23292377
)
2330-
slices = tuple(slice(0, dim) for dim in module_weight.shape)
2378+
slices = tuple(slice(0, dim) for dim in module_weight_shape)
23312379
new_weight[slices] = module_weight
23322380
tmp_state_dict = {"weight": new_weight}
23332381
if module_bias is not None:
@@ -2416,7 +2464,12 @@ def _calculate_module_shape(
24162464
base_weight_param_name: str = None,
24172465
) -> "torch.Size":
24182466
def _get_weight_shape(weight: torch.Tensor):
2419-
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
2467+
if weight.__class__.__name__ == "Params4bit":
2468+
return weight.quant_state.shape
2469+
elif weight.__class__.__name__ == "GGUFParameter":
2470+
return weight.quant_shape
2471+
else:
2472+
return weight.shape
24202473

24212474
if base_module is not None:
24222475
return _get_weight_shape(base_module.weight)

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549+
record_stream: bool = False,
549550
low_cpu_mem_usage=False,
550551
) -> None:
551552
r"""
@@ -594,6 +595,7 @@ def enable_group_offload(
594595
num_blocks_per_group,
595596
non_blocking,
596597
use_stream,
598+
record_stream,
597599
low_cpu_mem_usage=low_cpu_mem_usage,
598600
)
599601

0 commit comments

Comments
 (0)