Skip to content

Commit 1508f9f

Browse files
Merge branch 'main' into main
2 parents dce4c87 + c8083cc commit 1508f9f

File tree

7 files changed

+57
-48
lines changed

7 files changed

+57
-48
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,9 @@ For Stable Diffusion 3:
9595
```python
9696
from diffusionkit.mlx import DiffusionPipeline
9797
pipeline = DiffusionPipeline(
98-
model="argmaxinc/stable-diffusion",
9998
shift=3.0,
10099
use_t5=False,
101-
model_version="stable-diffusion-3-medium",
100+
model_version="argmaxinc/mlx-stable-diffusion-3-medium",
102101
low_memory_mode=True,
103102
a16=True,
104103
w16=True,
@@ -109,9 +108,8 @@ For FLUX:
109108
```python
110109
from diffusionkit.mlx import FluxPipeline
111110
pipeline = FluxPipeline(
112-
model="argmaxinc/stable-diffusion",
113111
shift=1.0,
114-
model_version="FLUX.1-schnell",
112+
model_version="argmaxinc/mlx-FLUX.1-schnell",
115113
low_memory_mode=True,
116114
a16=True,
117115
w16=True,

python/src/diffusionkit/mlx/__init__.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,20 @@
3535
logger = get_logger(__name__)
3636

3737
MMDIT_CKPT = {
38-
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
38+
"argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium",
3939
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
40-
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
41-
"FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
42-
"FLUX.1-dev": "raoulritter/flux-dev-mlx",
40+
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
41+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
42+
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev"
4343
}
4444

4545
T5_MAX_LENGTH = {
46-
"stable-diffusion-3-medium": 512,
47-
"FLUX.1-schnell": 256,
48-
"FLUX.1-schnell-4bit-quantized": 256,
49-
"FLUX.1-dev": 512,
46+
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
47+
"argmaxinc/mlx-FLUX.1-schnell": 256,
48+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
49+
"argmaxinc/mlx-FLUX.1-dev": 512,
5050
}
5151

52-
5352
class DiffusionKitInferenceContext(AppleSiliconContextMixin, InferenceContextSpec):
5453
def code_spec(self):
5554
return {}
@@ -61,11 +60,10 @@ def model_spec(self):
6160
class DiffusionPipeline:
6261
def __init__(
6362
self,
64-
model: str = _DEFAULT_MODEL,
6563
w16: bool = False,
6664
shift: float = 1.0,
6765
use_t5: bool = True,
68-
model_version: str = "stable-diffusion-3-medium",
66+
model_version: str = "argmaxinc/mlx-stable-diffusion-3-medium",
6967
low_memory_mode: bool = True,
7068
a16: bool = False,
7169
local_ckpt=None,
@@ -78,7 +76,7 @@ def __init__(
7876
self.use_t5 = use_t5
7977
self.mmdit_ckpt = MMDIT_CKPT[model_version]
8078
self.low_memory_mode = low_memory_mode
81-
self.model = model
79+
self.model = _DEFAULT_MODEL
8280
self.model_version = model_version
8381
self.sampler = ModelSamplingDiscreteFlow(shift=shift)
8482
self.latent_format = SD3LatentFormat()
@@ -301,6 +299,13 @@ def generate_image(
301299
image_path: Optional[str] = None,
302300
denoise: float = 1.0,
303301
):
302+
# Check latent size is divisible by 2
303+
assert (
304+
latent_size[0] % 2 == 0
305+
), f"Height must be divisible by 16 ({latent_size[0]*8}/16={latent_size[0]/2})"
306+
assert (
307+
latent_size[1] % 2 == 0
308+
), f"Width must be divisible by 16 ({latent_size[1]*8}/16={latent_size[1]/2})"
304309
self.check_and_load_models()
305310
# Start timing
306311
start_time = time.time()
@@ -588,11 +593,10 @@ def encode_image_to_latents(self, image_path: str, seed):
588593
class FluxPipeline(DiffusionPipeline):
589594
def __init__(
590595
self,
591-
model: str = _DEFAULT_MODEL,
592596
w16: bool = False,
593597
shift: float = 1.0,
594598
use_t5: bool = True,
595-
model_version: str = "FLUX.1-schnell",
599+
model_version: str = "argmaxinc/mlx-FLUX.1-schnell",
596600
low_memory_mode: bool = True,
597601
a16: bool = False,
598602
local_ckpt=None,
@@ -605,7 +609,7 @@ def __init__(
605609
self.activation_dtype = self.float16_dtype if a16 else mx.float32
606610
self.mmdit_ckpt = MMDIT_CKPT[model_version]
607611
self.low_memory_mode = low_memory_mode
608-
self.model = model
612+
self.model = _DEFAULT_MODEL
609613
self.model_version = model_version
610614
self.sampler = FluxSampler(shift=shift)
611615
self.latent_format = FluxLatentFormat()

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -958,11 +958,14 @@ def affine_transform(
958958
norm_module: nn.Module = None,
959959
) -> mx.array:
960960
"""Affine transformation (Used for Adaptive LayerNorm Modulation)"""
961-
if norm_module is not None:
961+
if x.shape[0] == 1 and norm_module is not None:
962962
return mx.fast.layer_norm(
963963
x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps
964964
)
965-
return x * (1.0 + residual_scale) + shift
965+
elif norm_module is not None:
966+
return norm_module(x) * (1.0 + residual_scale) + shift
967+
else:
968+
return x * (1.0 + residual_scale) + shift
966969

967970

968971
def unpatchify(

python/src/diffusionkit/mlx/model_io.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@
3232

3333

3434
RANK = 32
35-
_DEFAULT_MMDIT = "stabilityai/stable-diffusion-3-medium"
35+
_DEFAULT_MMDIT = "argmaxinc/mlx-stable-diffusion-3-medium"
3636
_MMDIT = {
37-
"stabilityai/stable-diffusion-3-medium": {
38-
"stable-diffusion-3-medium": "sd3_medium.safetensors",
37+
"argmaxinc/mlx-stable-diffusion-3-medium": {
38+
"argmaxinc/mlx-stable-diffusion-3-medium": "sd3_medium.safetensors",
3939
"vae": "sd3_medium.safetensors",
4040
},
4141
"argmaxinc/mlx-FLUX.1-schnell": {
42-
"FLUX.1-schnell": "flux-schnell.safetensors",
42+
"argmaxinc/mlx-FLUX.1-schnell": "flux-schnell.safetensors",
4343
"vae": "ae.safetensors",
4444
},
4545
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": {
46-
"FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
46+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
4747
"vae": "ae.safetensors",
4848
},
4949
"raoulritter/flux-dev-mlx": {
@@ -67,7 +67,7 @@
6767
}
6868

6969
_PREFIX = {
70-
"stabilityai/stable-diffusion-3-medium": {
70+
"argmaxinc/mlx-stable-diffusion-3-medium": {
7171
"vae_encoder": "first_stage_model.encoder.",
7272
"vae_decoder": "first_stage_model.decoder.",
7373
},
@@ -88,11 +88,11 @@
8888
_FLOAT16 = mx.bfloat16
8989

9090
DEPTH = {
91-
"stable-diffusion-3-medium": 24,
91+
"argmaxinc/mlx-stable-diffusion-3-medium": 24,
9292
"sd3-8b-unreleased": 38,
9393
}
9494
MAX_LATENT_RESOLUTION = {
95-
"stable-diffusion-3-medium": 96,
95+
"argmaxinc/mlx-stable-diffusion-3-medium": 96,
9696
"sd3-8b-unreleased": 192,
9797
}
9898

@@ -682,6 +682,7 @@ def load_mmdit(
682682

683683
mmdit_weights = _MMDIT[key][model_key]
684684
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights)
685+
hf_hub_download(key, "config.json")
685686
weights = mx.load(mmdit_weights_ckpt)
686687
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
687688
weights = {k: v.astype(dtype) for k, v in weights.items()}
@@ -696,7 +697,7 @@ def load_mmdit(
696697
def load_flux(
697698
key: str = "argmaxinc/mlx-FLUX.1-schnell",
698699
float16: bool = False,
699-
model_key: str = "FLUX.1-schnell",
700+
model_key: str = "argmaxinc/mlx-FLUX.1-schnell",
700701
low_memory_mode: bool = True,
701702
only_modulation_dict: bool = False,
702703
):
@@ -711,14 +712,16 @@ def load_flux(
711712
hf_hub_download(key, "config.json")
712713
weights = mx.load(flux_weights_ckpt)
713714

714-
if model_key in ["FLUX.1-schnell", "FLUX.1-dev"]:
715+
if model_key in ["argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-dev"]:
715716
weights = flux_state_dict_adjustments(
716717
weights,
717718
prefix="",
718719
hidden_size=config.hidden_size,
719720
mlp_ratio=config.mlp_ratio,
720721
)
721-
elif model_key == "FLUX.1-schnell-4bit-quantized": # 4-bit ckpt already adjusted
722+
elif (
723+
model_key == "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized"
724+
): # 4-bit ckpt already adjusted
722725
nn.quantize(model)
723726

724727
weights = {

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@
1313

1414
# Defaults
1515
HEIGHT = {
16-
"stable-diffusion-3-medium": 512,
16+
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
1717
"sd3-8b-unreleased": 1024,
18-
"FLUX.1-schnell": 512,
19-
"FLUX.1-schnell-4bit-quantized": 512,
18+
"argmaxinc/mlx-FLUX.1-schnell": 512,
19+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2020
}
2121
WIDTH = {
22-
"stable-diffusion-3-medium": 512,
22+
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
2323
"sd3-8b-unreleased": 1024,
24-
"FLUX.1-schnell": 512,
25-
"FLUX.1-schnell-4bit-quantized": 512,
24+
"argmaxinc/mlx-FLUX.1-schnell": 512,
25+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2626
}
2727
SHIFT = {
28-
"stable-diffusion-3-medium": 3.0,
28+
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
2929
"sd3-8b-unreleased": 3.0,
30-
"FLUX.1-schnell": 1.0,
31-
"FLUX.1-schnell-4bit-quantized": 1.0,
30+
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
31+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
3232
}
3333

3434

@@ -43,7 +43,7 @@ def cli():
4343
parser.add_argument(
4444
"--model-version",
4545
choices=tuple(MMDIT_CKPT.keys()),
46-
default="FLUX.1-schnell",
46+
default="argmaxinc/mlx-FLUX.1-schnell",
4747
help="Diffusion model version, e.g. FLUX-1.schnell, stable-diffusion-3-medium",
4848
)
4949
parser.add_argument(
@@ -127,7 +127,6 @@ def cli():
127127

128128
# Load the models
129129
sd = pipeline_class(
130-
model="argmaxinc/stable-diffusion",
131130
w16=args.w16,
132131
shift=shift,
133132
use_t5=args.t5,
@@ -143,6 +142,8 @@ def cli():
143142

144143
height = args.height or HEIGHT[args.model_version]
145144
width = args.width or WIDTH[args.model_version]
145+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
146+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
146147
logger.info(f"Output image resolution will be {height}x{width}")
147148

148149
if args.benchmark_mode:

python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
LOW_MEMORY_MODE = True
2727
SAVE_IMAGES = True
28-
MODEL_VERSION = "stable-diffusion-3-medium"
28+
MODEL_VERSION = "argmaxinc/mlx-stable-diffusion-3-medium"
2929
USE_T5 = False
3030
SKIP_CORRECTNESS = False
3131

@@ -49,7 +49,7 @@ def test_sd3_pipeline_correctness(self):
4949
metadata = json.load(f)
5050

5151
# Group metadata by model size
52-
model_examples = {"stable-diffusion-3-medium": []}
52+
model_examples = {"argmaxinc/mlx-stable-diffusion-3-medium": []}
5353
for data in metadata:
5454
model_examples[data["model_version"]].append(data)
5555

@@ -106,7 +106,7 @@ def test_memory_usage(self):
106106
metadata = json.load(f)
107107

108108
# Group metadata by model size
109-
model_examples = {"stable-diffusion-3-medium": []}
109+
model_examples = {"argmaxinc/mlx-stable-diffusion-3-medium": []}
110110
for data in metadata:
111111
model_examples[data["model_version"]].append(data)
112112

@@ -187,7 +187,7 @@ def main(args):
187187
parser.add_argument(
188188
"--model-size",
189189
type=str,
190-
default="stable-diffusion-3-medium",
190+
default="argmaxinc/mlx-stable-diffusion-3-medium",
191191
choices=tuple(MMDIT_CKPT.keys()),
192192
help="model version to test",
193193
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from setuptools import find_packages, setup
44
from setuptools.command.install import install
55

6-
VERSION = "0.3.2"
6+
VERSION = "0.3.5"
77

88

99
class VersionInstallCommand(install):

0 commit comments

Comments
 (0)