-
Notifications
You must be signed in to change notification settings - Fork 36
[Model Support] FLUX.1-dev #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
217a211
Temp commit to show issue:
raoulritter 19d9b6c
Temp commit to show issue. Need to add guidance_embed according to im…
raoulritter 8f2c9a6
Working conversion unable to save properly
raoulritter 19d65df
Working conversion no negative prompt
raoulritter dce4c87
Update according to review
raoulritter 1508f9f
Merge branch 'main' into main
ZachNagengast 1577412
Apply suggestions from code review for flux dev
ZachNagengast b076287
Guidance variable fixes
ZachNagengast b329c27
Formatting
ZachNagengast a3c108c
Use proper constants
ZachNagengast 0391f29
Update python/src/diffusionkit/mlx/config.py
ZachNagengast c5726c2
Update formatting
ZachNagengast 6541efb
Merge branch 'main' of https://github.com/raoulritter/DiffusionKit in…
ZachNagengast d0ee0fd
update README.md
arda-argmax 5d55b07
deleted test codes
arda-argmax 9c2289a
update README
arda-argmax 5cdda4d
update version
arda-argmax File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import mlx.core as mx | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please follow the |
||
from mlx.utils import tree_flatten, tree_unflatten | ||
from huggingface_hub import hf_hub_download, HfApi | ||
import os | ||
import sys | ||
from pathlib import Path | ||
|
||
from sympy import false | ||
from tqdm import tqdm | ||
|
||
|
||
current_dir = Path(__file__).resolve().parent | ||
parent_dir = current_dir.parent | ||
sys.path.append(str(parent_dir)) | ||
|
||
# Now try to import using both relative and absolute imports | ||
try: | ||
from .config import FLUX_DEV, FLUX_SCHNELL, MMDiTConfig, PositionalEncoding | ||
from .mmdit import MMDiT | ||
from .model_io import flux_state_dict_adjustments | ||
except ImportError: | ||
from diffusionkit.mlx.config import FLUX_DEV, FLUX_SCHNELL, MMDiTConfig, PositionalEncoding | ||
from diffusionkit.mlx.mmdit import MMDiT | ||
from diffusionkit.mlx.model_io import flux_state_dict_adjustments | ||
|
||
|
||
def load_flux_weights(model_key="flux-dev"): | ||
config = FLUX_DEV if model_key == "flux-dev" else FLUX_SCHNELL | ||
repo_id = "black-forest-labs/FLUX.1-dev" if model_key == "flux-dev" else "black-forest-labs/FLUX.1-schnell" | ||
file_name = "flux1-dev.safetensors" if model_key == "flux-dev" else "flux1-schnell.safetensors" | ||
|
||
# Set custom HF_HOME location | ||
custom_hf_home = "/Volumes/USB/huggingface/hub" | ||
os.environ["HF_HOME"] = custom_hf_home | ||
|
||
# Use the custom HF_HOME location or fall back to the default | ||
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) | ||
|
||
# Check if the file already exists in the custom location | ||
local_file = os.path.join(hf_home, "hub", repo_id.split("/")[-1], file_name) | ||
# Download the file if it doesn't exist | ||
|
||
if not os.path.exists(local_file): | ||
print(f"Downloading {file_name} to {hf_home}") | ||
local_file = hf_hub_download( | ||
repo_id, | ||
file_name, | ||
cache_dir=hf_home, | ||
force_download=False, | ||
resume_download=True, | ||
) | ||
else: | ||
print(f"Using existing file: {local_file}") | ||
|
||
# Load the weights | ||
weights = mx.load(local_file) | ||
return weights, config | ||
|
||
def verify_conversion(weights, config): | ||
# Initialize the model | ||
model = MMDiT(config) | ||
mlx_model = tree_flatten(model) | ||
mlx_dict = {m[0]: m[1] for m in mlx_model if isinstance(m[1], mx.array)} | ||
|
||
# Adjust the weights | ||
adjusted_weights = flux_state_dict_adjustments( | ||
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio | ||
) | ||
|
||
# Verify the conversion | ||
weights_set = set(adjusted_weights.keys()) | ||
mlx_dict_set = set(mlx_dict.keys()) | ||
|
||
print("Keys in weights but not in model:") | ||
for k in weights_set - mlx_dict_set: | ||
print(k) | ||
print(f"Count: {len(weights_set - mlx_dict_set)}") | ||
|
||
print("\nKeys in model but not in weights:") | ||
for k in mlx_dict_set - weights_set: | ||
print(k) | ||
print(f"Count: {len(mlx_dict_set - weights_set)}") | ||
|
||
print("\nShape mismatches:") | ||
count = 0 | ||
for k in weights_set & mlx_dict_set: | ||
if adjusted_weights[k].shape != mlx_dict[k].shape: | ||
print(f"{k}: weights {adjusted_weights[k].shape}, model {mlx_dict[k].shape}") | ||
count += 1 | ||
print(f"Total mismatches: {count}") | ||
|
||
def save_modified_weights(weights, output_file): | ||
print(f"Saving modified weights to {output_file}") | ||
mx.save_safetensors(output_file, weights) | ||
print("Weights saved successfully!") | ||
|
||
def upload_to_hub(file_path, repo_id, token): | ||
print(f"Uploading {file_path} to {repo_id}") | ||
api = HfApi() | ||
api.upload_file( | ||
path_or_fileobj=file_path, | ||
path_in_repo=os.path.basename(file_path), | ||
repo_id=repo_id, | ||
token=token | ||
) | ||
print("Upload completed successfully!") | ||
|
||
def main(): | ||
# Load the weights and config | ||
weights, config = load_flux_weights("flux-dev") # or "flux-schnell" | ||
|
||
# Verify the conversion | ||
verify_conversion(weights, config) | ||
|
||
output_file = "/Volumes/USB/flux1-dev-mlx.safetensors" | ||
save_modified_weights(weights, output_file) | ||
|
||
repo_id = "raoulritter/flux-dev-mlx" | ||
token = os.getenv("HF_TOKEN") # Make sure to set this environment variable | ||
# upload_to_hub(output_file, repo_id, token) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
raoulritter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from pathlib import Path | ||
from diffusionkit.mlx import FluxPipeline | ||
from huggingface_hub import HfFolder, HfApi | ||
from PIL import Image | ||
|
||
# Define cache paths | ||
usb_cache_path = "/Volumes/USB/huggingface/cache" | ||
local_cache_path = os.path.expanduser("~/.cache/huggingface") | ||
|
||
|
||
# Function to set and verify cache directory | ||
def set_hf_cache(): | ||
if os.path.exists("/Volumes/USB"): | ||
os.environ["HF_HOME"] = usb_cache_path | ||
Path(usb_cache_path).mkdir(parents=True, exist_ok=True) | ||
print(f"Using USB cache: {usb_cache_path}") | ||
else: | ||
os.environ["HF_HOME"] = local_cache_path | ||
print(f"USB not found. Using local cache: {local_cache_path}") | ||
|
||
print(f"HF_HOME is set to: {os.environ['HF_HOME']}") | ||
HfFolder.save_token(HfFolder.get_token()) | ||
|
||
|
||
# Set cache before initializing the pipeline | ||
set_hf_cache() | ||
|
||
# Initialize the pipeline | ||
pipeline = FluxPipeline( | ||
shift=1.0, | ||
model_version="FLUX.1-dev", | ||
low_memory_mode=True, | ||
a16=True, | ||
w16=True, | ||
) | ||
|
||
# Load LoRA weights | ||
# pipeline.load_lora_weights("XLabs-AI/flux-RealismLora") | ||
|
||
# Define image generation parameters | ||
HEIGHT = 512 | ||
WIDTH = 512 | ||
NUM_STEPS = 10 # 4 for FLUX.1-schnell, 50 for SD3 | ||
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3 | ||
# LORA_SCALE = 0.8 # LoRA strength | ||
|
||
# Define the prompt | ||
prompt = "A photo realistic cat holding a sign that says hello world in the style of a snapchat from 2015" | ||
|
||
# Generate the image | ||
image, _ = pipeline.generate_image( | ||
prompt, | ||
cfg_weight=CFG_WEIGHT, | ||
num_steps=NUM_STEPS, | ||
latent_size=(HEIGHT // 8, WIDTH // 8), | ||
# lora_scale=LORA_SCALE, | ||
) | ||
|
||
# Save the generated image | ||
output_format = "png" | ||
output_quality = 100 | ||
image.save(f"flux_image.{output_format}", format=output_format, quality=output_quality) | ||
|
||
print(f"Image generation complete. Saved image in {output_format} format.") |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.