Skip to content

Commit 13c2a05

Browse files
authored
Merge pull request #23 from EduardoPach/add-quantize-arg
Add quantize arg
2 parents 3d1b7b7 + 4e778a0 commit 13c2a05

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
3939
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
4040
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
41+
"FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
4142
}
4243

4344
T5_MAX_LENGTH = {
4445
"stable-diffusion-3-medium": 512,
4546
"FLUX.1-schnell": 256,
47+
"FLUX.1-schnell-4bit-quantized": 256,
4648
}
4749

4850

@@ -592,6 +594,7 @@ def __init__(
592594
low_memory_mode: bool = True,
593595
a16: bool = False,
594596
local_ckpt=None,
597+
quantize_mmdit: bool = False,
595598
):
596599
model_io.LOCAl_SD3_CKPT = local_ckpt
597600
self.float16_dtype = mx.bfloat16
@@ -606,16 +609,21 @@ def __init__(
606609
self.latent_format = FluxLatentFormat()
607610
self.use_t5 = True
608611
self.use_clip_g = False
612+
self.quantize_mmdit = quantize_mmdit
609613
self.check_and_load_models()
610614

611615
def load_mmdit(self, only_modulation_dict=False):
612616
if only_modulation_dict:
613617
return load_flux(
618+
key=self.mmdit_ckpt,
619+
model_key=self.model_version,
614620
float16=True if self.dtype == self.float16_dtype else False,
615621
low_memory_mode=self.low_memory_mode,
616622
only_modulation_dict=only_modulation_dict,
617623
)
618624
self.mmdit = load_flux(
625+
key=self.mmdit_ckpt,
626+
model_key=self.model_version,
619627
float16=True if self.dtype == self.float16_dtype else False,
620628
low_memory_mode=self.low_memory_mode,
621629
only_modulation_dict=only_modulation_dict,

python/src/diffusionkit/mlx/model_io.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import mlx.core as mx
1212
from huggingface_hub import hf_hub_download
13+
from mlx import nn
1314
from mlx.utils import tree_flatten, tree_unflatten
1415
from transformers import T5Config
1516

@@ -41,6 +42,10 @@
4142
"FLUX.1-schnell": "flux-schnell.safetensors",
4243
"vae": "ae.safetensors",
4344
},
45+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": {
46+
"FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
47+
"vae": "ae.safetensors",
48+
},
4449
}
4550
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
4651
_MODELS = {
@@ -66,6 +71,10 @@
6671
"vae_encoder": "encoder.",
6772
"vae_decoder": "decoder.",
6873
},
74+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": {
75+
"vae_encoder": "encoder.",
76+
"vae_decoder": "decoder.",
77+
},
6978
}
7079

7180
_FLOAT16 = mx.bfloat16
@@ -693,10 +702,20 @@ def load_flux(
693702
flux_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, flux_weights)
694703
hf_hub_download(key, "config.json")
695704
weights = mx.load(flux_weights_ckpt)
696-
weights = flux_state_dict_adjustments(
697-
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
698-
)
699-
weights = {k: v.astype(dtype) for k, v in weights.items()}
705+
706+
if model_key == "FLUX.1-schnell":
707+
weights = flux_state_dict_adjustments(
708+
weights,
709+
prefix="",
710+
hidden_size=config.hidden_size,
711+
mlp_ratio=config.mlp_ratio,
712+
)
713+
elif model_key == "FLUX.1-schnell-4bit-quantized": # 4-bit ckpt already adjusted
714+
nn.quantize(model)
715+
716+
weights = {
717+
k: v.astype(dtype) if v.dtype != mx.uint32 else v for k, v in weights.items()
718+
}
700719
if only_modulation_dict:
701720
weights = {k: v for k, v in weights.items() if "adaLN" in k}
702721
return tree_flatten(weights)

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@
1616
"stable-diffusion-3-medium": 512,
1717
"sd3-8b-unreleased": 1024,
1818
"FLUX.1-schnell": 512,
19+
"FLUX.1-schnell-4bit-quantized": 512,
1920
}
2021
WIDTH = {
2122
"stable-diffusion-3-medium": 512,
2223
"sd3-8b-unreleased": 1024,
2324
"FLUX.1-schnell": 512,
25+
"FLUX.1-schnell-4bit-quantized": 512,
2426
}
2527
SHIFT = {
2628
"stable-diffusion-3-medium": 3.0,
2729
"sd3-8b-unreleased": 3.0,
2830
"FLUX.1-schnell": 1.0,
31+
"FLUX.1-schnell-4bit-quantized": 1.0,
2932
}
3033

3134

@@ -107,7 +110,7 @@ def cli():
107110
args.w16 = True
108111
args.a16 = True
109112

110-
if args.model_version == "FLUX.1-schnell" and args.cfg > 0.0:
113+
if "FLUX" in args.model_version and args.cfg > 0.0:
111114
logger.warning("Disabling CFG for FLUX.1-schnell model.")
112115
args.cfg = 0.0
113116

0 commit comments

Comments
 (0)