Skip to content

Flux DEV - bad output quality (bad seeds?) #12311

@Mahrkeenerh

Description

@Mahrkeenerh

Flux.1 [DEV] Model Quality Issues Report

I don't know if I'm the only one with this problem, because I wasn't able to find any other mention of it. But some of the outputs from Flux.1 [DEV] model are of low quality. This either means blurry, or low resolution (the real image resolution is correct though).

I can see some correlation between a bad prompt and a bad seed, producing bad results. I can't find it though. If I have a "bad prompt", the outputs are mostly of bad quality. If I find a "bad seed", different prompts are likely to produce bad outputs from it too.

I've done extensive testing, and I can't pinpoint the issue. I've also compared with ComfyUI, and they seem to have the issue too.

For most of my tests, I'm running a qint8 quantized version of the official model, quantized with optimum-quanto.

I've also tested against a full-precision model. That produced some bad results too, but with a slightly lower frequency. And for clip - I've also tested without the long-prompt support. That produced bad results as well.

The minimal reproducible example is very simple - load pipeline, no fancy things around (our solution does include many controlnets, loras, etc, but I've reproduce identical outputs with a barebones implementation).

I'm attaching multiple outputs as an example, and the prompts too. There's a lot of them, I can find specific prompts for specific outputs, but again, I couldn't find any meaningful relationships between them.

Attached Files

prompts_graded.txt

(the first row contains outputs from SDXL, you can see the style is completely different)

Example Images

Image
Image
Image
Image
Image
Image
Image
Image
Image
Image

(Almost) Minimal Code Implementation

Main Test File

"""
Simple Flux Dev Model Test
A bare-bones implementation to test Flux.1-dev model directly without any extras.
"""

import torch
from diffusers import FluxPipeline
import time
from flux_utils import quantize_model, override_clip_prompt_embeds


def main(use_quantization=True):
    """Main function to test Flux model generation

    Args:
        use_quantization (bool): Whether to apply quantization to reduce memory usage
    """
    print("Starting simple Flux Dev model test...")
    print("-" * 50)
    
    # Configuration
    model_id = "black-forest-labs/FLUX.1-dev"
    prompt = "The image features a vibrant red ruby viewed from the top, rendered in a stylized cartoon vector style. The ruby's facets are simplified yet clearly defined, showcasing its rich, deep red color with subtle highlights that emphasize its gem-like quality. This striking gem is set against a solid black background, which enhances its vividness and makes it stand out prominently. Designed as a game asset, the ruby combines clean lines and bold colors typical of vector art, making it visually appealing and easily recognizable within a gaming interface."
    
    # Generation parameters
    width = 1024
    height = 1024
    num_inference_steps = 20
    guidance_scale = 3.5
    seeds = [1, 42, 441360, 604577, 958566]  # List of seeds to test

    # Device setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    
    print(f"Device: {device}")
    print(f"Model: {model_id}")
    print(f"Prompt: {prompt}")
    print(f"Size: {width}x{height}")
    print(f"Steps: {num_inference_steps}")
    print(f"Guidance Scale: {guidance_scale}")
    print(f"Seeds: {seeds}")
    print(f"Quantization: {'Enabled' if use_quantization else 'Disabled'}")
    print("-" * 50)
    
    try:
        # Load the pipeline on CPU first to avoid memory issues
        print("Loading Flux pipeline on CPU...")
        start_time = time.time()
        
        pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype)

        load_time = time.time() - start_time
        print(f"Pipeline loaded on CPU in {load_time:.2f} seconds")

        # Check prompt length
        test_tokens = pipe.tokenizer.encode(prompt)
        print(f"Prompt token count: {len(test_tokens)}")
        print(f"Tokenizer max length: {pipe.tokenizer.model_max_length}")

        # Apply quantization to reduce memory usage while on CPU (if enabled)
        if use_quantization:
            quantize_model(pipe)
        else:
            print("Quantization disabled - skipping quantization step")

        # Apply CLIP override for longer prompts
        override_clip_prompt_embeds(pipe)

        # Now move to GPU if available
        gpu_time = 0
        if device == "cuda":
            print("Moving quantized model to GPU...")
            gpu_start = time.time()
            pipe = pipe.to(device)
            gpu_time = time.time() - gpu_start
            print(f"Model moved to GPU in {gpu_time:.2f} seconds")

        # Generate images for each seed
        total_generation_time = 0
        successful_generations = 0

        for i, seed in enumerate(seeds, 1):
            print(f"\nGenerating image {i}/{len(seeds)} with seed {seed}...")

            # Set up generator for reproducible results
            generator = torch.Generator(device=device).manual_seed(seed)

            # Generate image
            generation_start = time.time()

            try:
                with torch.no_grad():
                    result = pipe(
                        prompt=prompt,
                        width=width,
                        height=height,
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        generator=generator,
                        output_type="pil",
                    )

                generation_time = time.time() - generation_start
                total_generation_time += generation_time
                print(f"Image {i} generated in {generation_time:.2f} seconds")

                # Get the generated image
                image = result.images[0]

                # Save the image with seed in filename
                output_path = f"flux_test_output_seed_{seed}.png"
                image.save(output_path)
                print(f"Image saved as: {output_path}")

                # Display image info
                print(f"Image size: {image.size}")
                print(f"Image mode: {image.mode}")

                successful_generations += 1

            except Exception as e:
                print(f"Error generating image with seed {seed}: {e}")
                continue
        
        print("-" * 50)
        print("Test completed successfully!")
        print(f"Generated {successful_generations}/{len(seeds)} images")
        total_setup_time = load_time + gpu_time
        print(f"Setup time: {total_setup_time:.2f} seconds")
        print(f"Total generation time: {total_generation_time:.2f} seconds")
        avg_generation_time = (
            total_generation_time / successful_generations
            if successful_generations > 0
            else 0
        )
        print(f"Average generation time per image: {avg_generation_time:.2f} seconds")
        print(f"Total time: {total_setup_time + total_generation_time:.2f} seconds")
        
        # Clean up GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc()
        
        # Clean up on error
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


