Skip to content

[Model Support] FLUX.1-dev #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
"FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
"FLUX.1-dev": "raoulritter/flux-dev-mlx",
}

T5_MAX_LENGTH = {
"stable-diffusion-3-medium": 512,
"FLUX.1-schnell": 256,
"FLUX.1-schnell-4bit-quantized": 256,
"FLUX.1-dev": 256,
}


Expand Down
18 changes: 18 additions & 0 deletions python/src/diffusionkit/mlx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def hidden_size(self) -> int:

low_memory_mode: bool = True

guidance_embed: bool = False


SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35])

Expand All @@ -90,6 +92,22 @@ def hidden_size(self) -> int:
dtype=mx.bfloat16,
)

FLUX_DEV = MMDiTConfig(
num_heads=24,
depth_multimodal=19,
depth_unified=38,
parallel_mlp_for_unified_blocks=True,
hidden_size_override=3072,
patchify_via_reshape=True,
pos_embed_type=PositionalEncoding.PreSDPARope,
rope_axes_dim=(16, 56, 56),
pooled_text_embed_dim=768, # CLIP-L/14 only
use_qk_norm=True,
float16_dtype=mx.bfloat16,
guidance_embed=True, # Add this line
dtype=mx.bfloat16
)


@dataclass
class AutoencoderConfig:
Expand Down
26 changes: 22 additions & 4 deletions python/src/diffusionkit/mlx/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def __init__(self, config: MMDiTConfig):
super().__init__()
self.config = config

if config.guidance_embed:
self.guidance_in = MLPEmbedder(in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size)
else:
self.guidance_in = nn.Identity()

# Input adapters and embeddings
self.x_embedder = LatentImageAdapter(config)

Expand Down Expand Up @@ -209,13 +214,15 @@ def __call__(
else:
positional_encodings = None

timestep_embedding = self.guidance_in(self.t_embedder(timestep))

# MultiModalTransformer layers
if self.config.depth_multimodal > 0:
for bidx, block in enumerate(self.multimodal_transformer_blocks):
latent_image_embeddings, token_level_text_embeddings = block(
latent_image_embeddings,
token_level_text_embeddings,
timestep,
timestep_embedding,
positional_encodings=positional_encodings,
)

Expand All @@ -228,18 +235,17 @@ def __call__(
for bidx, block in enumerate(self.unified_transformer_blocks):
latent_unified_embeddings = block(
latent_unified_embeddings,
timestep,
timestep_embedding,
positional_encodings=positional_encodings,
)

latent_image_embeddings = latent_unified_embeddings[
:, token_level_text_embeddings.shape[1] :, ...
]

# Final layer
latent_image_embeddings = self.final_layer(
latent_image_embeddings,
timestep,
timestep_embedding
)

if self.config.patchify_via_reshape:
Expand Down Expand Up @@ -932,6 +938,18 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
.flatten(-2)
)

class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim)
)

def __call__(self, x):
return self.mlp(x)


def affine_transform(
x: mx.array,
Expand Down
15 changes: 15 additions & 0 deletions python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
"FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
"vae": "ae.safetensors",
},
"raoulritter/flux-dev-mlx": {
"FLUX.1-dev": "flux1-dev-mlx.safetensors",
"vae": "ae.safetensors",
},
}
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
_MODELS = {
Expand Down Expand Up @@ -75,6 +79,10 @@
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
"raoulritter/flux-dev-mlx": {
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
}

_FLOAT16 = mx.bfloat16
Expand Down Expand Up @@ -710,6 +718,13 @@ def load_flux(
hidden_size=config.hidden_size,
mlp_ratio=config.mlp_ratio,
)
elif model_key == "FLUX.1-dev":
weights = flux_state_dict_adjustments(
weights,
prefix="",
hidden_size=config.hidden_size,
mlp_ratio=config.mlp_ratio,
)
elif model_key == "FLUX.1-schnell-4bit-quantized": # 4-bit ckpt already adjusted
nn.quantize(model)

Expand Down
123 changes: 123 additions & 0 deletions python/src/diffusionkit/mlx/test-conversion-mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import mlx.core as mx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please follow the UnitTest.TestCase usage and make this a simple unit test like this? Also, no need to upload to hub from within the test 👍

from mlx.utils import tree_flatten, tree_unflatten
from huggingface_hub import hf_hub_download, HfApi
import os
import sys
from pathlib import Path

from sympy import false
from tqdm import tqdm


current_dir = Path(__file__).resolve().parent
parent_dir = current_dir.parent
sys.path.append(str(parent_dir))

# Now try to import using both relative and absolute imports
try:
from .config import FLUX_DEV, FLUX_SCHNELL, MMDiTConfig, PositionalEncoding
from .mmdit import MMDiT
from .model_io import flux_state_dict_adjustments
except ImportError:
from diffusionkit.mlx.config import FLUX_DEV, FLUX_SCHNELL, MMDiTConfig, PositionalEncoding
from diffusionkit.mlx.mmdit import MMDiT
from diffusionkit.mlx.model_io import flux_state_dict_adjustments


def load_flux_weights(model_key="flux-dev"):
config = FLUX_DEV if model_key == "flux-dev" else FLUX_SCHNELL
repo_id = "black-forest-labs/FLUX.1-dev" if model_key == "flux-dev" else "black-forest-labs/FLUX.1-schnell"
file_name = "flux1-dev.safetensors" if model_key == "flux-dev" else "flux1-schnell.safetensors"

# Set custom HF_HOME location
custom_hf_home = "/Volumes/USB/huggingface/hub"
os.environ["HF_HOME"] = custom_hf_home

# Use the custom HF_HOME location or fall back to the default
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))

