-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Description
Flux.1 [DEV] Model Quality Issues Report
I don't know if I'm the only one with this problem, because I wasn't able to find any other mention of it. But some of the outputs from Flux.1 [DEV] model are of low quality. This either means blurry, or low resolution (the real image resolution is correct though).
I can see some correlation between a bad prompt and a bad seed, producing bad results. I can't find it though. If I have a "bad prompt", the outputs are mostly of bad quality. If I find a "bad seed", different prompts are likely to produce bad outputs from it too.
I've done extensive testing, and I can't pinpoint the issue. I've also compared with ComfyUI, and they seem to have the issue too.
For most of my tests, I'm running a qint8 quantized version of the official model, quantized with optimum-quanto.
I've also tested against a full-precision model. That produced some bad results too, but with a slightly lower frequency. And for clip - I've also tested without the long-prompt support. That produced bad results as well.
The minimal reproducible example is very simple - load pipeline, no fancy things around (our solution does include many controlnets, loras, etc, but I've reproduce identical outputs with a barebones implementation).
I'm attaching multiple outputs as an example, and the prompts too. There's a lot of them, I can find specific prompts for specific outputs, but again, I couldn't find any meaningful relationships between them.
Attached Files
(the first row contains outputs from SDXL, you can see the style is completely different)
Example Images
(Almost) Minimal Code Implementation
Main Test File
"""
Simple Flux Dev Model Test
A bare-bones implementation to test Flux.1-dev model directly without any extras.
"""
import torch
from diffusers import FluxPipeline
import time
from flux_utils import quantize_model, override_clip_prompt_embeds
def main(use_quantization=True):
"""Main function to test Flux model generation
Args:
use_quantization (bool): Whether to apply quantization to reduce memory usage
"""
print("Starting simple Flux Dev model test...")
print("-" * 50)
# Configuration
model_id = "black-forest-labs/FLUX.1-dev"
prompt = "The image features a vibrant red ruby viewed from the top, rendered in a stylized cartoon vector style. The ruby's facets are simplified yet clearly defined, showcasing its rich, deep red color with subtle highlights that emphasize its gem-like quality. This striking gem is set against a solid black background, which enhances its vividness and makes it stand out prominently. Designed as a game asset, the ruby combines clean lines and bold colors typical of vector art, making it visually appealing and easily recognizable within a gaming interface."
# Generation parameters
width = 1024
height = 1024
num_inference_steps = 20
guidance_scale = 3.5
seeds = [1, 42, 441360, 604577, 958566] # List of seeds to test
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
print(f"Device: {device}")
print(f"Model: {model_id}")
print(f"Prompt: {prompt}")
print(f"Size: {width}x{height}")
print(f"Steps: {num_inference_steps}")
print(f"Guidance Scale: {guidance_scale}")
print(f"Seeds: {seeds}")
print(f"Quantization: {'Enabled' if use_quantization else 'Disabled'}")
print("-" * 50)
try:
# Load the pipeline on CPU first to avoid memory issues
print("Loading Flux pipeline on CPU...")
start_time = time.time()
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype)
load_time = time.time() - start_time
print(f"Pipeline loaded on CPU in {load_time:.2f} seconds")
# Check prompt length
test_tokens = pipe.tokenizer.encode(prompt)
print(f"Prompt token count: {len(test_tokens)}")
print(f"Tokenizer max length: {pipe.tokenizer.model_max_length}")
# Apply quantization to reduce memory usage while on CPU (if enabled)
if use_quantization:
quantize_model(pipe)
else:
print("Quantization disabled - skipping quantization step")
# Apply CLIP override for longer prompts
override_clip_prompt_embeds(pipe)
# Now move to GPU if available
gpu_time = 0
if device == "cuda":
print("Moving quantized model to GPU...")
gpu_start = time.time()
pipe = pipe.to(device)
gpu_time = time.time() - gpu_start
print(f"Model moved to GPU in {gpu_time:.2f} seconds")
# Generate images for each seed
total_generation_time = 0
successful_generations = 0
for i, seed in enumerate(seeds, 1):
print(f"\nGenerating image {i}/{len(seeds)} with seed {seed}...")
# Set up generator for reproducible results
generator = torch.Generator(device=device).manual_seed(seed)
# Generate image
generation_start = time.time()
try:
with torch.no_grad():
result = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pil",
)
generation_time = time.time() - generation_start
total_generation_time += generation_time
print(f"Image {i} generated in {generation_time:.2f} seconds")
# Get the generated image
image = result.images[0]
# Save the image with seed in filename
output_path = f"flux_test_output_seed_{seed}.png"
image.save(output_path)
print(f"Image saved as: {output_path}")
# Display image info
print(f"Image size: {image.size}")
print(f"Image mode: {image.mode}")
successful_generations += 1
except Exception as e:
print(f"Error generating image with seed {seed}: {e}")
continue
print("-" * 50)
print("Test completed successfully!")
print(f"Generated {successful_generations}/{len(seeds)} images")
total_setup_time = load_time + gpu_time
print(f"Setup time: {total_setup_time:.2f} seconds")
print(f"Total generation time: {total_generation_time:.2f} seconds")
avg_generation_time = (
total_generation_time / successful_generations
if successful_generations > 0
else 0
)
print(f"Average generation time per image: {avg_generation_time:.2f} seconds")
print(f"Total time: {total_setup_time + total_generation_time:.2f} seconds")
# Clean up GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Clean up on error
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()
flux_utils.py
"""
Flux Utilities
Contains utility functions for Flux model optimization and enhancement.
"""
import torch
import types
from optimum.quanto import quantize, freeze, qint8
def quantize_model(pipe):
"""Apply quantization to reduce memory usage"""
print("Quantizing model components...")
quantization_type = qint8
# Quantize the transformer (main model component)
if hasattr(pipe, "transformer") and pipe.transformer is not None:
quantize(pipe.transformer, weights=quantization_type)
freeze(pipe.transformer)
print(" - Transformer quantized")
# Quantize text encoders if they exist
for encoder_name in ["text_encoder", "text_encoder_2"]:
if hasattr(pipe, encoder_name):
encoder = getattr(pipe, encoder_name)
if encoder is not None:
quantize(encoder, weights=quantization_type)
freeze(encoder)
print(f" - {encoder_name} quantized")
print("Quantization complete!")
def override_clip_prompt_embeds(pipe):
"""Override the CLIP prompt embedding method to handle longer prompts"""
print("Overriding CLIP prompt embedding method...")
def _get_clip_prompt_embeds(
self,
prompt,
num_images_per_prompt: int = 1,
device=None,
):
"""Custom CLIP encoding that overcomes the 77 token limit using batched encoding"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Tokenize without truncation to get full length
full_inputs = self.tokenizer(
prompt, padding=False, truncation=False, return_tensors="pt"
)
max_length = self.tokenizer.model_max_length
full_input_ids = full_inputs.input_ids
# If the prompt fits within the token limit, use standard processing
if full_input_ids.shape[-1] <= max_length:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False
)
prompt_embeds = prompt_embeds.pooler_output
else:
# Handle long prompts by chunking
print(
f"Long prompt detected ({full_input_ids.shape[-1]} tokens), using chunked encoding..."
)
all_embeddings = []
for prompt_text in prompt:
# Tokenize the full prompt
tokens = self.tokenizer.encode(prompt_text)
# Process in chunks
chunk_embeddings = []
for i in range(0, len(tokens), max_length - 2): # -2 for special tokens
chunk = tokens[i : i + max_length - 2]
# Add special tokens
chunk_with_special = (
[self.tokenizer.bos_token_id]
+ chunk
+ [self.tokenizer.eos_token_id]
)
# Pad to max_length
while len(chunk_with_special) < max_length:
chunk_with_special.append(self.tokenizer.pad_token_id)
# Convert to tensor and encode
chunk_tensor = torch.tensor([chunk_with_special], device=device)
chunk_embeds = self.text_encoder(
chunk_tensor, output_hidden_states=False
)
chunk_embeddings.append(chunk_embeds.pooler_output)
# Average the chunk embeddings
if len(chunk_embeddings) > 1:
prompt_embed = torch.stack(chunk_embeddings).mean(dim=0)
else:
prompt_embed = chunk_embeddings[0]
all_embeddings.append(prompt_embed)
prompt_embeds = torch.cat(all_embeddings, dim=0)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
# Override the method on the pipeline instance
pipe._get_clip_prompt_embeds = types.MethodType(_get_clip_prompt_embeds, pipe)
print("CLIP override applied!")