Skip to content

Commit baf5345

Browse files
authored
fix wan parallel & update examples (#116)
1 parent ac82a1c commit baf5345

File tree

9 files changed

+20
-36
lines changed

9 files changed

+20
-36
lines changed

diffsynth_engine/models/wan/wan_dit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,10 @@ def forward(
334334
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
335335
y: Optional[torch.Tensor] = None, # vae_encoder(img)
336336
):
337+
use_cfg = x.shape[0] > 1
337338
with (
338339
gguf_inference(),
339-
cfg_parallel((x, context, timestep, clip_feature, y)),
340+
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
340341
):
341342
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
342343
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
@@ -365,7 +366,7 @@ def forward(
365366
x = self.head(x, t)
366367
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
367368
x = self.unpatchify(x, (f, h, w))
368-
(x,) = cfg_parallel_unshard((x,))
369+
(x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
369370
return x
370371

371372
@classmethod

diffsynth_engine/models/wan/wan_vae.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def convert(self, state_dict):
515515
class WanVideoVAE(PreTrainedModel):
516516
converter = WanVideoVAEStateDictConverter()
517517

518-
def __init__(self, z_dim=16, parallelism: int = 1, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
518+
def __init__(self, z_dim=16, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
519519
super().__init__()
520520

521521
mean = [
@@ -561,12 +561,11 @@ def __init__(self, z_dim=16, parallelism: int = 1, device: str = "cuda:0", dtype
561561
# init model
562562
self.model = VideoVAE(z_dim=z_dim).eval().requires_grad_(False)
563563
self.upsampling_factor = 8
564-
self.parallelism = parallelism
565564

566565
@classmethod
567-
def from_state_dict(cls, state_dict, parallelism=1, device="cuda:0", dtype=torch.float32) -> "WanVideoVAE":
566+
def from_state_dict(cls, state_dict, device="cuda:0", dtype=torch.float32) -> "WanVideoVAE":
568567
with no_init_weights():
569-
model = torch.nn.utils.skip_init(cls, parallelism=parallelism, device=device, dtype=dtype)
568+
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
570569
model.load_state_dict(state_dict, assign=True)
571570
model.to(device=device, dtype=dtype, non_blocking=True)
572571
return model
@@ -607,7 +606,7 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
607606
h_, w_ = h + size_h, w + size_w
608607
tasks.append((h, h_, w, w_))
609608

610-
data_device = device if self.parallelism > 1 else "cpu"
609+
data_device = device if dist.is_initialized() else "cpu"
611610
computation_device = device
612611

613612
out_T = T * 4 - 3
@@ -622,9 +621,9 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
622621
device=data_device,
623622
)
624623

625-
hide_progress_bar = self.parallelism > 1 and dist.get_rank() != 0
626-
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE DECODING", disable=hide_progress_bar)):
627-
if self.parallelism > 1 and (i % dist.get_world_size() != dist.get_rank()):
624+
hide_progress = dist.is_initialized() and dist.get_rank() != 0
625+
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE DECODING", disable=hide_progress)):
626+
if dist.is_initialized() and (i % dist.get_world_size() != dist.get_rank()):
628627
continue
629628
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
630629
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
@@ -654,11 +653,11 @@ def tiled_decode(self, hidden_states, device, tile_size, tile_stride, progress_c
654653
target_h : target_h + hidden_states_batch.shape[3],
655654
target_w : target_w + hidden_states_batch.shape[4],
656655
] += mask
657-
if progress_callback is not None and not hide_progress_bar:
656+
if progress_callback is not None and not hide_progress:
658657
progress_callback(i + 1, len(tasks), "VAE DECODING")
659-
if progress_callback is not None and not hide_progress_bar:
658+
if progress_callback is not None and not hide_progress:
660659
progress_callback(len(tasks), len(tasks), "VAE DECODING")
661-
if self.parallelism > 1:
660+
if dist.is_initialized():
662661
dist.all_reduce(values)
663662
dist.all_reduce(weight)
664663
values = values / weight
@@ -681,7 +680,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
681680
h_, w_ = h + size_h, w + size_w
682681
tasks.append((h, h_, w, w_))
683682

684-
data_device = device if self.parallelism > 1 else "cpu"
683+
data_device = device if dist.is_initialized() else "cpu"
685684
computation_device = device
686685

687686
out_T = (T + 3) // 4
@@ -696,9 +695,9 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
696695
device=data_device,
697696
)
698697

699-
hide_progress_bar = self.parallelism > 1 and dist.get_rank() != 0
698+
hide_progress_bar = dist.is_initialized() and dist.get_rank() != 0
700699
for i, (h, h_, w, w_) in enumerate(tqdm(tasks, desc="VAE ENCODING", disable=hide_progress_bar)):
701-
if self.parallelism > 1 and (i % dist.get_world_size() != dist.get_rank()):
700+
if dist.is_initialized() and (i % dist.get_world_size() != dist.get_rank()):
702701
continue
703702
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
704703
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
@@ -732,7 +731,7 @@ def tiled_encode(self, video, device, tile_size, tile_stride, progress_callback=
732731
progress_callback(i + 1, len(tasks), "VAE ENCODING")
733732
if progress_callback is not None and not hide_progress_bar:
734733
progress_callback(len(tasks), len(tasks), "VAE ENCODING")
735-
if self.parallelism > 1:
734+
if dist.is_initialized():
736735
dist.all_reduce(values)
737736
dist.all_reduce(weight)
738737
values = values / weight

examples/flux_parallel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import torch.multiprocessing as mp
21
from diffsynth_engine import fetch_model, FluxImagePipeline
32

43

54
if __name__ == "__main__":
6-
mp.set_start_method("spawn")
75
model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
86
pipe = FluxImagePipeline.from_pretrained(model_path, parallelism=4, offload_mode="cpu_offload")
97
image = pipe(prompt="a cat", seed=42)

examples/wan_flf_to_video.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch.multiprocessing as mp
21
from PIL import Image
32

43
from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
@@ -7,9 +6,8 @@
76

87

98
if __name__ == "__main__":
10-
mp.set_start_method("spawn")
119
config = WanModelConfig(
12-
model_path=fetch_model("muse/wan2.1-flf2v-14b-720p-bf16", path="dit.safetensors"),
10+
model_path=fetch_model("MusePublic/wan2.1-flf2v-14b-720p-bf16", path="dit.safetensors"),
1311
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
1412
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
1513
image_encoder_path=fetch_model(

examples/wan_image_to_video.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch.multiprocessing as mp
21
from PIL import Image
32

43
from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
@@ -7,9 +6,8 @@
76

87

98
if __name__ == "__main__":
10-
mp.set_start_method("spawn")
119
config = WanModelConfig(
12-
model_path=fetch_model("muse/wan2.1-i2v-14b-480p-bf16", path="dit.safetensors"),
10+
model_path=fetch_model("MusePublic/wan2.1-i2v-14b-480p-bf16", path="dit.safetensors"),
1311
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
1412
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
1513
image_encoder_path=fetch_model(

examples/wan_lora.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
import torch.multiprocessing as mp
2-
31
from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
42
from diffsynth_engine.utils.download import fetch_model
53
from diffsynth_engine.utils.video import save_video
64

75

86
if __name__ == "__main__":
9-
mp.set_start_method("spawn")
107
config = WanModelConfig(
118
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
129
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),

examples/wan_text_to_video.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
import torch.multiprocessing as mp
2-
31
from diffsynth_engine.pipelines import WanVideoPipeline, WanModelConfig
42
from diffsynth_engine.utils.download import fetch_model
53
from diffsynth_engine.utils.video import save_video
64

75

86
if __name__ == "__main__":
9-
mp.set_start_method("spawn")
107
config = WanModelConfig(
11-
model_path=fetch_model("muse/wan2.1-t2v-14b-bf16", path="dit.safetensors"),
8+
model_path=fetch_model("MusePublic/wan2.1-14b-t2v", path="dit.safetensors"),
129
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
1310
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
1411
use_fsdp=True,

tests/test_models/wan/test_wan_vae_parallel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.multiprocessing as mp
32
import unittest
43
import numpy as np
54

@@ -13,7 +12,6 @@
1312
class TestWanVAEParallel(VideoTestCase):
1413
@classmethod
1514
def setUpClass(cls):
16-
mp.set_start_method("spawn")
1715
cls._vae_model_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
1816
loaded_state_dict = load_file(cls._vae_model_path)
1917
vae = WanVideoVAE.from_state_dict(loaded_state_dict, parallelism=4)

tests/test_pipelines/test_wan_video_parallel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch.multiprocessing as mp
21
import unittest
32

43
from tests.common.test_case import VideoTestCase
@@ -9,7 +8,6 @@
98
class TestWanVideoTP(VideoTestCase):
109
@classmethod
1110
def setUpClass(cls):
12-
mp.set_start_method("spawn")
1311
config = WanModelConfig(
1412
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
1513
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),

0 commit comments

Comments
 (0)