Skip to content

Commit d3d2909

Browse files
authored
supports flux parallel (#70)
* supports flux parallel * fix sequence parallel issue by padding
1 parent f9a925f commit d3d2909

20 files changed

+316
-145
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Text to image
4545
```python
4646
from diffsynth_engine import fetch_model, FluxImagePipeline
4747

48-
model_path = fetch_model("muse/flux-with-vae", path="flux_with_vae.safetensors")
48+
model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
4949
pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')
5050
image = pipe(prompt="a cat")
5151
image.save("image.png")
@@ -54,7 +54,7 @@ Text to image with LoRA
5454
```python
5555
from diffsynth_engine import fetch_model, FluxImagePipeline
5656

57-
model_path = fetch_model("muse/flux-with-vae", path="flux_with_vae.safetensors")
57+
model_path = fetch_model("muse/flux-with-vae", path="flux1-dev-with-vae.safetensors")
5858
lora_path = fetch_model("DonRat/MAJICFLUS_SuperChinesestyleheongsam", path="麦橘超国风旗袍.safetensors")
5959

6060
pipe = FluxImagePipeline.from_pretrained(model_path, device='cuda:0')

diffsynth_engine/models/basic/attention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,8 @@ def long_context_attention(
201201
assert attn_impl in [
202202
None,
203203
"auto",
204-
"eager",
205204
"flash_attn_2",
206205
"flash_attn_3",
207-
"xformers",
208206
"sdpa",
209207
"sage_attn",
210208
"sparge_attn",

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
)
1414
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
1515
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
16+
from diffsynth_engine.models.basic import attention as attention_ops
1617
from diffsynth_engine.models.utils import no_init_weights
1718
from diffsynth_engine.utils.gguf import gguf_inference
1819
from diffsynth_engine.utils.fp8_linear import fp8_inference
1920
from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
20-
from diffsynth_engine.models.basic.attention import attention
21+
from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
2122
from diffsynth_engine.utils import logging
2223

2324

@@ -198,7 +199,7 @@ def forward(self, image, text, rope_emb, image_emb):
198199
k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1)
199200
v = torch.cat([v_b, v_a], dim=1)
200201
q, k = apply_rope(q, k, rope_emb)
201-
attn_out = attention(q, k, v, attn_impl=self.attn_impl)
202+
attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
202203
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
203204
text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :]
204205
image_out, text_out = self.attention_callback(
@@ -286,7 +287,7 @@ def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
286287
def forward(self, x, rope_emb, image_emb):
287288
q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2)
288289
q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb)
289-
attn_out = attention(q, k, v, attn_impl=self.attn_impl)
290+
attn_out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl)
290291
attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype)
291292
return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb)
292293

@@ -324,6 +325,7 @@ def __init__(
324325
self,
325326
in_channel: int = 64,
326327
attn_impl: Optional[str] = None,
328+
use_usp: bool = False,
327329
device: str = "cuda:0",
328330
dtype: torch.dtype = torch.bfloat16,
329331
):
@@ -349,6 +351,8 @@ def __init__(
349351
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
350352
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
351353

354+
self.use_usp = use_usp
355+
352356
def patchify(self, hidden_states):
353357
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
354358
return hidden_states
@@ -359,7 +363,8 @@ def unpatchify(self, hidden_states, height, width):
359363
)
360364
return hidden_states
361365

