Skip to content

Commit d737473

Browse files
authored
Merge pull request #43 from EduardoPach/add-sd3.5-4bit
Add SD3.5-large 4bit quantized
2 parents 925ef5d + 1d05fce commit d737473

File tree

3 files changed

+47
-11
lines changed

3 files changed

+47
-11
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
MMDIT_CKPT = {
3838
"argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium",
3939
"argmaxinc/mlx-stable-diffusion-3.5-large": "argmaxinc/mlx-stable-diffusion-3.5-large",
40+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized",
4041
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
4142
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
4243
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
@@ -45,6 +46,7 @@
4546
T5_MAX_LENGTH = {
4647
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
4748
"argmaxinc/mlx-stable-diffusion-3.5-large": 512,
49+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 512,
4850
"argmaxinc/mlx-FLUX.1-schnell": 256,
4951
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
5052
"argmaxinc/mlx-FLUX.1-dev": 512,

python/src/diffusionkit/mlx/model_io.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@
5555
"argmaxinc/mlx-stable-diffusion-3.5-large": "sd3.5_large.safetensors",
5656
"vae": "sd3.5_large.safetensors",
5757
},
58+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": {
59+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": "sd3.5_large_4bit_quantized.safetensors",
60+
"vae": "sd3.5_large_4bit_quantized.safetensors",
61+
},
5862
}
5963
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
6064
_MODELS = {
@@ -92,6 +96,10 @@
9296
"vae_encoder": "first_stage_model.encoder.",
9397
"vae_decoder": "first_stage_model.decoder.",
9498
},
99+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": {
100+
"vae_encoder": "first_stage_model.encoder.",
101+
"vae_decoder": "first_stage_model.decoder.",
102+
},
95103
}
96104

97105
_CONFIG = {
@@ -100,17 +108,20 @@
100108
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": FLUX_SCHNELL,
101109
"argmaxinc/mlx-FLUX.1-dev": FLUX_SCHNELL,
102110
"argmaxinc/mlx-stable-diffusion-3.5-large": SD3_8b,
111+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": SD3_8b,
103112
}
104113

105114
_FLOAT16 = mx.bfloat16
106115

107116
DEPTH = {
108117
"argmaxinc/mlx-stable-diffusion-3-medium": 24,
109118
"argmaxinc/mlx-stable-diffusion-3.5-large": 38,
119+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 38,
110120
}
111121
MAX_LATENT_RESOLUTION = {
112122
"argmaxinc/mlx-stable-diffusion-3-medium": 96,
113123
"argmaxinc/mlx-stable-diffusion-3.5-large": 192,
124+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 192,
114125
}
115126

116127
LOCAl_SD3_CKPT = None
@@ -712,12 +723,23 @@ def load_mmdit(
712723
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights)
713724
hf_hub_download(key, "config.json")
714725
weights = mx.load(mmdit_weights_ckpt)
715-
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
716-
weights = {k: v.astype(dtype) for k, v in weights.items()}
726+
prefix = "model.diffusion_model."
727+
728+
if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized":
729+
weights = mmdit_state_dict_adjustments(weights, prefix=prefix)
730+
else:
731+
nn.quantize(
732+
model, class_predicate=lambda _, module: isinstance(module, nn.Linear)
733+
)
734+
weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k}
735+
736+
weights = {
737+
k: v.astype(dtype) if v.dtype != mx.uint32 else v for k, v in weights.items()
738+
}
717739
if only_modulation_dict:
718740
weights = {k: v for k, v in weights.items() if "adaLN" in k}
719741
return tree_flatten(weights)
720-
model.update(tree_unflatten(tree_flatten(weights)))
742+
model.load_weights(list(weights.items()))
721743

722744
return model
723745

@@ -852,11 +874,15 @@ def load_vae_decoder(
852874
vae_weights = _MMDIT[key][model_key]
853875
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
854876
weights = mx.load(vae_weights_ckpt)
855-
weights = vae_decoder_state_dict_adjustments(
856-
weights, prefix=_PREFIX[key]["vae_decoder"]
857-
)
877+
prefix = _PREFIX[key]["vae_decoder"]
878+
879+
if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized":
880+
weights = vae_decoder_state_dict_adjustments(weights, prefix=prefix)
881+
else:
882+
weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k}
883+
858884
weights = {k: v.astype(dtype) for k, v in weights.items()}
859-
model.update(tree_unflatten(tree_flatten(weights)))
885+
model.load_weights(list(weights.items()))
860886

861887
return model
862888

@@ -880,11 +906,15 @@ def load_vae_encoder(
880906
vae_weights = _MMDIT[key][model_key]
881907
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
882908
weights = mx.load(vae_weights_ckpt)
883-
weights = vae_encoder_state_dict_adjustments(
884-
weights, prefix=_PREFIX[key]["vae_encoder"]
885-
)
909+
prefix = _PREFIX[key]["vae_encoder"]
910+
911+
if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized":
912+
weights = vae_encoder_state_dict_adjustments(weights, prefix=prefix)
913+
else:
914+
weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k}
915+
886916
weights = {k: v.astype(dtype) for k, v in weights.items()}
887-
model.update(tree_unflatten(tree_flatten(weights)))
917+
model.load_weights(list(weights.items()))
888918

889919
return model
890920

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@
1515
HEIGHT = {
1616
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
1717
"argmaxinc/mlx-stable-diffusion-3.5-large": 1024,
18+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 1024,
1819
"argmaxinc/mlx-FLUX.1-schnell": 512,
1920
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2021
"argmaxinc/mlx-FLUX.1-dev": 512,
2122
}
2223
WIDTH = {
2324
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
2425
"argmaxinc/mlx-stable-diffusion-3.5-large": 1024,
26+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 1024,
2527
"argmaxinc/mlx-FLUX.1-schnell": 512,
2628
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2729
"argmaxinc/mlx-FLUX.1-dev": 512,
2830
}
2931
SHIFT = {
3032
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
3133
"argmaxinc/mlx-stable-diffusion-3.5-large": 3.0,
34+
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 3.0,
3235
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
3336
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
3437
"argmaxinc/mlx-FLUX.1-dev": 1.0,
@@ -108,6 +111,7 @@ def cli():
108111
type=str,
109112
help="Path to the local mmdit checkpoint.",
110113
)
114+
111115
args = parser.parse_args()
112116

113117
args.w16 = True

0 commit comments

Comments
 (0)