Skip to content

Commit 0e63ca6

Browse files
authored
Merge pull request #26 from argmaxinc/update_config
Update model repo names
2 parents d915340 + 5fbd7ad commit 0e63ca6

File tree

5 files changed

+40
-42
lines changed

5 files changed

+40
-42
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,9 @@ For Stable Diffusion 3:
9595
```python
9696
from diffusionkit.mlx import DiffusionPipeline
9797
pipeline = DiffusionPipeline(
98-
model="argmaxinc/stable-diffusion",
9998
shift=3.0,
10099
use_t5=False,
101-
model_version="stable-diffusion-3-medium",
100+
model_version="argmaxinc/mlx-stable-diffusion-3-medium",
102101
low_memory_mode=True,
103102
a16=True,
104103
w16=True,
@@ -109,9 +108,8 @@ For FLUX:
109108
```python
110109
from diffusionkit.mlx import FluxPipeline
111110
pipeline = FluxPipeline(
112-
model="argmaxinc/stable-diffusion",
113111
shift=1.0,
114-
model_version="FLUX.1-schnell",
112+
model_version="argmaxinc/mlx-FLUX.1-schnell",
115113
low_memory_mode=True,
116114
a16=True,
117115
w16=True,

python/src/diffusionkit/mlx/__init__.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@
3535
logger = get_logger(__name__)
3636

3737
MMDIT_CKPT = {
38-
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
38+
"argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium",
3939
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
40-
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
41-
"FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
40+
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
41+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
4242
}
4343

4444
T5_MAX_LENGTH = {
45-
"stable-diffusion-3-medium": 512,
46-
"FLUX.1-schnell": 256,
47-
"FLUX.1-schnell-4bit-quantized": 256,
45+
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
46+
"argmaxinc/mlx-FLUX.1-schnell": 256,
47+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
4848
}
4949

5050

@@ -59,11 +59,10 @@ def model_spec(self):
5959
class DiffusionPipeline:
6060
def __init__(
6161
self,
62-
model: str = _DEFAULT_MODEL,
6362
w16: bool = False,
6463
shift: float = 1.0,
6564
use_t5: bool = True,
66-
model_version: str = "stable-diffusion-3-medium",
65+
model_version: str = "argmaxinc/mlx-stable-diffusion-3-medium",
6766
low_memory_mode: bool = True,
6867
a16: bool = False,
6968
local_ckpt=None,
@@ -76,7 +75,7 @@ def __init__(
7675
self.use_t5 = use_t5
7776
self.mmdit_ckpt = MMDIT_CKPT[model_version]
7877
self.low_memory_mode = low_memory_mode
79-
self.model = model
78+
self.model = _DEFAULT_MODEL
8079
self.model_version = model_version
8180
self.sampler = ModelSamplingDiscreteFlow(shift=shift)
8281
self.latent_format = SD3LatentFormat()
@@ -586,11 +585,10 @@ def encode_image_to_latents(self, image_path: str, seed):
586585
class FluxPipeline(DiffusionPipeline):
587586
def __init__(
588587
self,
589-
model: str = _DEFAULT_MODEL,
590588
w16: bool = False,
591589
shift: float = 1.0,
592590
use_t5: bool = True,
593-
model_version: str = "FLUX.1-schnell",
591+
model_version: str = "argmaxinc/mlx-FLUX.1-schnell",
594592
low_memory_mode: bool = True,
595593
a16: bool = False,
596594
local_ckpt=None,
@@ -603,7 +601,7 @@ def __init__(
603601
self.activation_dtype = self.float16_dtype if a16 else mx.float32
604602
self.mmdit_ckpt = MMDIT_CKPT[model_version]
605603
self.low_memory_mode = low_memory_mode
606-
self.model = model
604+
self.model = _DEFAULT_MODEL
607605
self.model_version = model_version
608606
self.sampler = FluxSampler(shift=shift)
609607
self.latent_format = FluxLatentFormat()

python/src/diffusionkit/mlx/model_io.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@
3232

3333

3434
RANK = 32
35-
_DEFAULT_MMDIT = "stabilityai/stable-diffusion-3-medium"
35+
_DEFAULT_MMDIT = "argmaxinc/mlx-stable-diffusion-3-medium"
3636
_MMDIT = {
37-
"stabilityai/stable-diffusion-3-medium": {
38-
"stable-diffusion-3-medium": "sd3_medium.safetensors",
37+
"argmaxinc/mlx-stable-diffusion-3-medium": {
38+
"argmaxinc/mlx-stable-diffusion-3-medium": "sd3_medium.safetensors",
3939
"vae": "sd3_medium.safetensors",
4040
},
4141
"argmaxinc/mlx-FLUX.1-schnell": {
42-
"FLUX.1-schnell": "flux-schnell.safetensors",
42+
"argmaxinc/mlx-FLUX.1-schnell": "flux-schnell.safetensors",
4343
"vae": "ae.safetensors",
4444
},
4545
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": {
46-
"FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
46+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
4747
"vae": "ae.safetensors",
4848
},
4949
}
@@ -63,7 +63,7 @@
6363
}
6464

6565
_PREFIX = {
66-
"stabilityai/stable-diffusion-3-medium": {
66+
"argmaxinc/mlx-stable-diffusion-3-medium": {
6767
"vae_encoder": "first_stage_model.encoder.",
6868
"vae_decoder": "first_stage_model.decoder.",
6969
},
@@ -80,11 +80,11 @@
8080
_FLOAT16 = mx.bfloat16
8181

8282
DEPTH = {
83-
"stable-diffusion-3-medium": 24,
83+
"argmaxinc/mlx-stable-diffusion-3-medium": 24,
8484
"sd3-8b-unreleased": 38,
8585
}
8686
MAX_LATENT_RESOLUTION = {
87-
"stable-diffusion-3-medium": 96,
87+
"argmaxinc/mlx-stable-diffusion-3-medium": 96,
8888
"sd3-8b-unreleased": 192,
8989
}
9090

@@ -674,6 +674,7 @@ def load_mmdit(
674674

675675
mmdit_weights = _MMDIT[key][model_key]
676676
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights)
677+
hf_hub_download(key, "config.json")
677678
weights = mx.load(mmdit_weights_ckpt)
678679
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
679680
weights = {k: v.astype(dtype) for k, v in weights.items()}
@@ -688,7 +689,7 @@ def load_mmdit(
688689
def load_flux(
689690
key: str = "argmaxinc/mlx-FLUX.1-schnell",
690691
float16: bool = False,
691-
model_key: str = "FLUX.1-schnell",
692+
model_key: str = "argmaxinc/mlx-FLUX.1-schnell",
692693
low_memory_mode: bool = True,
693694
only_modulation_dict: bool = False,
694695
):
@@ -703,14 +704,16 @@ def load_flux(
703704
hf_hub_download(key, "config.json")
704705
weights = mx.load(flux_weights_ckpt)
705706

706-
if model_key == "FLUX.1-schnell":
707+
if model_key == "argmaxinc/mlx-FLUX.1-schnell":
707708
weights = flux_state_dict_adjustments(
708709
weights,
709710
prefix="",
710711
hidden_size=config.hidden_size,
711712
mlp_ratio=config.mlp_ratio,
712713
)
713-
elif model_key == "FLUX.1-schnell-4bit-quantized": # 4-bit ckpt already adjusted
714+
elif (
715+
model_key == "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized"
716+
): # 4-bit ckpt already adjusted
714717
nn.quantize(model)
715718

716719
weights = {

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@
1313

1414
# Defaults
1515
HEIGHT = {
16-
"stable-diffusion-3-medium": 512,
16+
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
1717
"sd3-8b-unreleased": 1024,
18-
"FLUX.1-schnell": 512,
19-
"FLUX.1-schnell-4bit-quantized": 512,
18+
"argmaxinc/mlx-FLUX.1-schnell": 512,
19+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2020
}
2121
WIDTH = {
22-
"stable-diffusion-3-medium": 512,
22+
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
2323
"sd3-8b-unreleased": 1024,
24-
"FLUX.1-schnell": 512,
25-
"FLUX.1-schnell-4bit-quantized": 512,
24+
"argmaxinc/mlx-FLUX.1-schnell": 512,
25+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
2626
}
2727
SHIFT = {
28-
"stable-diffusion-3-medium": 3.0,
28+
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
2929
"sd3-8b-unreleased": 3.0,
30-
"FLUX.1-schnell": 1.0,
31-
"FLUX.1-schnell-4bit-quantized": 1.0,
30+
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
31+
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
3232
}
3333

3434

@@ -43,7 +43,7 @@ def cli():
4343
parser.add_argument(
4444
"--model-version",
4545
choices=tuple(MMDIT_CKPT.keys()),
46-
default="FLUX.1-schnell",
46+
default="argmaxinc/mlx-FLUX.1-schnell",
4747
help="Diffusion model version, e.g. FLUX-1.schnell, stable-diffusion-3-medium",
4848
)
4949
parser.add_argument(
@@ -127,7 +127,6 @@ def cli():
127127

128128
# Load the models
129129
sd = pipeline_class(
130-
model="argmaxinc/stable-diffusion",
131130
w16=args.w16,
132131
shift=shift,
133132
use_t5=args.t5,

python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
LOW_MEMORY_MODE = True
2727
SAVE_IMAGES = True
28-
MODEL_VERSION = "stable-diffusion-3-medium"
28+
MODEL_VERSION = "argmaxinc/mlx-stable-diffusion-3-medium"
2929
USE_T5 = False
3030
SKIP_CORRECTNESS = False
3131

@@ -49,7 +49,7 @@ def test_sd3_pipeline_correctness(self):
4949
metadata = json.load(f)
5050

5151
# Group metadata by model size
52-
model_examples = {"stable-diffusion-3-medium": []}
52+
model_examples = {"argmaxinc/mlx-stable-diffusion-3-medium": []}
5353
for data in metadata:
5454
model_examples[data["model_version"]].append(data)
5555

@@ -106,7 +106,7 @@ def test_memory_usage(self):
106106
metadata = json.load(f)
107107

108108
# Group metadata by model size
109-
model_examples = {"stable-diffusion-3-medium": []}
109+
model_examples = {"argmaxinc/mlx-stable-diffusion-3-medium": []}
110110
for data in metadata:
111111
model_examples[data["model_version"]].append(data)
112112

@@ -187,7 +187,7 @@ def main(args):
187187
parser.add_argument(
188188
"--model-size",
189189
type=str,
190-
default="stable-diffusion-3-medium",
190+
default="argmaxinc/mlx-stable-diffusion-3-medium",
191191
choices=tuple(MMDIT_CKPT.keys()),
192192
help="model version to test",
193193
)

0 commit comments

Comments
 (0)