55
55
"argmaxinc/mlx-stable-diffusion-3.5-large" : "sd3.5_large.safetensors" ,
56
56
"vae" : "sd3.5_large.safetensors" ,
57
57
},
58
+ "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" : {
59
+ "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" : "sd3.5_large_4bit_quantized.safetensors" ,
60
+ "vae" : "sd3.5_large_4bit_quantized.safetensors" ,
61
+ },
58
62
}
59
63
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
60
64
_MODELS = {
92
96
"vae_encoder" : "first_stage_model.encoder." ,
93
97
"vae_decoder" : "first_stage_model.decoder." ,
94
98
},
99
+ "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" : {
100
+ "vae_encoder" : "first_stage_model.encoder." ,
101
+ "vae_decoder" : "first_stage_model.decoder." ,
102
+ },
95
103
}
96
104
97
105
_CONFIG = {
100
108
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized" : FLUX_SCHNELL ,
101
109
"argmaxinc/mlx-FLUX.1-dev" : FLUX_SCHNELL ,
102
110
"argmaxinc/mlx-stable-diffusion-3.5-large" : SD3_8b ,
111
+ "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" : SD3_8b ,
103
112
}
104
113
105
114
_FLOAT16 = mx .bfloat16
106
115
107
116
DEPTH = {
108
117
"argmaxinc/mlx-stable-diffusion-3-medium" : 24 ,
109
118
"argmaxinc/mlx-stable-diffusion-3.5-large" : 38 ,
119
+ "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" : 38 ,
110
120
}
111
121
MAX_LATENT_RESOLUTION = {
112
122
"argmaxinc/mlx-stable-diffusion-3-medium" : 96 ,
113
123
"argmaxinc/mlx-stable-diffusion-3.5-large" : 192 ,
124
+ "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" : 192 ,
114
125
}
115
126
116
127
LOCAl_SD3_CKPT = None
@@ -712,12 +723,23 @@ def load_mmdit(
712
723
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download (key , mmdit_weights )
713
724
hf_hub_download (key , "config.json" )
714
725
weights = mx .load (mmdit_weights_ckpt )
715
- weights = mmdit_state_dict_adjustments (weights , prefix = "model.diffusion_model." )
716
- weights = {k : v .astype (dtype ) for k , v in weights .items ()}
726
+ prefix = "model.diffusion_model."
727
+
728
+ if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" :
729
+ weights = mmdit_state_dict_adjustments (weights , prefix = prefix )
730
+ else :
731
+ nn .quantize (
732
+ model , class_predicate = lambda _ , module : isinstance (module , nn .Linear )
733
+ )
734
+ weights = {k .replace (prefix , "" ): v for k , v in weights .items () if prefix in k }
735
+
736
+ weights = {
737
+ k : v .astype (dtype ) if v .dtype != mx .uint32 else v for k , v in weights .items ()
738
+ }
717
739
if only_modulation_dict :
718
740
weights = {k : v for k , v in weights .items () if "adaLN" in k }
719
741
return tree_flatten (weights )
720
- model .update ( tree_unflatten ( tree_flatten ( weights )))
742
+ model .load_weights ( list ( weights . items ( )))
721
743
722
744
return model
723
745
@@ -852,11 +874,15 @@ def load_vae_decoder(
852
874
vae_weights = _MMDIT [key ][model_key ]
853
875
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download (key , vae_weights )
854
876
weights = mx .load (vae_weights_ckpt )
855
- weights = vae_decoder_state_dict_adjustments (
856
- weights , prefix = _PREFIX [key ]["vae_decoder" ]
857
- )
877
+ prefix = _PREFIX [key ]["vae_decoder" ]
878
+
879
+ if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" :
880
+ weights = vae_decoder_state_dict_adjustments (weights , prefix = prefix )
881
+ else :
882
+ weights = {k .replace (prefix , "" ): v for k , v in weights .items () if prefix in k }
883
+
858
884
weights = {k : v .astype (dtype ) for k , v in weights .items ()}
859
- model .update ( tree_unflatten ( tree_flatten ( weights )))
885
+ model .load_weights ( list ( weights . items ( )))
860
886
861
887
return model
862
888
@@ -880,11 +906,15 @@ def load_vae_encoder(
880
906
vae_weights = _MMDIT [key ][model_key ]
881
907
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download (key , vae_weights )
882
908
weights = mx .load (vae_weights_ckpt )
883
- weights = vae_encoder_state_dict_adjustments (
884
- weights , prefix = _PREFIX [key ]["vae_encoder" ]
885
- )
909
+ prefix = _PREFIX [key ]["vae_encoder" ]
910
+
911
+ if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized" :
912
+ weights = vae_encoder_state_dict_adjustments (weights , prefix = prefix )
913
+ else :
914
+ weights = {k .replace (prefix , "" ): v for k , v in weights .items () if prefix in k }
915
+
886
916
weights = {k : v .astype (dtype ) for k , v in weights .items ()}
887
- model .update ( tree_unflatten ( tree_flatten ( weights )))
917
+ model .load_weights ( list ( weights . items ( )))
888
918
889
919
return model
890
920
0 commit comments