10
10
import mlx .nn as nn
11
11
import numpy as np
12
12
from argmaxtools .utils import get_logger
13
- from beartype .typing import Tuple
13
+ from beartype .typing import Tuple , List , Dict , Optional
14
14
15
- from .config import MMDiTConfig
15
+ from .config import MMDiTConfig , PositionalEncoding
16
16
17
17
logger = get_logger (__name__ )
18
18
@@ -26,18 +26,22 @@ def __init__(self, config: MMDiTConfig):
26
26
super ().__init__ ()
27
27
self .config = config
28
28
29
- # Check if use_pe is enabled
30
- if config .use_pe :
31
- self .pe_embedder = EmbedND (
32
- dim = config .hidden_size // config .num_heads ,
29
+ # Input adapters and embeddings
30
+ self .x_embedder = LatentImageAdapter (config )
31
+
32
+ if config .pos_embed_type == PositionalEncoding .LearnedInputEmbedding :
33
+ self .x_pos_embedder = LatentImagePositionalEmbedding (config )
34
+ self .pre_sdpa_rope = nn .Identity ()
35
+ elif config .pos_embed_type == PositionalEncoding .PreSDPARope :
36
+ self .x_pos_embedder = None
37
+ self .pre_sdpa_rope = RoPE (
33
38
theta = 10000 ,
34
- axes_dim = config .axes_dim ,
39
+ axes_dim = config .rope_axes_dim ,
35
40
)
41
+ else :
42
+ raise ValueError (f"Unsupported positional encoding type: { config .pos_embed_type } " )
36
43
37
- # Input adapters and embeddings
38
- self .x_embedder = LatentImageAdapter (config )
39
- self .x_pos_embedder = LatentImagePositionalEmbedding (config )
40
- self .y_embedder = PooledTextEmbeddingAdater (config )
44
+ self .y_embedder = PooledTextEmbeddingAdapter (config )
41
45
self .t_embedder = TimestepAdapter (config )
42
46
self .context_embedder = nn .Linear (
43
47
config .token_level_text_embed_dim ,
@@ -75,16 +79,22 @@ def __call__(
75
79
)
76
80
token_level_text_embeddings = self .context_embedder (token_level_text_embeddings )
77
81
78
- latent_image_embeddings = self .x_embedder (
79
- latent_image_embeddings
80
- ) + self .x_pos_embedder (latent_image_embeddings )
82
+ latent_image_embeddings = self .x_embedder (latent_image_embeddings )
83
+ if self .x_pos_embedder is not None :
84
+ latent_image_embeddings = latent_image_embeddings + \
85
+ self .x_pos_embedder (latent_image_embeddings )
86
+
81
87
latent_image_embeddings = latent_image_embeddings .reshape (
82
88
batch , - 1 , 1 , self .config .hidden_size
83
89
)
84
90
85
- # TODO(arda): process `ids`
86
- # ids = torch.cat((txt_ids, img_ids), dim=1)
87
- # pe = self.pe_embedder(ids)
91
+ if self .config .pos_embed_type == PositionalEncoding .PreSDPARope :
92
+ positional_encodings = self .rope (
93
+ text_sequence_length = token_level_text_embeddings .shape [1 ],
94
+ image_sequence_length = latent_image_embeddings .shape [1 ],
95
+ )
96
+ else :
97
+ positional_encodings = None
88
98
89
99
# MultiModalTransformer layers
90
100
count = 0
@@ -104,6 +114,7 @@ def __call__(
104
114
latent_image_embeddings ,
105
115
token_level_text_embeddings ,
106
116
modulation_inputs ,
117
+ positional_encodings = positional_encodings ,
107
118
)
108
119
mx .eval (latent_image_embeddings )
109
120
mx .eval (token_level_text_embeddings )
@@ -133,8 +144,12 @@ def __call__(
133
144
134
145
for block in self .unimodal_transformer_blocks :
135
146
latent_image_embeddings = block (
136
- latent_image_embeddings , modulation_inputs , None
137
- ) # FIXME(arda): positional_encodings
147
+ latent_image_embeddings ,
148
+ modulation_inputs ,
149
+ None ,
150
+ # FIXME(atiorh): RoPE is only supported for MultiModalTransformerBlock
151
+ positional_encodings = positional_encodings
152
+ )
138
153
139
154
# Final layer
140
155
latent_image_embeddings = self .final_layer (
@@ -210,7 +225,7 @@ def __call__(self, x: mx.array) -> mx.array:
210
225
return mx .repeat (w , repeats = b , axis = 0 )
211
226
212
227
213
- class PooledTextEmbeddingAdater (nn .Module ):
228
+ class PooledTextEmbeddingAdapter (nn .Module ):
214
229
def __init__ (self , config : MMDiTConfig ):
215
230
super ().__init__ ()
216
231
@@ -282,7 +297,10 @@ def __init__(self, config: MMDiTConfig, skip_post_sdpa: bool = False):
282
297
if config .use_qk_norm :
283
298
self .qk_norm = QKNorm (config .hidden_size // config .num_heads )
284
299
285
- def pre_sdpa (self , tensor : mx .array , modulation_inputs : mx .array ):
300
+ def pre_sdpa (self ,
301
+ tensor : mx .array ,
302
+ modulation_inputs : mx .array ,
303
+ ) -> Dict [str , mx .array ]:
286
304
# Project Adaptive LayerNorm modulation parameters
287
305
modulation_params = self .adaLN_modulation (modulation_inputs )
288
306
modulation_params = mx .split (
@@ -299,17 +317,14 @@ def pre_sdpa(self, tensor: mx.array, modulation_inputs: mx.array):
299
317
residual_scale = post_norm1_residual_scale ,
300
318
)
301
319
302
- results = {
303
- "q" : self .attn .q_proj (pre_attn ),
304
- "k" : self .attn .k_proj (pre_attn ),
305
- "v" : self .attn .v_proj (pre_attn ),
306
- }
320
+ q = self .attn .q_proj (pre_attn )
321
+ k = self .attn .k_proj (pre_attn )
322
+ v = self .attn .v_proj (pre_attn )
307
323
308
- # Apply QKNorm if enabled
309
324
if self .config .use_qk_norm :
310
- results [ "q" ], results [ "k" ] = self .qk_norm (
311
- results [ "q" ], results [ "k" ], results [ "v" ]
312
- )
325
+ q , k = self .qk_norm (q , k )
326
+
327
+ results = { "q" : q , "k" : k , "v" : v }
313
328
314
329
if len (modulation_params ) > 2 :
315
330
results .update (
@@ -403,9 +418,11 @@ def rearrange_for_sdpa(t):
403
418
"scale" : 1.0 / np .sqrt (self .per_head_dim ),
404
419
}
405
420
406
- # Apply rope to q, k if positional_encodings are provided
407
- if positional_encodings is not None :
408
- pass # TODO(arda): Implement rope
421
+ # TESTME(atiorh)
422
+ if self .config .pos_embed_type == PositionalEncoding .PreSDPARope :
423
+ assert positional_encodings is not None
424
+ RoPE .apply (multimodal_sdpa_inputs ["q" ], positional_encodings )
425
+ RoPE .apply (multimodal_sdpa_inputs ["k" ], positional_encodings )
409
426
410
427
# Compute multi-modal SDPA
411
428
sdpa_outputs = (
@@ -489,21 +506,19 @@ def __call__(self, x: mx.array, vec: mx.array, pe: mx.array) -> mx.array:
489
506
return x + mod .gate * output
490
507
491
508
492
- # FIXME(arda): check it
493
509
class QKNorm (nn .Module ):
494
510
def __init__ (self , head_dim ):
495
511
super ().__init__ ()
496
512
self .q_norm = nn .RMSNorm (head_dim , eps = 1e-6 )
497
513
self .k_norm = nn .RMSNorm (head_dim , eps = 1e-6 )
498
514
499
- def __call__ (
500
- self , q : mx .array , k : mx .array , v : mx .array
501
- ) -> Tuple [mx .array , mx .array ]:
502
- q = self .q_norm (q )
503
- k = self .k_norm (k )
504
- return q .astype (v .dtype ), k .astype (v .dtype )
515
+ def __call__ (self , q : mx .array , k : mx .array ) -> Tuple [mx .array , mx .array ]:
516
+ q = self .q_norm (q .astype (mx .float32 ))
517
+ k = self .k_norm (k .astype (mx .float32 ))
518
+ return q , k
505
519
506
520
521
+ # FIXME(arda): Reuse our dict impl above for modulation outputs
507
522
@dataclass
508
523
class ModulationOut :
509
524
shift : mx .array
@@ -609,33 +624,71 @@ def __call__(self, inputs: mx.array) -> mx.array:
609
624
return mx .fast .layer_norm (inputs , weight = None , bias = None , eps = self .eps )
610
625
611
626
612
- class EmbedND (nn .Module ):
613
- def __init__ (self , dim : int , theta : int , axes_dim : list [int ]):
627
+ class RoPE (nn .Module ):
628
+ """ Custom RoPE implementation for FLUX
629
+ """
630
+ def __init__ (self , theta : int , axes_dim : List [int ]) -> None :
614
631
super ().__init__ ()
615
- self .dim = dim
616
632
self .theta = theta
617
633
self .axes_dim = axes_dim
618
634
619
- def forward (self , ids : mx .array ) -> mx .array :
620
- n_axes = ids .shape [- 1 ]
621
- emb = mx .concatenate (
622
- [rope (ids [..., i ], self .axes_dim [i ], self .theta ) for i in range (n_axes )],
623
- dim = - 3 ,
624
- )
625
-
626
- return emb .unsqueeze (1 )
635
+ # Cache for consecutive identical calls
636
+ self .rope_embeddings = None
637
+ self .last_image_resolution = None
638
+ self .last_text_sequence_length = None
639
+
640
+ def _get_positions (self , latent_image_resolution : Tuple [int ], text_sequence_length : int ) -> mx .array :
641
+ h , w = latent_image_resolution
642
+ image_positions = mx .stack ([
643
+ mx .zeros ((h , w )),
644
+ mx .repeat (mx .arange (h )[:, None ], w , axis = 1 ),
645
+ mx .repeat (mx .arange (w )[None , :], h , axis = 0 ),
646
+ ], axis = - 1 ).flatten (0 , 1 ) # (h * w, 3)
647
+
648
+ text_and_image_positions = mx .concatenate ([
649
+ mx .zeros ((text_sequence_length , 3 )),
650
+ image_positions ,
651
+ ], axis = 0 )[None ] # (text_sequence_length + h * w, 3)
652
+
653
+ return text_and_image_positions
654
+
655
+ def rope (self , positions : mx .array , dim : int , theta : int = 10_000 ) -> mx .array :
656
+ def _rope_per_dim (positions , dim , theta ):
657
+ scale = mx .arange (0 , dim , 2 , dtype = mx .float32 ) / dim
658
+ omega = 1.0 / (theta ** scale )
659
+ out = mx .einsum ("bn,d->bnd" , positions , omega )
660
+ return mx .stack ([
661
+ mx .cos (out ), - mx .sin (out ), mx .sin (out ), mx .cos (out )
662
+ ], axis = - 1 ).reshape (* positions .shape , dim // 2 , 2 , 2 )
663
+
664
+ return mx .concatenate ([
665
+ _rope_per_dim (
666
+ positions = positions [..., i ],
667
+ dim = self .axes_dim [i ],
668
+ theta = self .theta
669
+ ) for i in range (len (self .axes_dim ))
670
+ ], axis = - 3 ).astype (positions .dtype )
671
+
672
+ def __call__ (self , latent_image_resolution : Tuple [int ], text_sequence_length : int ) -> mx .array :
673
+ identical_to_last_call = \
674
+ latent_image_resolution == self .last_image_resolution and \
675
+ text_sequence_length == self .last_text_sequence_length
676
+
677
+ if self .rope_embeddings is None or not identical_to_last_call :
678
+ self .last_image_resolution = latent_image_resolution
679
+ self .last_text_sequence_length = text_sequence_length
680
+ positions = self ._get_positions (latent_image_resolution , text_sequence_length )
681
+ self .rope_embeddings = self .rope (positions , self .theta )
682
+ else :
683
+ print ("Returning cached RoPE embeddings" )
627
684
685
+ return self .rope_embeddings
628
686
629
- def rope (pos : mx .array , dim : int , theta : int ) -> mx .array :
630
- assert dim % 2 == 0
631
- scale = mx .arange (0 , dim , 2 , dtype = mx .float32 ) / dim
632
- omega = 1.0 / (theta ** scale )
633
- # TODO(arda): implement this
634
- out = None
635
- # out = mx.einsum("...n,d->...nd", pos, omega)
636
- # out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
637
- # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
638
- return out .float ()
687
+ @staticmethod
688
+ def apply (q_or_k : mx .array , rope : mx .array ) -> mx .array :
689
+ in_dtype = q_or_k .dtype
690
+ q_or_k = q_or_k .astype (mx .float32 ).reshape (* q_or_k .shape [:- 1 ], - 1 , 1 , 2 )
691
+ return (rope [..., 0 ] * q_or_k [..., 0 ] + rope [..., 1 ] * q_or_k [..., 1 ]).astype (in_dtype )
639
692
640
693
641
694
def affine_transform (
0 commit comments