Skip to content

Commit 925ef5d

Browse files
authored
Merge pull request #41 from argmaxinc/sd3.5_large
Model support: stable-diffusion-3.5-large
2 parents e63501d + 4cfefc7 commit 925ef5d

File tree

6 files changed

+41
-10
lines changed

6 files changed

+41
-10
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@
3636

3737
MMDIT_CKPT = {
3838
"argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium",
39-
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
39+
"argmaxinc/mlx-stable-diffusion-3.5-large": "argmaxinc/mlx-stable-diffusion-3.5-large",
4040
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
4141
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
4242
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
4343
}
4444

4545
T5_MAX_LENGTH = {
4646
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
47+
"argmaxinc/mlx-stable-diffusion-3.5-large": 512,
4748
"argmaxinc/mlx-FLUX.1-schnell": 256,
4849
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
4950
"argmaxinc/mlx-FLUX.1-dev": 512,

python/src/diffusionkit/mlx/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def hidden_size(self) -> int:
7171
guidance_embed: bool = False
7272

7373

74-
SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35])
74+
SD3_8b = MMDiTConfig(
75+
depth_multimodal=38, num_heads=38, upcast_multimodal_blocks=[35], use_qk_norm=True
76+
)
7577

7678
SD3_2b = MMDiTConfig(
7779
depth_multimodal=24, num_heads=24, float16_dtype=mx.float16, dtype=mx.float16

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def __init__(
818818
self.kv_proj_embed_dim = self.per_head_dim * n_heads
819819

820820
# Note: key bias is redundant due to softmax invariance
821-
self.k_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim)
821+
self.k_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim, bias=False)
822822
self.q_proj = nn.Linear(embed_dim, embed_dim)
823823
self.v_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim)
824824
self.o_proj = nn.Linear(embed_dim, embed_dim)

python/src/diffusionkit/mlx/model_io.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AutoencoderConfig,
2121
CLIPTextModelConfig,
2222
SD3_2b,
23+
SD3_8b,
2324
VAEDecoderConfig,
2425
VAEEncoderConfig,
2526
)
@@ -50,6 +51,10 @@
5051
"argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors",
5152
"vae": "ae.safetensors",
5253
},
54+
"argmaxinc/mlx-stable-diffusion-3.5-large": {
55+
"argmaxinc/mlx-stable-diffusion-3.5-large": "sd3.5_large.safetensors",
56+
"vae": "sd3.5_large.safetensors",
57+
},
5358
}
5459
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
5560
_MODELS = {
@@ -83,17 +88,29 @@
8388
"vae_encoder": "encoder.",
8489
"vae_decoder": "decoder.",
8590
},
91+
"argmaxinc/mlx-stable-diffusion-3.5-large": {
92+
"vae_encoder": "first_stage_model.encoder.",
93+
"vae_decoder": "first_stage_model.decoder.",
94+
},
95+
}
96+
97+
_CONFIG = {
98+
"argmaxinc/mlx-stable-diffusion-3-medium": SD3_2b,
99+
"argmaxinc/mlx-FLUX.1-schnell": FLUX_SCHNELL,
100+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": FLUX_SCHNELL,
101+
"argmaxinc/mlx-FLUX.1-dev": FLUX_SCHNELL,
102+
"argmaxinc/mlx-stable-diffusion-3.5-large": SD3_8b,
86103
}
87104

88105
_FLOAT16 = mx.bfloat16
89106

90107
DEPTH = {
91108
"argmaxinc/mlx-stable-diffusion-3-medium": 24,
92-
"sd3-8b-unreleased": 38,
109+
"argmaxinc/mlx-stable-diffusion-3.5-large": 38,
93110
}
94111
MAX_LATENT_RESOLUTION = {
95112
"argmaxinc/mlx-stable-diffusion-3-medium": 96,
96-
"sd3-8b-unreleased": 192,
113+
"argmaxinc/mlx-stable-diffusion-3.5-large": 192,
97114
}
98115

99116
LOCAl_SD3_CKPT = None
@@ -321,6 +338,14 @@ def mmdit_state_dict_adjustments(state_dict, prefix=""):
321338
for k, v in state_dict.items()
322339
}
323340

341+
# Remap qk_norm
342+
state_dict = {
343+
k.replace(".attn.ln_q.", ".qk_norm.q_norm."): v for k, v in state_dict.items()
344+
}
345+
state_dict = {
346+
k.replace(".attn.ln_k.", ".qk_norm.k_norm."): v for k, v in state_dict.items()
347+
}
348+
324349
# Split qkv proj and rename:
325350
# *transformer_block.attn.qkv.{weigth/bias} -> transformer_block.attn.{q/k/v}_proj.{weigth/bias}
326351
# *transformer_block.attn.proj.{weigth/bias} -> transformer_block.attn.o_proj.{weight/bias}
@@ -347,6 +372,9 @@ def mmdit_state_dict_adjustments(state_dict, prefix=""):
347372
# Filter out VAE Decoder related tensors
348373
state_dict = {k: v for k, v in state_dict.items() if "decoder." not in k}
349374

375+
# Filter out VAE Encoder related tensors
376+
state_dict = {k: v for k, v in state_dict.items() if "encoder." not in k}
377+
350378
# Filter out k_proj.bias related tensors
351379
state_dict = {k: v for k, v in state_dict.items() if "k_proj.bias" not in k}
352380

@@ -676,7 +704,7 @@ def load_mmdit(
676704
"""Load the MM-DiT model from the checkpoint file."""
677705
"""only_modulation_dict: Only returns the modulation dictionary"""
678706
dtype = _FLOAT16 if float16 else mx.float32
679-
config = SD3_2b
707+
config = _CONFIG[key]
680708
config.low_memory_mode = low_memory_mode
681709
model = MMDiT(config)
682710

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@
1414
# Defaults
1515
HEIGHT = {
1616
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
17-
"sd3-8b-unreleased": 1024,
17+
"argmaxinc/mlx-stable-diffusion-3.5-large": 1024,
1818
"argmaxinc/mlx-FLUX.1-schnell": 512,
1919
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2020
"argmaxinc/mlx-FLUX.1-dev": 512,
2121
}
2222
WIDTH = {
2323
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
24-
"sd3-8b-unreleased": 1024,
24+
"argmaxinc/mlx-stable-diffusion-3.5-large": 1024,
2525
"argmaxinc/mlx-FLUX.1-schnell": 512,
2626
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2727
"argmaxinc/mlx-FLUX.1-dev": 512,
2828
}
2929
SHIFT = {
3030
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
31-
"sd3-8b-unreleased": 3.0,
31+
"argmaxinc/mlx-stable-diffusion-3.5-large": 3.0,
3232
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
3333
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
3434
"argmaxinc/mlx-FLUX.1-dev": 1.0,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from setuptools import find_packages, setup
44
from setuptools.command.install import install
55

6-
VERSION = "0.4.0"
6+
VERSION = "0.5.0"
77

88

99
class VersionInstallCommand(install):

0 commit comments

Comments
 (0)