Skip to content

Commit 8384f40

Browse files
committed
Final tests for SD3 and FLUX
1 parent eb12a15 commit 8384f40

File tree

2 files changed

+59
-55
lines changed

2 files changed

+59
-55
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def read_image(self, image_path: str):
495495
# Make sure image shape is divisible by 64
496496
W, H = (dim - dim % 64 for dim in (img.width, img.height))
497497
if W != img.width or H != img.height:
498-
print(
498+
logger.warning(
499499
f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}"
500500
)
501501
img = img.resize((W, H), Image.LANCZOS) # use desired downsampling filter
@@ -629,13 +629,21 @@ def __init__(self, model: DiffusionPipeline):
629629
self.model = model
630630

631631
def cache_modulation_params(self, pooled_text_embeddings, sigmas):
632-
self.model.mmdit.cache_modulation_params(pooled_text_embeddings, sigmas)
632+
self.model.mmdit.cache_modulation_params(
633+
pooled_text_embeddings, sigmas.astype(self.model.activation_dtype)
634+
)
633635

634636
def clear_cache(self):
635637
self.model.mmdit.clear_modulation_params_cache()
636638

637639
def __call__(
638-
self, x_t, t, conditioning, cfg_weight: float = 7.5, pooled_conditioning=None
640+
self,
641+
x_t,
642+
timestep,
643+
sigma,
644+
conditioning,
645+
cfg_weight: float = 7.5,
646+
pooled_conditioning=None,
639647
):
640648
if cfg_weight <= 0:
641649
logger.debug("CFG Weight disabled")
@@ -644,20 +652,14 @@ def __call__(
644652
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
645653
self.model.activation_dtype
646654
)
647-
t_mmdit = mx.broadcast_to(t, [len(x_t_mmdit)])
648-
timestep = self.model.sampler.timestep(t_mmdit).astype(
649-
self.model.activation_dtype
650-
)
651655
mmdit_input = {
652656
"latent_image_embeddings": x_t_mmdit,
653657
"token_level_text_embeddings": mx.expand_dims(conditioning, 2),
654-
"timestep": timestep,
658+
"timestep": mx.broadcast_to(timestep, [len(x_t_mmdit)]),
655659
}
656660

657661
mmdit_output = self.model.mmdit(**mmdit_input)
658-
eps_pred = self.model.sampler.calculate_denoised(
659-
t_mmdit, mmdit_output, x_t_mmdit
660-
)
662+
eps_pred = self.model.sampler.calculate_denoised(sigma, mmdit_output, x_t_mmdit)
661663
if cfg_weight <= 0:
662664
return eps_pred
663665
else:
@@ -707,21 +709,22 @@ def to_d(x, sigma, denoised):
707709
def sample_euler(model: CFGDenoiser, x, sigmas, extra_args=None):
708710
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
709711
extra_args = {} if extra_args is None else extra_args
710-
s_in = mx.ones([x.shape[0]])
712+
711713
from tqdm import trange
712714

713-
sigmas = mx.array([1.0, 0.75, 0.5, 0.25, 0.0], mx.bfloat16) # FIXME
714715
t = trange(len(sigmas) - 1)
715716

716-
model.cache_modulation_params(extra_args.pop("pooled_conditioning"), sigmas)
717+
timesteps = model.model.sampler.timestep(sigmas).astype(
718+
model.model.activation_dtype
719+
)
720+
model.cache_modulation_params(extra_args.pop("pooled_conditioning"), timesteps)
717721

718722
iter_time = []
719723
for i in t:
720724
start_time = t.format_dict["elapsed"]
721-
sigma_hat = sigmas[i]
722-
denoised = model(x, sigma_hat * s_in, **extra_args)
723-
d = to_d(x, sigma_hat, denoised)
724-
dt = sigmas[i + 1] - sigma_hat
725+
denoised = model(x, timesteps[i], sigmas[i], **extra_args)
726+
d = to_d(x, sigmas[i], denoised)
727+
dt = sigmas[i + 1] - sigmas[i]
725728
# Euler method
726729
x = x + d * dt
727730
mx.eval(x)

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
# Copyright (C) 2024 Argmax, Inc. All Rights Reserved.
44
#
55

6-
import gc
76
from functools import partial
87

