Skip to content

Commit 274c814

Browse files
committed
WIP RoPE embeddings, not integrated with UniModalTransformerBlock yet
1 parent 38e8ba6 commit 274c814

File tree

2 files changed

+123
-65
lines changed

2 files changed

+123
-65
lines changed

python/src/diffusionkit/mlx/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55
# Copyright (C) 2024 Argmax, Inc. All Rights Reserved.
66
#
77
from dataclasses import dataclass
8+
from enum import Enum
89
from typing import Optional, Tuple
910

1011
import mlx.core as mx
1112

1213

14+
class PositionalEncoding(Enum):
15+
LearnedInputEmbedding = 1
16+
PreSDPARope = 2
17+
18+
1319
@dataclass
1420
class MMDiTConfig:
1521
"""Multi-modal Diffusion Transformer Configuration"""
@@ -21,6 +27,7 @@ class MMDiTConfig:
2127
mlp_ratio: int = 4
2228
vae_latent_dim: int = 16 # = in_channels = out_channels
2329
layer_norm_eps: float = 1e-6
30+
pos_embed_type: PositionalEncoding = PositionalEncoding.LearnedInputEmbedding
2431

2532
@property
2633
def hidden_size(self) -> int:
@@ -47,9 +54,6 @@ def hidden_size(self) -> int:
4754
use_qk_norm: bool = False
4855
qk_scale: float = 1.0
4956

50-
# positional encoding
51-
use_pe: bool = False
52-
5357
# axes_dim
5458
axes_dim: Tuple[int] = (16, 56, 56)
5559

@@ -63,7 +67,8 @@ def hidden_size(self) -> int:
6367
depth=19,
6468
depth_unimodal=38,
6569
mlp_ratio=4,
66-
patchify_via_reshape=True)
70+
patchify_via_reshape=True,
71+
pos_embed_type=PositionalEncoding.PreSDPARope)
6772

6873

6974
@dataclass

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 114 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import mlx.nn as nn
1111
import numpy as np
1212
from argmaxtools.utils import get_logger
13-
from beartype.typing import Tuple
13+
from beartype.typing import Tuple, List, Dict, Optional
1414

15-
from .config import MMDiTConfig
15+
from .config import MMDiTConfig, PositionalEncoding
1616

1717
logger = get_logger(__name__)
1818

@@ -26,18 +26,22 @@ def __init__(self, config: MMDiTConfig):
2626
super().__init__()
2727
self.config = config
2828

