Skip to content

[Model Support] FLUX.1-dev #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .flake8

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __pycache__/
# Distribution / packaging
.Python
build/
.build/
develop-eggs/
dist/
downloads/
Expand Down
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ pip install -e .
<summary> Click to expand </summary>


[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) requires users to accept the terms before downloading the checkpoint. Once you accept the terms, sign in with your Hugging Face hub READ token as below:
[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) requires users to accept the terms before downloading the checkpoint.

[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) also requires users to accept the terms before downloading the checkpoint.

Once you accept the terms, sign in with your Hugging Face hub READ token as below:
> [!IMPORTANT]
> If using a fine-grained token, it is also necessary to [edit permissions](https://huggingface.co/settings/tokens) to allow `Read access to contents of all public gated repos you can access`

Expand Down Expand Up @@ -89,6 +93,8 @@ Some notable optional arguments for:

Please refer to the help menu for all available arguments: `diffusionkit-cli -h`.

Note: When using `FLUX.1-dev`, verify you've accepted the [FLUX.1-dev licence](https://huggingface.co/black-forest-labs/FLUX.1-dev) and have allowed gated access on your [HuggingFace token](https://huggingface.co/settings/tokens)

### Code ###

For Stable Diffusion 3:
Expand All @@ -109,7 +115,7 @@ For FLUX:
from diffusionkit.mlx import FluxPipeline
pipeline = FluxPipeline(
shift=1.0,
model_version="argmaxinc/mlx-FLUX.1-schnell",
model_version="argmaxinc/mlx-FLUX.1-schnell", # model_version="argmaxinc/mlx-FLUX.1-dev" for FLUX.1-dev
low_memory_mode=True,
a16=True,
w16=True,
Expand All @@ -120,7 +126,7 @@ Finally, to generate the image, use the `generate_image()` function:
```python
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3
NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3 and FLUX.1-dev
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3

image, _ = pipeline.generate_image(
Expand Down
6 changes: 5 additions & 1 deletion python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
}

T5_MAX_LENGTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"argmaxinc/mlx-FLUX.1-schnell": 256,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
"argmaxinc/mlx-FLUX.1-dev": 512,
}


Expand Down Expand Up @@ -653,7 +655,9 @@ def encode_text(
text,
(negative_text if cfg_weight > 1 else None),
)
padded_tokens_t5 = mx.zeros((1, 256)).astype(tokens_t5.dtype)
padded_tokens_t5 = mx.zeros((1, T5_MAX_LENGTH[self.model_version])).astype(
tokens_t5.dtype
)
padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[
[0], :
] # Ignore negative text
Expand Down
18 changes: 18 additions & 0 deletions python/src/diffusionkit/mlx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def hidden_size(self) -> int:

low_memory_mode: bool = True

guidance_embed: bool = False


SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35])

Expand All @@ -90,6 +92,22 @@ def hidden_size(self) -> int:
dtype=mx.bfloat16,
)

FLUX_DEV = MMDiTConfig(
num_heads=24,
depth_multimodal=19,
depth_unified=38,
parallel_mlp_for_unified_blocks=True,
hidden_size_override=3072,
patchify_via_reshape=True,
pos_embed_type=PositionalEncoding.PreSDPARope,
rope_axes_dim=(16, 56, 56),
pooled_text_embed_dim=768, # CLIP-L/14 only
use_qk_norm=True,
float16_dtype=mx.bfloat16,
guidance_embed=True,
dtype=mx.bfloat16,
)


@dataclass
class AutoencoderConfig:
Expand Down
24 changes: 23 additions & 1 deletion python/src/diffusionkit/mlx/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, config: MMDiTConfig):
super().__init__()
self.config = config

if config.guidance_embed:
self.guidance_in = MLPEmbedder(
in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size
)
else:
self.guidance_in = nn.Identity()

# Input adapters and embeddings
self.x_embedder = LatentImageAdapter(config)

Expand Down Expand Up @@ -209,6 +216,9 @@ def __call__(
else:
positional_encodings = None

if self.config.guidance_embed:
timestep = self.guidance_in(self.t_embedder(timestep))

# MultiModalTransformer layers
if self.config.depth_multimodal > 0:
for bidx, block in enumerate(self.multimodal_transformer_blocks):
Expand Down Expand Up @@ -236,7 +246,6 @@ def __call__(
:, token_level_text_embeddings.shape[1] :, ...
]

# Final layer
latent_image_embeddings = self.final_layer(
latent_image_embeddings,
timestep,
Expand Down Expand Up @@ -933,6 +942,19 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
)


class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)

def __call__(self, x):
return self.mlp(x)


def affine_transform(
x: mx.array,
shift: mx.array,
Expand Down
10 changes: 9 additions & 1 deletion python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
"vae": "ae.safetensors",
},
"argmaxinc/mlx-FLUX.1-dev": {
"argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors",
"vae": "ae.safetensors",
},
}
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
_MODELS = {
Expand Down Expand Up @@ -75,6 +79,10 @@
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
"argmaxinc/mlx-FLUX.1-dev": {
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
}

_FLOAT16 = mx.bfloat16
Expand Down Expand Up @@ -704,7 +712,7 @@ def load_flux(
hf_hub_download(key, "config.json")
weights = mx.load(flux_weights_ckpt)

if model_key == "argmaxinc/mlx-FLUX.1-schnell":
if model_key in ["argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-dev"]:
weights = flux_state_dict_adjustments(
weights,
prefix="",
Expand Down
5 changes: 4 additions & 1 deletion python/src/diffusionkit/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
"sd3-8b-unreleased": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
WIDTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
SHIFT = {
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
"sd3-8b-unreleased": 3.0,
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
"argmaxinc/mlx-FLUX.1-dev": 1.0,
}


Expand Down Expand Up @@ -111,7 +114,7 @@ def cli():
args.a16 = True

if "FLUX" in args.model_version and args.cfg > 0.0:
logger.warning("Disabling CFG for FLUX.1-schnell model.")
logger.warning(f"Disabling CFG for {args.model_version} model.")
args.cfg = 0.0

if args.benchmark_mode:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import find_packages, setup
from setuptools.command.install import install

VERSION = "0.3.5"
VERSION = "0.4.0"


class VersionInstallCommand(install):
Expand All @@ -29,7 +29,7 @@ def run(self):
"argmaxtools>=0.1.13",
"torch",
"safetensors",
"mlx>=0.16.3",
"mlx>=0.17.1",
"jaxtyping",
"transformers",
"pillow",
Expand Down
Loading