362-
def prepare_image_ids(self, latents):
366+
@staticmethod
367+
def prepare_image_ids(latents: torch.Tensor):
363368
batch_size, _, height, width = latents.shape
364369
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
365370
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
@@ -389,7 +394,14 @@ def forward(
389394
controlnet_single_block_output=None,
390395
**kwargs,
391396
):
392-
height, width = hidden_states.shape[-2:]
397+
h, w = hidden_states.shape[-2:]
398+
controlnet_double_block_output = (
399+
controlnet_double_block_output if controlnet_double_block_output is not None else ()
400+
)
401+
controlnet_single_block_output = (
402+
controlnet_single_block_output if controlnet_single_block_output is not None else ()
403+
)
404+
393405
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
394406
with fp8_inference(fp8_linear_enabled), gguf_inference():
395407
if image_ids is None:
@@ -402,28 +414,54 @@ def forward(
402414
guidance = guidance * 1000
403415
conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
404416
conditioning += self.pooled_text_embedder(pooled_prompt_emb)
405-
prompt_emb = self.context_embedder(prompt_emb)
406417
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
418+
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
419+
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
407420
hidden_states = self.patchify(hidden_states)
408-
hidden_states = self.x_embedder(hidden_states)
409-
for i, block in enumerate(self.blocks):
410-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
411-
if controlnet_double_block_output is not None:
412-
interval_control = len(self.blocks) / len(controlnet_double_block_output)
413-
interval_control = int(np.ceil(interval_control))
414-
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
415-
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
416-
for i, block in enumerate(self.single_blocks):
417-
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
418-
if controlnet_single_block_output is not None:
419-
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
420-
interval_control = int(np.ceil(interval_control))
421-
hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
422-
423-
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
424-
hidden_states = self.final_norm_out(hidden_states, conditioning)
425-
hidden_states = self.final_proj_out(hidden_states)
426-
hidden_states = self.unpatchify(hidden_states, height, width)
421+
422+
with sequence_parallel(
423+
(
424+
hidden_states,
425+
prompt_emb,
426+
text_rope_emb,
427+
image_rope_emb,
428+
*controlnet_double_block_output,
429+
*controlnet_single_block_output,
430+
),
431+
seq_dims=(
432+
1,
433+
1,
434+
2,
435+
2,
436+
*(1 for _ in controlnet_double_block_output),
437+
*(1 for _ in controlnet_single_block_output),
438+
),
439+
enabled=self.use_usp,
440+
):
441+
hidden_states = self.x_embedder(hidden_states)
442+
prompt_emb = self.context_embedder(prompt_emb)
443+
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
444+
445+
for i, block in enumerate(self.blocks):
446+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
447+
if len(controlnet_double_block_output) > 0:
448+
interval_control = len(self.blocks) / len(controlnet_double_block_output)
449+
interval_control = int(np.ceil(interval_control))
450+
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
451+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
452+
for i, block in enumerate(self.single_blocks):
453+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
454+
if len(controlnet_single_block_output) > 0:
455+
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
456+
interval_control = int(np.ceil(interval_control))
457+
hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
458+
459+
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
460+
hidden_states = self.final_norm_out(hidden_states, conditioning)
461+
hidden_states = self.final_proj_out(hidden_states)
462+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
463+
464+
hidden_states = self.unpatchify(hidden_states, h, w)
427465
return hidden_states
428466

429467
@classmethod
@@ -434,6 +472,7 @@ def from_state_dict(
434472
dtype: torch.dtype,
435473
in_channel: int = 64,
436474
attn_impl: Optional[str] = None,
475+
use_usp: bool = False,
437476
):
438477
with no_init_weights():
439478
model = torch.nn.utils.skip_init(
@@ -442,6 +481,7 @@ def from_state_dict(
442481
dtype=dtype,
443482
in_channel=in_channel,
444483
attn_impl=attn_impl,
484+
use_usp=use_usp,
445485
)
446486
model = model.requires_grad_(False) # for loading gguf
447487
model.load_state_dict(state_dict, assign=True)

diffsynth_engine/models/wan/wan_dit.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import json
33
import torch
44
import torch.nn as nn
5-
import torch.distributed as dist
65
from typing import Tuple, Optional
76
from einops import rearrange
87

98
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
10-
from diffsynth_engine.models.basic.attention import attention, long_context_attention
9+
from diffsynth_engine.models.basic import attention as attention_ops
1110
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
1211
from diffsynth_engine.models.utils import no_init_weights
1312
from diffsynth_engine.utils.constants import (
@@ -17,11 +16,7 @@
1716
WAN_DIT_14B_FLF2V_CONFIG_FILE,
1817
)
1918
from diffsynth_engine.utils.gguf import gguf_inference
20-
from diffsynth_engine.utils.parallel import (
21-
get_sp_group,
22-
get_sp_world_size,
23-
get_sp_rank,
24-
)
19+
from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
2520

2621
T5_TOKEN_NUM = 512
2722
FLF_TOKEN_NUM = 257 * 2
@@ -90,20 +85,12 @@ def forward(self, x, freqs):
9085
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
9186
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
9287
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
93-
if getattr(self, "use_usp", False):
94-
x = long_context_attention(
95-
q=rope_apply(q, freqs),
96-
k=rope_apply(k, freqs),
97-
v=v,
98-
attn_impl=self.attn_impl,
99-
)
100-
else:
101-
x = attention(
102-
q=rope_apply(q, freqs),
103-
k=rope_apply(k, freqs),
104-
v=v,
105-
attn_impl=self.attn_impl,
106-
)
88+
x = attention_ops.attention(
89+
q=rope_apply(q, freqs),
90+
k=rope_apply(k, freqs),
91+
v=v,
92+
attn_impl=self.attn_impl,
93+
)
10794
x = x.flatten(2)
10895
return self.o(x)
10996

@@ -148,12 +135,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
148135
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
149136
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
150137

151-
x = attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
138+
x = attention_ops.attention(q, k, v, attn_impl=self.attn_impl).flatten(2)
152139
if self.has_image_input:
153140
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
154141
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
155142
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
156-
y = attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
143+
y = attention_ops.attention(q, k_img, v_img, attn_impl=self.attn_impl).flatten(2)
157144
x = x + y
158145
return self.o(x)
159146

@@ -316,10 +303,7 @@ def __init__(
316303
if has_image_input:
317304
self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
318305

319-
if use_usp:
320-
setattr(self, "use_usp", True)
321-
for block in self.blocks:
322-
setattr(block.self_attn, "use_usp", True)
306+
self.use_usp = use_usp
323307

324308
def patchify(self, x: torch.Tensor):
325309
x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
@@ -368,21 +352,12 @@ def forward(
368352
.reshape(f * h * w, 1, -1)
369353
.to(x.device)
370354
)
371-
if getattr(self, "use_usp", False):
372-
s, p = x.size(1), get_sp_world_size() # (sequence_length, parallelism)
373-
split_size = [s // p + 1 if i < s % p else s // p for i in range(p)]
374-
x = torch.split(x, split_size, dim=1)[get_sp_rank()]
375-
freqs = torch.split(freqs, split_size, dim=0)[get_sp_rank()]
376-
377-
for block in self.blocks:
378-
x = block(x, context, t_mod, freqs)
379-
x = self.head(x, t)
380-
381-
if getattr(self, "use_usp", False):
382-
b, d = x.size(0), x.size(2) # (batch_size, out_dim)
383-
xs = [torch.zeros((b, s, d), dtype=x.dtype, device=x.device) for s in split_size]
384-
dist.all_gather(xs, x, group=get_sp_group())
385-
x = torch.concat(xs, dim=1)
355+
356+
with sequence_parallel([x, freqs], seq_dims=(1, 0), enabled=self.use_usp):
357+
for block in self.blocks:
358+
x = block(x, context, t_mod, freqs)
359+
x = self.head(x, t)
360+
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
386361
x = self.unpatchify(x, (f, h, w))
387362
return x
388363

diffsynth_engine/pipelines/base.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,21 @@ def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[st
2626
class BasePipeline:
2727
lora_converter = LoRAStateDictConverter()
2828

29-
def __init__(self, vae_tiled, vae_tile_size, vae_tile_stride, device="cuda:0", dtype=torch.float16):
29+
def __init__(
30+
self,
31+
vae_tiled: bool = False,
32+
vae_tile_size: int = -1,
33+
vae_tile_stride: int = -1,
34+
device="cuda:0",
35+
dtype=torch.float16,
36+
):
3037
super().__init__()
31-
self.device = device
32-
self.dtype = dtype
33-
self.offload_mode = None
3438
self.vae_tiled = vae_tiled
3539
self.vae_tile_size = vae_tile_size
3640
self.vae_tile_stride = vae_tile_stride
41+
self.device = device
42+
self.dtype = dtype
43+
self.offload_mode = None
3744
self.model_names = []
3845

3946
@classmethod
@@ -199,8 +206,53 @@ def eval(self):
199206
model.eval()
200207
return self
201208

202-
def enable_fp8_linear(self):
203-
raise NotImplementedError()
209+
@staticmethod
210+
def init_parallel_config(
211+
parallelism: int,
212+
use_cfg_parallel: bool,
213+
model_config: ModelConfig,
214+
):
215+
assert parallelism in (2, 4, 8), "parallelism must be 2, 4 or 8"
216+
cfg_degree = 2 if use_cfg_parallel else 1
217+
sp_ulysses_degree = getattr(model_config, "sp_ulysses_degree", None)
218+
sp_ring_degree = getattr(model_config, "sp_ring_degree", None)
219+
tp_degree = getattr(model_config, "tp_degree", None)
220+
use_fsdp = getattr(model_config, "use_fsdp", False)
221+
222+
if tp_degree is not None:
223+
assert sp_ulysses_degree is None and sp_ring_degree is None, (
224+
"not allowed to enable sequence parallel and tensor parallel together; "
225+
"either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization"
226+
)
227+
assert use_fsdp is False, (
228+
"not allowed to enable fully sharded data parallel and tensor parallel together; "
229+
"either set use_fsdp=False or set tp_degree=None during pipeline initialization"
230+
)
231+
assert parallelism == cfg_degree * tp_degree, (
232+
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * tp_degree ({tp_degree})"
233+
)
234+
sp_ulysses_degree = 1
235+
sp_ring_degree = 1
236+
elif sp_ulysses_degree is None and sp_ring_degree is None:
237+
# use ulysses if not specified
238+
sp_ulysses_degree = parallelism // cfg_degree
239+
sp_ring_degree = 1
240+
tp_degree = 1
241+
elif sp_ulysses_degree is not None and sp_ring_degree is not None:
242+
assert parallelism == cfg_degree * sp_ulysses_degree * sp_ring_degree, (
243+
f"parallelism ({parallelism}) must be equal to cfg_degree ({cfg_degree}) * "
244+
f"sp_ulysses_degree ({sp_ulysses_degree}) * sp_ring_degree ({sp_ring_degree})"
245+
)
246+
tp_degree = 1
247+
else:
248+
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
249+
return {
250+
"cfg_degree": cfg_degree,
251+
"sp_ulysses_degree": sp_ulysses_degree,
252+
"sp_ring_degree": sp_ring_degree,
253+
"tp_degree": tp_degree,
254+
"use_fsdp": use_fsdp,
255+
}
204256

205257
@staticmethod
206258
def validate_offload_mode(offload_mode: str | None):

0 commit comments

Comments
 (0)