if __name__ == "__main__":
    main()

flux_utils.py

"""
Flux Utilities
Contains utility functions for Flux model optimization and enhancement.
"""

import torch
import types
from optimum.quanto import quantize, freeze, qint8


def quantize_model(pipe):
    """Apply quantization to reduce memory usage"""
    print("Quantizing model components...")
    quantization_type = qint8

    # Quantize the transformer (main model component)
    if hasattr(pipe, "transformer") and pipe.transformer is not None:
        quantize(pipe.transformer, weights=quantization_type)
        freeze(pipe.transformer)
        print("  - Transformer quantized")

    # Quantize text encoders if they exist
    for encoder_name in ["text_encoder", "text_encoder_2"]:
        if hasattr(pipe, encoder_name):
            encoder = getattr(pipe, encoder_name)
            if encoder is not None:
                quantize(encoder, weights=quantization_type)
                freeze(encoder)
                print(f"  - {encoder_name} quantized")

    print("Quantization complete!")


def override_clip_prompt_embeds(pipe):
    """Override the CLIP prompt embedding method to handle longer prompts"""
    print("Overriding CLIP prompt embedding method...")

    def _get_clip_prompt_embeds(
        self,
        prompt,
        num_images_per_prompt: int = 1,
        device=None,
    ):
        """Custom CLIP encoding that overcomes the 77 token limit using batched encoding"""
        device = device or self._execution_device

        prompt = [prompt] if isinstance(prompt, str) else prompt
        batch_size = len(prompt)

        # Tokenize without truncation to get full length
        full_inputs = self.tokenizer(
            prompt, padding=False, truncation=False, return_tensors="pt"
        )

        max_length = self.tokenizer.model_max_length
        full_input_ids = full_inputs.input_ids

        # If the prompt fits within the token limit, use standard processing
        if full_input_ids.shape[-1] <= max_length:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_overflowing_tokens=False,
                return_length=False,
                return_tensors="pt",
            )

            text_input_ids = text_inputs.input_ids
            prompt_embeds = self.text_encoder(
                text_input_ids.to(device), output_hidden_states=False
            )
            prompt_embeds = prompt_embeds.pooler_output
        else:
            # Handle long prompts by chunking
            print(
                f"Long prompt detected ({full_input_ids.shape[-1]} tokens), using chunked encoding..."
            )

            all_embeddings = []
            for prompt_text in prompt:
                # Tokenize the full prompt
                tokens = self.tokenizer.encode(prompt_text)

                # Process in chunks
                chunk_embeddings = []
                for i in range(0, len(tokens), max_length - 2):  # -2 for special tokens
                    chunk = tokens[i : i + max_length - 2]

                    # Add special tokens
                    chunk_with_special = (
                        [self.tokenizer.bos_token_id]
                        + chunk
                        + [self.tokenizer.eos_token_id]
                    )

                    # Pad to max_length
                    while len(chunk_with_special) < max_length:
                        chunk_with_special.append(self.tokenizer.pad_token_id)

                    # Convert to tensor and encode
                    chunk_tensor = torch.tensor([chunk_with_special], device=device)
                    chunk_embeds = self.text_encoder(
                        chunk_tensor, output_hidden_states=False
                    )
                    chunk_embeddings.append(chunk_embeds.pooler_output)

                # Average the chunk embeddings
                if len(chunk_embeddings) > 1:
                    prompt_embed = torch.stack(chunk_embeddings).mean(dim=0)
                else:
                    prompt_embed = chunk_embeddings[0]

                all_embeddings.append(prompt_embed)

            prompt_embeds = torch.cat(all_embeddings, dim=0)

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)

        return prompt_embeds

    # Override the method on the pipeline instance
    pipe._get_clip_prompt_embeds = types.MethodType(_get_clip_prompt_embeds, pipe)
    print("CLIP override applied!")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions