Skip to content

Commit de3af02

Browse files
authored
fix sequential offload for wan (#73)
* fix sequential offload * fix * fix
1 parent d3d2909 commit de3af02

File tree

9 files changed

+33
-84
lines changed

9 files changed

+33
-84
lines changed

diffsynth_engine/models/wan/wan_image_encoder.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -349,30 +349,6 @@ def __init__(
349349
embedding_dropout=embedding_dropout,
350350
norm_eps=norm_eps,
351351
)
352-
self.textual = None
353-
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
354-
355-
def forward(self, imgs, txt_ids):
356-
"""
357-
imgs: [B, 3, H, W] of torch.float32.
358-
- mean: [0.48145466, 0.4578275, 0.40821073]
359-
- std: [0.26862954, 0.26130258, 0.27577711]
360-
txt_ids: [B, L] of torch.long.
361-
Encoded by data.CLIPTokenizer.
362-
"""
363-
xi = self.visual(imgs)
364-
xt = self.textual(txt_ids)
365-
return xi, xt
366-
367-
def param_groups(self):
368-
groups = [
369-
{
370-
"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
371-
"weight_decay": 0.0,
372-
},
373-
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
374-
]
375-
return groups
376352

377353

378354
def _clip(
@@ -444,7 +420,7 @@ def _from_diffusers(self, state_dict):
444420
def _from_civitai(self, state_dict):
445421
state_dict_ = {}
446422
for name, param in state_dict.items():
447-
if name.startswith("textual."):
423+
if name.startswith(("textual.", "log_scale")):
448424
continue
449425
name = "model." + name
450426
state_dict_[name] = param

diffsynth_engine/models/wan/wan_text_encoder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
147147

148148
def forward(self, lq, lk):
149149
device = self.embedding.weight.device
150-
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
151-
# torch.arange(lq).unsqueeze(1).to(device)
152150
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
153151
rel_pos = self._relative_position_bucket(rel_pos)
154152
rel_pos_embeds = self.embedding(rel_pos)

diffsynth_engine/pipelines/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,26 +254,26 @@ def init_parallel_config(
254254
"use_fsdp": use_fsdp,
255255
}
256256

257-
@staticmethod
258-
def validate_offload_mode(offload_mode: str | None):
259-
valid_offload_mode = (None, "cpu_offload", "sequential_cpu_offload")
257+
def enable_cpu_offload(self, offload_mode: str):
258+
valid_offload_mode = ("cpu_offload", "sequential_cpu_offload")
260259
if offload_mode not in valid_offload_mode:
261260
raise ValueError(f"offload_mode must be one of {valid_offload_mode}, but got {offload_mode}")
262-
263-
def enable_cpu_offload(self):
264261
if self.device == "cpu":
265262
logger.warning("must set an non cpu device for pipeline before calling enable_cpu_offload")
266263
return
264+
if offload_mode == "cpu_offload":
265+
self.enable_model_cpu_offload()
266+
elif offload_mode == "sequential_cpu_offload":
267+
self.enable_sequential_cpu_offload()
268+
269+
def enable_model_cpu_offload(self):
267270
for model_name in self.model_names:
268271
model = getattr(self, model_name)
269272
if model is not None:
270273
model.to("cpu")
271274
self.offload_mode = "cpu_offload"
272275

273276
def enable_sequential_cpu_offload(self):
274-
if self.device == "cpu":
275-
logger.warning("must set an non cpu device for pipeline before calling enable_sequential_cpu_offload")
276-
return
277277
for model_name in self.model_names:
278278
model = getattr(self, model_name)
279279
if model is not None:

diffsynth_engine/pipelines/flux_image.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,6 @@ def from_pretrained(
361361
parallelism: int = 1,
362362
use_cfg_parallel: bool = False,
363363
) -> "FluxImagePipeline":
364-
cls.validate_offload_mode(offload_mode)
365-
366364
model_config = (
367365
model_path_or_config
368366
if isinstance(model_path_or_config, FluxModelConfig)
@@ -460,10 +458,8 @@ def from_pretrained(
460458
device=device,
461459
dtype=dtype,
462460
)
463-
if offload_mode == "cpu_offload":
464-
pipe.enable_cpu_offload()
465-
elif offload_mode == "sequential_cpu_offload":
466-
pipe.enable_sequential_cpu_offload()
461+
if offload_mode is not None:
462+
pipe.enable_cpu_offload(offload_mode)
467463
return pipe
468464

469465
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
@@ -751,7 +747,7 @@ def predict_multicontrolnet(
751747
# if current_step is not in the control range
752748
# skip thie controlnet
753749
continue
754-
if self.offload_mode == "sequential_cpu_offload" or self.offload_mode == "cpu_offload":
750+
if self.offload_mode is not None:
755751
empty_cache()
756752
param.model.to(self.device)
757753
double_block_output, single_block_output = param.model(
@@ -765,7 +761,7 @@ def predict_multicontrolnet(
765761
image_ids,
766762
text_ids,
767763
)
768-
if self.offload_mode == "sequential_cpu_offload" or self.offload_mode == "cpu_offload":
764+
if self.offload_mode is not None:
769765
empty_cache()
770766
param.model.to("cpu")
771767
double_block_output_results = accumulate(double_block_output_results, double_block_output)

diffsynth_engine/pipelines/sd_image.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@ def from_pretrained(
190190
offload_mode: str | None = None,
191191
batch_cfg: bool = True,
192192
) -> "SDImagePipeline":
193-
cls.validate_offload_mode(offload_mode)
194-
195193
if isinstance(model_path_or_config, str):
196194
model_config = SDModelConfig(unet_path=model_path_or_config)
197195
else:
@@ -237,10 +235,8 @@ def from_pretrained(
237235
device=device,
238236
dtype=dtype,
239237
)
240-
if offload_mode == "cpu_offload":
241-
pipe.enable_cpu_offload()
242-
elif offload_mode == "sequential_cpu_offload":
243-
pipe.enable_sequential_cpu_offload()
238+
if offload_mode is not None:
239+
pipe.enable_cpu_offload(offload_mode)
244240
return pipe
245241

246242
@classmethod

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ def from_pretrained(
164164
offload_mode: str | None = None,
165165
batch_cfg: bool = True,
166166
) -> "SDXLImagePipeline":
167-
cls.validate_offload_mode(offload_mode)
168-
169167
if isinstance(model_path_or_config, str):
170168
model_config = SDXLModelConfig(
171169
unet_path=model_path_or_config, unet_dtype=dtype, clip_l_dtype=dtype, clip_g_dtype=dtype
@@ -225,10 +223,8 @@ def from_pretrained(
225223
device=device,
226224
dtype=dtype,
227225
)
228-
if offload_mode == "cpu_offload":
229-
pipe.enable_cpu_offload()
230-
elif offload_mode == "sequential_cpu_offload":
231-
pipe.enable_sequential_cpu_offload()
226+
if offload_mode is not None:
227+
pipe.enable_cpu_offload(offload_mode)
232228
return pipe
233229

234230
@classmethod

diffsynth_engine/pipelines/wan_video.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
self.vae = vae
160160
self.image_encoder = image_encoder
161161
self.batch_cfg = batch_cfg
162-
self.model_names = ["text_encoder", "dit", "vae"]
162+
self.model_names = ["text_encoder", "dit", "vae", "image_encoder"]
163163

164164
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
165165
assert self.config.tp_degree is None, (
@@ -417,8 +417,6 @@ def from_pretrained(
417417
parallelism: int = 1,
418418
use_cfg_parallel: bool = False,
419419
) -> "WanVideoPipeline":
420-
cls.validate_offload_mode(offload_mode)
421-
422420
if isinstance(model_path_or_config, str):
423421
model_config = WanModelConfig(model_path=model_path_or_config)
424422
else:
@@ -523,10 +521,8 @@ def from_pretrained(
523521
dtype=dtype,
524522
)
525523
pipe.eval()
526-
if offload_mode == "cpu_offload":
527-
pipe.enable_cpu_offload()
528-
elif offload_mode == "sequential_cpu_offload":
529-
pipe.enable_sequential_cpu_offload()
524+
if offload_mode is not None:
525+
pipe.enable_cpu_offload(offload_mode)
530526
return pipe
531527

532528
def __del__(self):

diffsynth_engine/utils/offload.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,36 @@
1+
import torch
12
import torch.nn as nn
23

3-
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
4-
from diffsynth_engine.models.basic.relative_position_emb import RelativePositionEmbedding
5-
6-
7-
SUPPORTED_OFFLOAD_MODULES = (
8-
nn.Embedding,
9-
nn.Linear,
10-
nn.LayerNorm,
11-
nn.Conv2d,
12-
nn.GroupNorm,
13-
RMSNorm,
14-
RelativePositionEmbedding,
15-
)
16-
174

185
def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda:0"):
19-
if isinstance(module, SUPPORTED_OFFLOAD_MODULES):
20-
add_cpu_offload_hook(module, device)
6+
if len(list(module.children())) == 0:
7+
if len(list(module.parameters())) > 0: # leaf module with parameters
8+
add_cpu_offload_hook(module, device)
219
return
10+
if len(list(module.parameters(recurse=False))) > 0: # module with direct parameters
11+
add_cpu_offload_hook(module, device, recurse=False)
2212
for submodule in module.children():
2313
enable_sequential_cpu_offload(submodule, device)
2414

2515

26-
def add_cpu_offload_hook(module: nn.Module, device: str = "cuda:0"):
16+
# TODO: supports module buffer
17+
def add_cpu_offload_hook(module: nn.Module, device: str = "cuda:0", recurse: bool = True):
2718
def _forward_pre_hook(module: nn.Module, input):
2819
offload_params = {}
29-
for name, param in module.named_parameters():
20+
for name, param in module.named_parameters(recurse=recurse):
3021
offload_params[name] = param.data
3122
param.data = param.data.to(device=device)
3223
setattr(module, "_offload_params", offload_params)
24+
return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input)
3325

3426
def _forward_hook(module: nn.Module, input, output):
3527
offload_params = getattr(module, "_offload_params", {})
36-
for name, param in module.named_parameters():
28+
for name, param in module.named_parameters(recurse=recurse):
3729
if name in offload_params:
3830
param.data = offload_params[name]
3931

40-
if getattr(module, "_sequential_cpu_offload_enabled", False):
32+
if getattr(module, "_cpu_offload_enabled", False):
4133
return
4234
module.register_forward_pre_hook(_forward_pre_hook)
4335
module.register_forward_hook(_forward_hook)
44-
setattr(module, "_sequential_cpu_offload_enabled", True)
36+
setattr(module, "_cpu_offload_enabled", True)

diffsynth_engine/utils/parallel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ def __init__(
338338
):
339339
super().__init__()
340340
self.world_size = cfg_degree * sp_ulysses_degree * sp_ring_degree * tp_degree
341-
self.device = device
342341
self.queue_in = mp.Queue()
343342
self.queue_out = mp.Queue()
344343
self.ctx = mp.spawn(

0 commit comments

Comments
 (0)