Skip to content

Commit 1577412

Browse files
Apply suggestions from code review for flux dev
1 parent 1508f9f commit 1577412

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,16 @@ def __call__(
214214
else:
215215
positional_encodings = None
216216

217-
timestep_embedding = self.guidance_in(self.t_embedder(timestep))
217+
if config.guidance_embed:
218+
timestep = self.guidance_in(self.t_embedder(timestep))
218219

219220
# MultiModalTransformer layers
220221
if self.config.depth_multimodal > 0:
221222
for bidx, block in enumerate(self.multimodal_transformer_blocks):
222223
latent_image_embeddings, token_level_text_embeddings = block(
223224
latent_image_embeddings,
224225
token_level_text_embeddings,
225-
timestep_embedding,
226+
timestep,
226227
positional_encodings=positional_encodings,
227228
)
228229

@@ -245,7 +246,7 @@ def __call__(
245246

246247
latent_image_embeddings = self.final_layer(
247248
latent_image_embeddings,
248-
timestep_embedding
249+
timestep,
249250
)
250251

251252
if self.config.patchify_via_reshape:

python/src/diffusionkit/mlx/model_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
4747
"vae": "ae.safetensors",
4848
},
49-
"raoulritter/flux-dev-mlx": {
50-
"FLUX.1-dev": "flux1-dev-mlx.safetensors",
49+
"argmaxinc/mlx-FLUX.1-dev": {
50+
"argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors",
5151
"vae": "ae.safetensors",
5252
},
5353
}
@@ -79,7 +79,7 @@
7979
"vae_encoder": "encoder.",
8080
"vae_decoder": "decoder.",
8181
},
82-
"raoulritter/flux-dev-mlx": {
82+
"argmaxinc/mlx-FLUX.1-dev": {
8383
"vae_encoder": "encoder.",
8484
"vae_decoder": "decoder.",
8585
},

0 commit comments

Comments
 (0)