# Check if the file already exists in the custom location
local_file = os.path.join(hf_home, "hub", repo_id.split("/")[-1], file_name)
# Download the file if it doesn't exist

if not os.path.exists(local_file):
print(f"Downloading {file_name} to {hf_home}")
local_file = hf_hub_download(
repo_id,
file_name,
cache_dir=hf_home,
force_download=False,
resume_download=True,
)
else:
print(f"Using existing file: {local_file}")

# Load the weights
weights = mx.load(local_file)
return weights, config

def verify_conversion(weights, config):
# Initialize the model
model = MMDiT(config)
mlx_model = tree_flatten(model)
mlx_dict = {m[0]: m[1] for m in mlx_model if isinstance(m[1], mx.array)}

# Adjust the weights
adjusted_weights = flux_state_dict_adjustments(
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
)

# Verify the conversion
weights_set = set(adjusted_weights.keys())
mlx_dict_set = set(mlx_dict.keys())

print("Keys in weights but not in model:")
for k in weights_set - mlx_dict_set:
print(k)
print(f"Count: {len(weights_set - mlx_dict_set)}")

print("\nKeys in model but not in weights:")
for k in mlx_dict_set - weights_set:
print(k)
print(f"Count: {len(mlx_dict_set - weights_set)}")

print("\nShape mismatches:")
count = 0
for k in weights_set & mlx_dict_set:
if adjusted_weights[k].shape != mlx_dict[k].shape:
print(f"{k}: weights {adjusted_weights[k].shape}, model {mlx_dict[k].shape}")
count += 1
print(f"Total mismatches: {count}")

def save_modified_weights(weights, output_file):
print(f"Saving modified weights to {output_file}")
mx.save_safetensors(output_file, weights)
print("Weights saved successfully!")

def upload_to_hub(file_path, repo_id, token):
print(f"Uploading {file_path} to {repo_id}")
api = HfApi()
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=os.path.basename(file_path),
repo_id=repo_id,
token=token
)
print("Upload completed successfully!")

def main():
# Load the weights and config
weights, config = load_flux_weights("flux-dev") # or "flux-schnell"

# Verify the conversion
verify_conversion(weights, config)

output_file = "/Volumes/USB/flux1-dev-mlx.safetensors"
save_modified_weights(weights, output_file)

repo_id = "raoulritter/flux-dev-mlx"
token = os.getenv("HF_TOKEN") # Make sure to set this environment variable
# upload_to_hub(output_file, repo_id, token)

if __name__ == "__main__":
main()
65 changes: 65 additions & 0 deletions python/test-gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from pathlib import Path
from diffusionkit.mlx import FluxPipeline
from huggingface_hub import HfFolder, HfApi
from PIL import Image

# Define cache paths
usb_cache_path = "/Volumes/USB/huggingface/cache"
local_cache_path = os.path.expanduser("~/.cache/huggingface")


# Function to set and verify cache directory
def set_hf_cache():
if os.path.exists("/Volumes/USB"):
os.environ["HF_HOME"] = usb_cache_path
Path(usb_cache_path).mkdir(parents=True, exist_ok=True)
print(f"Using USB cache: {usb_cache_path}")
else:
os.environ["HF_HOME"] = local_cache_path
print(f"USB not found. Using local cache: {local_cache_path}")

print(f"HF_HOME is set to: {os.environ['HF_HOME']}")
HfFolder.save_token(HfFolder.get_token())


# Set cache before initializing the pipeline
set_hf_cache()

# Initialize the pipeline
pipeline = FluxPipeline(
shift=1.0,
model_version="FLUX.1-dev",
low_memory_mode=True,
a16=True,
w16=True,
)

# Load LoRA weights
# pipeline.load_lora_weights("XLabs-AI/flux-RealismLora")

# Define image generation parameters
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 10 # 4 for FLUX.1-schnell, 50 for SD3
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3
# LORA_SCALE = 0.8 # LoRA strength

# Define the prompt
prompt = "A photo realistic cat holding a sign that says hello world in the style of a snapchat from 2015"

# Generate the image
image, _ = pipeline.generate_image(
prompt,
cfg_weight=CFG_WEIGHT,
num_steps=NUM_STEPS,
latent_size=(HEIGHT // 8, WIDTH // 8),
# lora_scale=LORA_SCALE,
)

# Save the generated image
output_format = "png"
output_quality = 100
image.save(f"flux_image.{output_format}", format=output_format, quality=output_quality)

print(f"Image generation complete. Saved image in {output_format} format.")