Skip to content

Commit bfbdd0e

Browse files
authored
Merge pull request #28 from raoulritter/main
[Model Support] FLUX.1-dev
2 parents c8083cc + 5cdda4d commit bfbdd0e

File tree

9 files changed

+71
-15
lines changed

9 files changed

+71
-15
lines changed

.flake8

Lines changed: 0 additions & 6 deletions
This file was deleted.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ __pycache__/
2727
# Distribution / packaging
2828
.Python
2929
build/
30+
.build/
3031
develop-eggs/
3132
dist/
3233
downloads/

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ pip install -e .
3232
<summary> Click to expand </summary>
3333

3434

35-
[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) requires users to accept the terms before downloading the checkpoint. Once you accept the terms, sign in with your Hugging Face hub READ token as below:
35+
[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) requires users to accept the terms before downloading the checkpoint.
36+
37+
[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) also requires users to accept the terms before downloading the checkpoint.
38+
39+
Once you accept the terms, sign in with your Hugging Face hub READ token as below:
3640
> [!IMPORTANT]
3741
> If using a fine-grained token, it is also necessary to [edit permissions](https://huggingface.co/settings/tokens) to allow `Read access to contents of all public gated repos you can access`
3842
@@ -89,6 +93,8 @@ Some notable optional arguments for:
8993

9094
Please refer to the help menu for all available arguments: `diffusionkit-cli -h`.
9195

96+
Note: When using `FLUX.1-dev`, verify you've accepted the [FLUX.1-dev licence](https://huggingface.co/black-forest-labs/FLUX.1-dev) and have allowed gated access on your [HuggingFace token](https://huggingface.co/settings/tokens)
97+
9298
### Code ###
9399

94100
For Stable Diffusion 3:
@@ -109,7 +115,7 @@ For FLUX:
109115
from diffusionkit.mlx import FluxPipeline
110116
pipeline = FluxPipeline(
111117
shift=1.0,
112-
model_version="argmaxinc/mlx-FLUX.1-schnell",
118+
model_version="argmaxinc/mlx-FLUX.1-schnell", # model_version="argmaxinc/mlx-FLUX.1-dev" for FLUX.1-dev
113119
low_memory_mode=True,
114120
a16=True,
115121
w16=True,
@@ -120,7 +126,7 @@ Finally, to generate the image, use the `generate_image()` function:
120126
```python
121127
HEIGHT = 512
122128
WIDTH = 512
123-
NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3
129+
NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3 and FLUX.1-dev
124130
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3
125131

126132
image, _ = pipeline.generate_image(

python/src/diffusionkit/mlx/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
4040
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
4141
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
42+
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
4243
}
4344

4445
T5_MAX_LENGTH = {
4546
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
4647
"argmaxinc/mlx-FLUX.1-schnell": 256,
4748
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
49+
"argmaxinc/mlx-FLUX.1-dev": 512,
4850
}
4951

5052

@@ -653,7 +655,9 @@ def encode_text(
653655
text,
654656
(negative_text if cfg_weight > 1 else None),
655657
)
656-
padded_tokens_t5 = mx.zeros((1, 256)).astype(tokens_t5.dtype)
658+
padded_tokens_t5 = mx.zeros((1, T5_MAX_LENGTH[self.model_version])).astype(
659+
tokens_t5.dtype
660+
)
657661
padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[
658662
[0], :
659663
] # Ignore negative text

python/src/diffusionkit/mlx/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def hidden_size(self) -> int:
6868

6969
low_memory_mode: bool = True
7070

71+
guidance_embed: bool = False
72+
7173

7274
SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35])
7375

@@ -90,6 +92,22 @@ def hidden_size(self) -> int:
9092
dtype=mx.bfloat16,
9193
)
9294

95+
FLUX_DEV = MMDiTConfig(
96+
num_heads=24,
97+
depth_multimodal=19,
98+
depth_unified=38,
99+
parallel_mlp_for_unified_blocks=True,
100+
hidden_size_override=3072,
101+
patchify_via_reshape=True,
102+
pos_embed_type=PositionalEncoding.PreSDPARope,
103+
rope_axes_dim=(16, 56, 56),
104+
pooled_text_embed_dim=768, # CLIP-L/14 only
105+
use_qk_norm=True,
106+
float16_dtype=mx.bfloat16,
107+
guidance_embed=True,
108+
dtype=mx.bfloat16,
109+
)
110+
93111

94112
@dataclass
95113
class AutoencoderConfig:

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ def __init__(self, config: MMDiTConfig):
2828
super().__init__()
2929
self.config = config
3030

31+
if config.guidance_embed:
32+
self.guidance_in = MLPEmbedder(
33+
in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size
34+
)
35+
else:
36+
self.guidance_in = nn.Identity()
37+
3138
# Input adapters and embeddings
3239
self.x_embedder = LatentImageAdapter(config)
3340

@@ -209,6 +216,9 @@ def __call__(
209216
else:
210217
positional_encodings = None
211218

219+
if self.config.guidance_embed:
220+
timestep = self.guidance_in(self.t_embedder(timestep))
221+
212222
# MultiModalTransformer layers
213223
if self.config.depth_multimodal > 0:
214224
for bidx, block in enumerate(self.multimodal_transformer_blocks):
@@ -236,7 +246,6 @@ def __call__(
236246
:, token_level_text_embeddings.shape[1] :, ...
237247
]
238248

239-
# Final layer
240249
latent_image_embeddings = self.final_layer(
241250
latent_image_embeddings,
242251
timestep,
@@ -933,6 +942,19 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
933942
)
934943

935944

945+
class MLPEmbedder(nn.Module):
946+
def __init__(self, in_dim: int, hidden_dim: int):
947+
super().__init__()
948+
self.mlp = nn.Sequential(
949+
nn.Linear(in_dim, hidden_dim),
950+
nn.SiLU(),
951+
nn.Linear(hidden_dim, hidden_dim),
952+
)
953+
954+
def __call__(self, x):
955+
return self.mlp(x)
956+
957+
936958
def affine_transform(
937959
x: mx.array,
938960
shift: mx.array,

python/src/diffusionkit/mlx/model_io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
4747
"vae": "ae.safetensors",
4848
},
49+
"argmaxinc/mlx-FLUX.1-dev": {
50+
"argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors",
51+
"vae": "ae.safetensors",
52+
},
4953
}
5054
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
5155
_MODELS = {
@@ -75,6 +79,10 @@
7579
"vae_encoder": "encoder.",
7680
"vae_decoder": "decoder.",
7781
},
82+
"argmaxinc/mlx-FLUX.1-dev": {
83+
"vae_encoder": "encoder.",
84+
"vae_decoder": "decoder.",
85+
},
7886
}
7987

8088
_FLOAT16 = mx.bfloat16
@@ -704,7 +712,7 @@ def load_flux(
704712
hf_hub_download(key, "config.json")
705713
weights = mx.load(flux_weights_ckpt)
706714

707-
if model_key == "argmaxinc/mlx-FLUX.1-schnell":
715+
if model_key in ["argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-dev"]:
708716
weights = flux_state_dict_adjustments(
709717
weights,
710718
prefix="",

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,21 @@
1717
"sd3-8b-unreleased": 1024,
1818
"argmaxinc/mlx-FLUX.1-schnell": 512,
1919
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
20+
"argmaxinc/mlx-FLUX.1-dev": 512,
2021
}
2122
WIDTH = {
2223
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
2324
"sd3-8b-unreleased": 1024,
2425
"argmaxinc/mlx-FLUX.1-schnell": 512,
2526
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
27+
"argmaxinc/mlx-FLUX.1-dev": 512,
2628
}
2729
SHIFT = {
2830
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
2931
"sd3-8b-unreleased": 3.0,
3032
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
3133
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
34+
"argmaxinc/mlx-FLUX.1-dev": 1.0,
3235
}
3336

3437

@@ -111,7 +114,7 @@ def cli():
111114
args.a16 = True
112115

113116
if "FLUX" in args.model_version and args.cfg > 0.0:
114-
logger.warning("Disabling CFG for FLUX.1-schnell model.")
117+
logger.warning(f"Disabling CFG for {args.model_version} model.")
115118
args.cfg = 0.0
116119

117120
if args.benchmark_mode:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from setuptools import find_packages, setup
44
from setuptools.command.install import install
55

6-
VERSION = "0.3.5"
6+
VERSION = "0.4.0"
77

88

99
class VersionInstallCommand(install):
@@ -29,7 +29,7 @@ def run(self):
2929
"argmaxtools>=0.1.13",
3030
"torch",
3131
"safetensors",
32-
"mlx>=0.16.3",
32+
"mlx>=0.17.1",
3333
"jaxtyping",
3434
"transformers",
3535
"pillow",

0 commit comments

Comments
 (0)