98
import mlx.core as mx
109
import mlx.nn as nn
11-
import mlx.utils as utils
1210
import numpy as np
1311
from argmaxtools.utils import get_logger
1412
from beartype.typing import Dict, List, Optional, Tuple
@@ -77,35 +75,31 @@ def cache_modulation_params(
7775
by offloading all adaLN_modulation parameters
7876
"""
7977
y_embed = self.y_embedder(pooled_text_embeddings)
78+
batch_size = pooled_text_embeddings.shape[0]
8079

8180
offload_size = 0
8281
to_offload = []
8382

8483
for timestep in timesteps:
8584
final_timestep = timestep.item() == timesteps[-1].item()
86-
modulation_inputs = y_embed + self.t_embedder(timestep[None] * 1000.0)
85+
timestep_key = timestep.item()
86+
modulation_inputs = y_embed[:, None, None, :] + self.t_embedder(
87+
mx.repeat(timestep[None], batch_size, axis=0)
88+
)
8789

8890
for block in self.multimodal_transformer_blocks:
8991
if not hasattr(block.image_transformer_block, "_modulation_params"):
9092
block.image_transformer_block._modulation_params = dict()
9193
block.text_transformer_block._modulation_params = dict()
9294

9395
block.image_transformer_block._modulation_params[
94-
(timestep * 1000).item()
96+
timestep_key
9597
] = block.image_transformer_block.adaLN_modulation(modulation_inputs)
9698
block.text_transformer_block._modulation_params[
97-
(timestep * 1000).item()
99+
timestep_key
98100
] = block.text_transformer_block.adaLN_modulation(modulation_inputs)
99-
mx.eval(
100-
block.image_transformer_block._modulation_params[
101-
(timestep * 1000).item()
102-
]
103-
)
104-
mx.eval(
105-
block.text_transformer_block._modulation_params[
106-
(timestep * 1000).item()
107-
]
108-
)
101+
mx.eval(block.image_transformer_block._modulation_params[timestep_key])
102+
mx.eval(block.text_transformer_block._modulation_params[timestep_key])
109103

110104
if final_timestep:
111105
offload_size += (
@@ -131,33 +125,34 @@ def cache_modulation_params(
131125
]
132126
)
133127

134-
for block in self.unified_transformer_blocks:
135-
if not hasattr(block.transformer_block, "_modulation_params"):
136-
block.transformer_block._modulation_params = dict()
137-
block.transformer_block._modulation_params[
138-
(timestep * 1000).item()
139-
] = block.transformer_block.adaLN_modulation(modulation_inputs)
140-
mx.eval(
141-
block.transformer_block._modulation_params[(timestep * 1000).item()]
142-
)
143-
144-
if final_timestep:
145-
offload_size += (
146-
block.transformer_block.adaLN_modulation.layers[1].weight.size
147-
* block.transformer_block.adaLN_modulation.layers[
148-
1
149-
].weight.dtype.size
150-
)
151-
to_offload.extend(
152-
[block.transformer_block.adaLN_modulation.layers[1]]
153-
)
128+
if self.config.depth_unified > 0:
129+
for block in self.unified_transformer_blocks:
130+
if not hasattr(block.transformer_block, "_modulation_params"):
131+
block.transformer_block._modulation_params = dict()
132+
block.transformer_block._modulation_params[
133+
timestep_key
134+
] = block.transformer_block.adaLN_modulation(modulation_inputs)
135+
mx.eval(block.transformer_block._modulation_params[timestep_key])
136+
137+
if final_timestep:
138+
offload_size += (
139+
block.transformer_block.adaLN_modulation.layers[
140+
1
141+
].weight.size
142+
* block.transformer_block.adaLN_modulation.layers[
143+
1
144+
].weight.dtype.size
145+
)
146+
to_offload.extend(
147+
[block.transformer_block.adaLN_modulation.layers[1]]
148+
)
154149

155150
if not hasattr(self.final_layer, "_modulation_params"):
156151
self.final_layer._modulation_params = dict()
157152
self.final_layer._modulation_params[
158-
(timestep * 1000).item()
153+
timestep_key
159154
] = self.final_layer.adaLN_modulation(modulation_inputs)
160-
mx.eval(self.final_layer._modulation_params[(timestep * 1000).item()])
155+
mx.eval(self.final_layer._modulation_params[timestep_key])
161156

162157
if final_timestep:
163158
offload_size += (
@@ -246,6 +241,7 @@ def __call__(
246241
latent_image_embeddings,
247242
timestep,
248243
)
244+
249245
if self.config.patchify_via_reshape:
250246
latent_image_embeddings = self.x_embedder.unpack(
251247
latent_image_embeddings, (latent_height, latent_width)
@@ -437,7 +433,10 @@ def pre_sdpa(
437433
tensor: mx.array,
438434
timestep: mx.array,
439435
) -> Dict[str, mx.array]:
436+
if timestep.size > 1:
437+
timestep = timestep[0]
440438
modulation_params = self._modulation_params[timestep.item()]
439+
441440
modulation_params = mx.split(
442441
modulation_params, self.num_modulation_params, axis=-1
443442
)
@@ -771,6 +770,8 @@ def __call__(
771770
latent_image_embeddings: mx.array,
772771
timestep: mx.array,
773772
) -> mx.array:
773+
if timestep.size > 1:
774+
timestep = timestep[0]
774775
modulation_params = self._modulation_params[timestep.item()]
775776

776777
shift, residual_scale = mx.split(modulation_params, 2, axis=-1)

0 commit comments

Comments
 (0)