29-
# Check if use_pe is enabled
30-
if config.use_pe:
31-
self.pe_embedder = EmbedND(
32-
dim=config.hidden_size // config.num_heads,
29+
# Input adapters and embeddings
30+
self.x_embedder = LatentImageAdapter(config)
31+
32+
if config.pos_embed_type == PositionalEncoding.LearnedInputEmbedding:
33+
self.x_pos_embedder = LatentImagePositionalEmbedding(config)
34+
self.pre_sdpa_rope = nn.Identity()
35+
elif config.pos_embed_type == PositionalEncoding.PreSDPARope:
36+
self.x_pos_embedder = None
37+
self.pre_sdpa_rope = RoPE(
3338
theta=10000,
34-
axes_dim=config.axes_dim,
39+
axes_dim=config.rope_axes_dim,
3540
)
41+
else:
42+
raise ValueError(f"Unsupported positional encoding type: {config.pos_embed_type}")
3643

37-
# Input adapters and embeddings
38-
self.x_embedder = LatentImageAdapter(config)
39-
self.x_pos_embedder = LatentImagePositionalEmbedding(config)
40-
self.y_embedder = PooledTextEmbeddingAdater(config)
44+
self.y_embedder = PooledTextEmbeddingAdapter(config)
4145
self.t_embedder = TimestepAdapter(config)
4246
self.context_embedder = nn.Linear(
4347
config.token_level_text_embed_dim,
@@ -75,16 +79,22 @@ def __call__(
7579
)
7680
token_level_text_embeddings = self.context_embedder(token_level_text_embeddings)
7781

78-
latent_image_embeddings = self.x_embedder(
79-
latent_image_embeddings
80-
) + self.x_pos_embedder(latent_image_embeddings)
82+
latent_image_embeddings = self.x_embedder(latent_image_embeddings)
83+
if self.x_pos_embedder is not None:
84+
latent_image_embeddings = latent_image_embeddings + \
85+
self.x_pos_embedder(latent_image_embeddings)
86+
8187
latent_image_embeddings = latent_image_embeddings.reshape(
8288
batch, -1, 1, self.config.hidden_size
8389
)
8490

85-
# TODO(arda): process `ids`
86-
# ids = torch.cat((txt_ids, img_ids), dim=1)
87-
# pe = self.pe_embedder(ids)
91+
if self.config.pos_embed_type == PositionalEncoding.PreSDPARope:
92+
positional_encodings = self.rope(
93+
text_sequence_length=token_level_text_embeddings.shape[1],
94+
image_sequence_length=latent_image_embeddings.shape[1],
95+
)
96+
else:
97+
positional_encodings = None
8898

8999
# MultiModalTransformer layers
90100
count = 0
@@ -104,6 +114,7 @@ def __call__(
104114
latent_image_embeddings,
105115
token_level_text_embeddings,
106116
modulation_inputs,
117+
positional_encodings=positional_encodings,
107118
)
108119
mx.eval(latent_image_embeddings)
109120
mx.eval(token_level_text_embeddings)
@@ -133,8 +144,12 @@ def __call__(
133144

134145
for block in self.unimodal_transformer_blocks:
135146
latent_image_embeddings = block(
136-
latent_image_embeddings, modulation_inputs, None
137-
) # FIXME(arda): positional_encodings
147+
latent_image_embeddings,
148+
modulation_inputs,
149+
None,
150+
# FIXME(atiorh): RoPE is only supported for MultiModalTransformerBlock
151+
positional_encodings=positional_encodings
152+
)
138153

139154
# Final layer
140155
latent_image_embeddings = self.final_layer(
@@ -210,7 +225,7 @@ def __call__(self, x: mx.array) -> mx.array:
210225
return mx.repeat(w, repeats=b, axis=0)
211226

212227

213-
class PooledTextEmbeddingAdater(nn.Module):
228+
class PooledTextEmbeddingAdapter(nn.Module):
214229
def __init__(self, config: MMDiTConfig):
215230
super().__init__()
216231

@@ -282,7 +297,10 @@ def __init__(self, config: MMDiTConfig, skip_post_sdpa: bool = False):
282297
if config.use_qk_norm:
283298
self.qk_norm = QKNorm(config.hidden_size // config.num_heads)
284299

285-
def pre_sdpa(self, tensor: mx.array, modulation_inputs: mx.array):
300+
def pre_sdpa(self,
301+
tensor: mx.array,
302+
modulation_inputs: mx.array,
303+
) -> Dict[str, mx.array]:
286304
# Project Adaptive LayerNorm modulation parameters
287305
modulation_params = self.adaLN_modulation(modulation_inputs)
288306
modulation_params = mx.split(
@@ -299,17 +317,14 @@ def pre_sdpa(self, tensor: mx.array, modulation_inputs: mx.array):
299317
residual_scale=post_norm1_residual_scale,
300318
)
301319

302-
results = {
303-
"q": self.attn.q_proj(pre_attn),
304-
"k": self.attn.k_proj(pre_attn),
305-
"v": self.attn.v_proj(pre_attn),
306-
}
320+
q = self.attn.q_proj(pre_attn)
321+
k = self.attn.k_proj(pre_attn)
322+
v = self.attn.v_proj(pre_attn)
307323

308-
# Apply QKNorm if enabled
309324
if self.config.use_qk_norm:
310-
results["q"], results["k"] = self.qk_norm(
311-
results["q"], results["k"], results["v"]
312-
)
325+
q, k = self.qk_norm(q, k)
326+
327+
results = {"q": q, "k": k, "v": v}
313328

314329
if len(modulation_params) > 2:
315330
results.update(
@@ -403,9 +418,11 @@ def rearrange_for_sdpa(t):
403418
"scale": 1.0 / np.sqrt(self.per_head_dim),
404419
}
405420

406-
# Apply rope to q, k if positional_encodings are provided
407-
if positional_encodings is not None:
408-
pass # TODO(arda): Implement rope
421+
# TESTME(atiorh)
422+
if self.config.pos_embed_type == PositionalEncoding.PreSDPARope:
423+
assert positional_encodings is not None
424+
RoPE.apply(multimodal_sdpa_inputs["q"], positional_encodings)
425+
RoPE.apply(multimodal_sdpa_inputs["k"], positional_encodings)
409426

410427
# Compute multi-modal SDPA
411428
sdpa_outputs = (
@@ -489,21 +506,19 @@ def __call__(self, x: mx.array, vec: mx.array, pe: mx.array) -> mx.array:
489506
return x + mod.gate * output
490507

491508

492-
# FIXME(arda): check it
493509
class QKNorm(nn.Module):
494510
def __init__(self, head_dim):
495511
super().__init__()
496512
self.q_norm = nn.RMSNorm(head_dim, eps=1e-6)
497513
self.k_norm = nn.RMSNorm(head_dim, eps=1e-6)
498514

499-
def __call__(
500-
self, q: mx.array, k: mx.array, v: mx.array
501-
) -> Tuple[mx.array, mx.array]:
502-
q = self.q_norm(q)
503-
k = self.k_norm(k)
504-
return q.astype(v.dtype), k.astype(v.dtype)
515+
def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]:
516+
q = self.q_norm(q.astype(mx.float32))
517+
k = self.k_norm(k.astype(mx.float32))
518+
return q, k
505519

506520

521+
# FIXME(arda): Reuse our dict impl above for modulation outputs
507522
@dataclass
508523
class ModulationOut:
509524
shift: mx.array
@@ -609,33 +624,71 @@ def __call__(self, inputs: mx.array) -> mx.array:
609624
return mx.fast.layer_norm(inputs, weight=None, bias=None, eps=self.eps)
610625

611626

612-
class EmbedND(nn.Module):
613-
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
627+
class RoPE(nn.Module):
628+
""" Custom RoPE implementation for FLUX
629+
"""
630+
def __init__(self, theta: int, axes_dim: List[int]) -> None:
614631
super().__init__()
615-
self.dim = dim
616632
self.theta = theta
617633
self.axes_dim = axes_dim
618634

619-
def forward(self, ids: mx.array) -> mx.array:
620-
n_axes = ids.shape[-1]
621-
emb = mx.concatenate(
622-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
623-
dim=-3,
624-
)
625-
626-
return emb.unsqueeze(1)
635+
# Cache for consecutive identical calls
636+
self.rope_embeddings = None
637+
self.last_image_resolution = None
638+
self.last_text_sequence_length = None
639+
640+
def _get_positions(self, latent_image_resolution: Tuple[int], text_sequence_length: int) -> mx.array:
641+
h, w = latent_image_resolution
642+
image_positions = mx.stack([
643+
mx.zeros((h, w)),
644+
mx.repeat(mx.arange(h)[:, None], w, axis=1),
645+
mx.repeat(mx.arange(w)[None, :], h, axis=0),
646+
], axis=-1).flatten(0, 1) # (h * w, 3)
647+
648+
text_and_image_positions = mx.concatenate([
649+
mx.zeros((text_sequence_length, 3)),
650+
image_positions,
651+
], axis=0)[None] # (text_sequence_length + h * w, 3)
652+
653+
return text_and_image_positions
654+
655+
def rope(self, positions: mx.array, dim: int, theta: int = 10_000) -> mx.array:
656+
def _rope_per_dim(positions, dim, theta):
657+
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
658+
omega = 1.0 / (theta ** scale)
659+
out = mx.einsum("bn,d->bnd", positions, omega)
660+
return mx.stack([
661+
mx.cos(out), -mx.sin(out), mx.sin(out), mx.cos(out)
662+
], axis=-1).reshape(*positions.shape, dim // 2, 2, 2)
663+
664+
return mx.concatenate([
665+
_rope_per_dim(
666+
positions=positions[..., i],
667+
dim=self.axes_dim[i],
668+
theta=self.theta
669+
) for i in range(len(self.axes_dim))
670+
], axis=-3).astype(positions.dtype)
671+
672+
def __call__(self, latent_image_resolution: Tuple[int], text_sequence_length: int) -> mx.array:
673+
identical_to_last_call = \
674+
latent_image_resolution == self.last_image_resolution and \
675+
text_sequence_length == self.last_text_sequence_length
676+
677+
if self.rope_embeddings is None or not identical_to_last_call:
678+
self.last_image_resolution = latent_image_resolution
679+
self.last_text_sequence_length = text_sequence_length
680+
positions = self._get_positions(latent_image_resolution, text_sequence_length)
681+
self.rope_embeddings = self.rope(positions, self.theta)
682+
else:
683+
print("Returning cached RoPE embeddings")
627684

685+
return self.rope_embeddings
628686

629-
def rope(pos: mx.array, dim: int, theta: int) -> mx.array:
630-
assert dim % 2 == 0
631-
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
632-
omega = 1.0 / (theta**scale)
633-
# TODO(arda): implement this
634-
out = None
635-
# out = mx.einsum("...n,d->...nd", pos, omega)
636-
# out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
637-
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
638-
return out.float()
687+
@staticmethod
688+
def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
689+
in_dtype = q_or_k.dtype
690+
q_or_k = q_or_k.astype(mx.float32).reshape(*q_or_k.shape[:-1], -1, 1, 2)
691+
return (rope[..., 0] * q_or_k[..., 0] + rope[..., 1] * q_or_k[..., 1]).astype(in_dtype)
639692

640693

641694
def affine_transform(

0 commit comments

Comments
 (0)