13
13
)
14
14
from diffsynth_engine .models .basic .timestep import TimestepEmbeddings
15
15
from diffsynth_engine .models .base import PreTrainedModel , StateDictConverter
16
+ from diffsynth_engine .models .basic import attention as attention_ops
16
17
from diffsynth_engine .models .utils import no_init_weights
17
18
from diffsynth_engine .utils .gguf import gguf_inference
18
19
from diffsynth_engine .utils .fp8_linear import fp8_inference
19
20
from diffsynth_engine .utils .constants import FLUX_DIT_CONFIG_FILE
20
- from diffsynth_engine .models . basic . attention import attention
21
+ from diffsynth_engine .utils . parallel import sequence_parallel , sequence_parallel_unshard
21
22
from diffsynth_engine .utils import logging
22
23
23
24
@@ -198,7 +199,7 @@ def forward(self, image, text, rope_emb, image_emb):
198
199
k = torch .cat ([self .norm_k_b (k_b ), self .norm_k_a (k_a )], dim = 1 )
199
200
v = torch .cat ([v_b , v_a ], dim = 1 )
200
201
q , k = apply_rope (q , k , rope_emb )
201
- attn_out = attention (q , k , v , attn_impl = self .attn_impl )
202
+ attn_out = attention_ops . attention (q , k , v , attn_impl = self .attn_impl )
202
203
attn_out = rearrange (attn_out , "b s h d -> b s (h d)" ).to (q .dtype )
203
204
text_out , image_out = attn_out [:, : text .shape [1 ]], attn_out [:, text .shape [1 ] :]
204
205
image_out , text_out = self .attention_callback (
@@ -286,7 +287,7 @@ def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb):
286
287
def forward (self , x , rope_emb , image_emb ):
287
288
q , k , v = rearrange (self .to_qkv (x ), "b s (h d) -> b s h d" , h = (3 * self .num_heads )).chunk (3 , dim = 2 )
288
289
q , k = apply_rope (self .norm_q_a (q ), self .norm_k_a (k ), rope_emb )
289
- attn_out = attention (q , k , v , attn_impl = self .attn_impl )
290
+ attn_out = attention_ops . attention (q , k , v , attn_impl = self .attn_impl )
290
291
attn_out = rearrange (attn_out , "b s h d -> b s (h d)" ).to (q .dtype )
291
292
return self .attention_callback (attn_out = attn_out , x = x , q = q , k = k , v = v , rope_emb = rope_emb , image_emb = image_emb )
292
293
@@ -324,6 +325,7 @@ def __init__(
324
325
self ,
325
326
in_channel : int = 64 ,
326
327
attn_impl : Optional [str ] = None ,
328
+ use_usp : bool = False ,
327
329
device : str = "cuda:0" ,
328
330
dtype : torch .dtype = torch .bfloat16 ,
329
331
):
@@ -349,6 +351,8 @@ def __init__(
349
351
self .final_norm_out = AdaLayerNorm (3072 , device = device , dtype = dtype )
350
352
self .final_proj_out = nn .Linear (3072 , 64 , device = device , dtype = dtype )
351
353
354
+ self .use_usp = use_usp
355
+
352
356
def patchify (self , hidden_states ):
353
357
hidden_states = rearrange (hidden_states , "B C (H P) (W Q) -> B (H W) (C P Q)" , P = 2 , Q = 2 )
354
358
return hidden_states
@@ -359,7 +363,8 @@ def unpatchify(self, hidden_states, height, width):
359
363
)
360
364
return hidden_states
361
365
362
- def prepare_image_ids (self , latents ):
366
+ @staticmethod
367
+ def prepare_image_ids (latents : torch .Tensor ):
363
368
batch_size , _ , height , width = latents .shape
364
369
latent_image_ids = torch .zeros (height // 2 , width // 2 , 3 )
365
370
latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height // 2 )[:, None ]
@@ -389,7 +394,14 @@ def forward(
389
394
controlnet_single_block_output = None ,
390
395
** kwargs ,
391
396
):
392
- height , width = hidden_states .shape [- 2 :]
397
+ h , w = hidden_states .shape [- 2 :]
398
+ controlnet_double_block_output = (
399
+ controlnet_double_block_output if controlnet_double_block_output is not None else ()
400
+ )
401
+ controlnet_single_block_output = (
402
+ controlnet_single_block_output if controlnet_single_block_output is not None else ()
403
+ )
404
+
393
405
fp8_linear_enabled = getattr (self , "fp8_linear_enabled" , False )
394
406
with fp8_inference (fp8_linear_enabled ), gguf_inference ():
395
407
if image_ids is None :
@@ -402,28 +414,54 @@ def forward(
402
414
guidance = guidance * 1000
403
415
conditioning += self .guidance_embedder (guidance , hidden_states .dtype )
404
416
conditioning += self .pooled_text_embedder (pooled_prompt_emb )
405
- prompt_emb = self .context_embedder (prompt_emb )
406
417
rope_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
418
+ text_rope_emb = rope_emb [:, :, : text_ids .size (1 )]
419
+ image_rope_emb = rope_emb [:, :, text_ids .size (1 ) :]
407
420
hidden_states = self .patchify (hidden_states )
408
- hidden_states = self .x_embedder (hidden_states )
409
- for i , block in enumerate (self .blocks ):
410
- hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , rope_emb , image_emb )
411
- if controlnet_double_block_output is not None :
412
- interval_control = len (self .blocks ) / len (controlnet_double_block_output )
413
- interval_control = int (np .ceil (interval_control ))
414
- hidden_states = hidden_states + controlnet_double_block_output [i // interval_control ]
415
- hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
416
- for i , block in enumerate (self .single_blocks ):
417
- hidden_states = block (hidden_states , conditioning , rope_emb , image_emb )
418
- if controlnet_single_block_output is not None :
419
- interval_control = len (self .single_blocks ) / len (controlnet_double_block_output )
420
- interval_control = int (np .ceil (interval_control ))
421
- hidden_states = hidden_states + controlnet_single_block_output [i // interval_control ]
422
-
423
- hidden_states = hidden_states [:, prompt_emb .shape [1 ] :]
424
- hidden_states = self .final_norm_out (hidden_states , conditioning )
425
- hidden_states = self .final_proj_out (hidden_states )
426
- hidden_states = self .unpatchify (hidden_states , height , width )
421
+
422
+ with sequence_parallel (
423
+ (
424
+ hidden_states ,
425
+ prompt_emb ,
426
+ text_rope_emb ,
427
+ image_rope_emb ,
428
+ * controlnet_double_block_output ,
429
+ * controlnet_single_block_output ,
430
+ ),
431
+ seq_dims = (
432
+ 1 ,
433
+ 1 ,
434
+ 2 ,
435
+ 2 ,
436
+ * (1 for _ in controlnet_double_block_output ),
437
+ * (1 for _ in controlnet_single_block_output ),
438
+ ),
439
+ enabled = self .use_usp ,
440
+ ):
441
+ hidden_states = self .x_embedder (hidden_states )
442
+ prompt_emb = self .context_embedder (prompt_emb )
443
+ rope_emb = torch .cat ((text_rope_emb , image_rope_emb ), dim = 2 )
444
+
445
+ for i , block in enumerate (self .blocks ):
446
+ hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , rope_emb , image_emb )
447
+ if len (controlnet_double_block_output ) > 0 :
448
+ interval_control = len (self .blocks ) / len (controlnet_double_block_output )
449
+ interval_control = int (np .ceil (interval_control ))
450
+ hidden_states = hidden_states + controlnet_double_block_output [i // interval_control ]
451
+ hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
452
+ for i , block in enumerate (self .single_blocks ):
453
+ hidden_states = block (hidden_states , conditioning , rope_emb , image_emb )
454
+ if len (controlnet_single_block_output ) > 0 :
455
+ interval_control = len (self .single_blocks ) / len (controlnet_double_block_output )
456
+ interval_control = int (np .ceil (interval_control ))
457
+ hidden_states = hidden_states + controlnet_single_block_output [i // interval_control ]
458
+
459
+ hidden_states = hidden_states [:, prompt_emb .shape [1 ] :]
460
+ hidden_states = self .final_norm_out (hidden_states , conditioning )
461
+ hidden_states = self .final_proj_out (hidden_states )
462
+ (hidden_states ,) = sequence_parallel_unshard ((hidden_states ,), seq_dims = (1 ,), seq_lens = (h * w // 4 ,))
463
+
464
+ hidden_states = self .unpatchify (hidden_states , h , w )
427
465
return hidden_states
428
466
429
467
@classmethod
@@ -434,6 +472,7 @@ def from_state_dict(
434
472
dtype : torch .dtype ,
435
473
in_channel : int = 64 ,
436
474
attn_impl : Optional [str ] = None ,
475
+ use_usp : bool = False ,
437
476
):
438
477
with no_init_weights ():
439
478
model = torch .nn .utils .skip_init (
@@ -442,6 +481,7 @@ def from_state_dict(
442
481
dtype = dtype ,
443
482
in_channel = in_channel ,
444
483
attn_impl = attn_impl ,
484
+ use_usp = use_usp ,
445
485
)
446
486
model = model .requires_grad_ (False ) # for loading gguf
447
487
model .load_state_dict (state_dict , assign = True )
0 commit comments