Skip to content

Commit 3d1b7b7

Browse files
authored
Merge pull request #20 from argmaxinc/atila/flux_latency_opt
FLUX latency optimizations
2 parents 6c5de95 + fc551f3 commit 3d1b7b7

File tree

4 files changed

+22
-33
lines changed

4 files changed

+22
-33
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def generate_image(
467467
logger.info(
468468
f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB"
469469
)
470-
latents = latents.astype(mx.float32)
470+
latents = latents.astype(self.activation_dtype)
471471
decoded = self.decode_latents_to_image(latents)
472472
mx.eval(decoded)
473473

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
logger = get_logger(__name__)
1818

19-
SDPA_FLASH_ATTN_THRESHOLD = 1000
19+
SDPA_FLASH_ATTN_THRESHOLD = 1024
2020

2121

2222
class MMDiT(nn.Module):
@@ -218,8 +218,6 @@ def __call__(
218218
timestep,
219219
positional_encodings=positional_encodings,
220220
)
221-
mx.eval(latent_image_embeddings)
222-
mx.eval(token_level_text_embeddings)
223221

224222
# UnifiedTransformerBlock layers
225223
if self.config.depth_unified > 0:
@@ -449,9 +447,10 @@ def pre_sdpa(
449447
# LayerNorm and modulate before SDPA
450448
try:
451449
modulated_pre_attention = affine_transform(
452-
self.norm1(tensor),
450+
tensor,
453451
shift=post_norm1_shift,
454452
residual_scale=post_norm1_residual_scale,
453+
norm_module=self.norm1,
455454
)
456455
except Exception as e:
457456
logger.error(
@@ -531,9 +530,10 @@ def post_sdpa(
531530
# Apply separate modulation parameters and LayerNorm across attn and mlp
532531
mlp_out = self.mlp(
533532
affine_transform(
534-
self.norm2(residual),
533+
residual,
535534
shift=post_norm2_shift,
536535
residual_scale=post_norm2_residual_scale,
536+
norm_module=self.norm2,
537537
)
538538
)
539539
return residual + post_mlp_scale * mlp_out
@@ -749,8 +749,9 @@ def __init__(self, head_dim):
749749
self.k_norm = nn.RMSNorm(head_dim, eps=1e-6)
750750

751751
def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]:
752-
q = self.q_norm(q.astype(mx.float32))
753-
k = self.k_norm(k.astype(mx.float32))
752+
# Note: mlx.nn.RMSNorm has high precision accumulation (does not require upcasting)
753+
q = self.q_norm(q)
754+
k = self.k_norm(k)
754755
return q, k
755756

756757

@@ -778,9 +779,10 @@ def __call__(
778779

779780
shift, residual_scale = mx.split(modulation_params, 2, axis=-1)
780781
latent_image_embeddings = affine_transform(
781-
self.norm_final(latent_image_embeddings),
782+
latent_image_embeddings,
782783
shift=shift,
783784
residual_scale=residual_scale,
785+
norm_module=self.norm_final,
784786
)
785787
return self.linear(latent_image_embeddings)
786788

@@ -932,9 +934,16 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
932934

933935

934936
def affine_transform(
935-
x: mx.array, shift: mx.array, residual_scale: mx.array
937+
x: mx.array,
938+
shift: mx.array,
939+
residual_scale: mx.array,
940+
norm_module: nn.Module = None,
936941
) -> mx.array:
937942
"""Affine transformation (Used for Adaptive LayerNorm Modulation)"""
943+
if norm_module is not None:
944+
return mx.fast.layer_norm(
945+
x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps
946+
)
938947
return x * (1.0 + residual_scale) + shift
939948

940949

python/src/diffusionkit/mlx/vae.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,15 @@ def __init__(
8484
self.conv_shortcut = nn.Linear(in_channels, out_channels)
8585

8686
def __call__(self, x, temb=None):
87-
dtype = x.dtype
88-
8987
if temb is not None:
9088
temb = self.time_emb_proj(nn.silu(temb))
9189

92-
y = self.norm1(x.astype(mx.float32)).astype(dtype)
90+
y = self.norm1(x)
9391
y = nn.silu(y)
9492
y = self.conv1(y)
9593
if temb is not None:
9694
y = y + temb[:, None, None, :]
97-
y = self.norm2(y.astype(mx.float32)).astype(dtype)
95+
y = self.norm2(y)
9896
y = nn.silu(y)
9997
y = self.conv2(y)
10098

@@ -386,37 +384,19 @@ def __init__(
386384
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
387385

388386
def __call__(self, x):
389-
t = x.dtype
390387
x = self.conv_in(x)
391388

392389
x = self.mid_blocks[0](x)
393-
if mx.isnan(x).any():
394-
raise ValueError("NaN detected in VAE Decoder after mid_blocks[0]")
395-
x = x.astype(mx.float32)
396390
x = self.mid_blocks[1](x)
397-
if mx.isnan(x).any():
398-
raise ValueError("NaN detected in VAE Decoder after mid_blocks[1]")
399-
x = x.astype(t)
400391
x = self.mid_blocks[2](x)
401-
if mx.isnan(x).any():
402-
raise ValueError("NaN detected in VAE Decoder after mid_blocks[2]")
403392

404393
for l in reversed(self.up_blocks):
405394
x = l(x)
406395
mx.eval(x)
407396

408-
if mx.isnan(x).any():
409-
raise ValueError("NaN detected in VAE Decoder after up_blocks")
410-
411-
x = x.astype(mx.float32)
412397
x = self.conv_norm_out(x)
413-
if mx.isnan(x).any():
414-
raise ValueError("NaN detected in VAE Decoder after conv_norm_out")
415-
x = x.astype(t)
416398
x = nn.silu(x)
417399
x = self.conv_out(x)
418-
if mx.isnan(x).any():
419-
raise ValueError("NaN detected in VAE Decoder after conv_out")
420400

421401
return x
422402

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from setuptools import find_packages, setup
44
from setuptools.command.install import install
55

6-
VERSION = "0.3.0"
6+
VERSION = "0.3.1"
77

88

99
class VersionInstallCommand(install):

0 commit comments

Comments
 (0)