Skip to content

Support Lumina-image-2.0 #1927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 69 commits into
base: sd3
Choose a base branch
from
Open

Support Lumina-image-2.0 #1927

wants to merge 69 commits into from

Conversation

sdbds
Copy link
Contributor

@sdbds sdbds commented Feb 12, 2025

Still in preparation.

After checking their sampler using flux and vae, the textencoder part uses google's gemma2

@kohya-ss CC

@sdbds sdbds marked this pull request as draft February 12, 2025 08:32
@sdbds sdbds mentioned this pull request Feb 12, 2025
@rockerBOO
Copy link
Contributor

I got this setup locally, I know it's not ready for anything but I want to get it working. Let me know if you want to work together on this. I can help with some of the model loading parts which is where I got stuck with after poking at it. If you are progressed past this, I can help wherever else or just testing.

Thanks.

@sdbds
Copy link
Contributor Author

sdbds commented Feb 15, 2025

I got this setup locally, I know it's not ready for anything but I want to get it working. Let me know if you want to work together on this. I can help with some of the model loading parts which is where I got stuck with after poking at it. If you are progressed past this, I can help wherever else or just testing.

Thanks.

Thank you, the framework is basically set up at the moment, but there is still some room for improvement in the caching strategy.

I think I can discuss with @kohya-ss whether to continue using the previous method.

#1924 (comment)

@sdbds sdbds marked this pull request as ready for review February 15, 2025 09:12
@envy-ai
Copy link

envy-ai commented Feb 15, 2025

Thank you, the framework is basically set up at the moment, but there is still some room for improvement in the caching strategy.

Does that mean I can download your fork and test it now?

@rockerBOO
Copy link
Contributor

It's still not quite working but I'm working through some issues at the moment. Mostly with model loading but will see what else is needed after that. It is fairly barebones so wouldn't expect it to be in working state just yet.

@sdbds
Copy link
Contributor Author

sdbds commented Feb 17, 2025

After multiple updates, the project can now run under limited conditions:

  1. Flash_attn on Windows will cause NAN, so it must be run in a Linux environment.
    Later consideration will be given to transforming it into SDP or xformers-driven
  2. The POS ID calculation for token sequences is not padded to the max length, which leads to the necessity of batchsize = 1

@kohya-ss
Copy link
Owner

Regarding strategy, I would like you to proceed as is. I would like to refactor it together with other architectures later.

The script seems to assume that the model file is .safetensors, but I could only find .pth: https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/tree/main

I would appreciate it if you could tell me where .safetensors is.

@rockerBOO
Copy link
Contributor

I converted their consolidated.00-of-01.pth here https://huggingface.co/rockerBOO/lumina-image-2/blob/main/lumina-image-2.safetensors

@kohya-ss
Copy link
Owner

I'm sorry this is so late. I am testing the training, but the sample image seems to be a black image even with --sample_at_first, and the loss is also NaN. Can you give me some hints?

Lumina checkpoint is download from https://huggingface.co/rockerBOO/lumina-image-2/tree/main, and Gemma2 and AE are download from https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/tree/main/split_files.

The command is:

 accelerate launch  --mixed_precision bf16 --num_cpu_threads_per_process 1 lumina_train_network.py 
    --pretrained_model_name_or_path path\to\lumina-2.0\lumina-image-2.safetensors  
    --gemma2 path\to\lumina-2.0\gemma_2_2b_fp16.safetensors --ae path\to\lumina-2.0\ae.safetensors 
    --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 
    --seed 42 --mixed_precision bf16 --save_precision bf16 
    --network_module networks.lora_lumina --network_dim 4 
    --optimizer_type adamw8bit --learning_rate 1e-4  --gradient_checkpointing --highvram 
    --max_train_epochs 8 --save_every_n_epochs 1 
    --dataset_config path\to\dataset_config.toml --output_dir path\to\output\lora --output_name lumina-test-1 
    --sample_prompts=path\to\prompts.txt --sample_every_n_epochs 1 --vae_batch_size 4 --sample_at_first

@sdbds
Copy link
Contributor Author

sdbds commented May 25, 2025

I'm sorry this is so late. I am testing the training, but the sample image seems to be a black image even with --sample_at_first, and the loss is also NaN. Can you give me some hints?

Lumina checkpoint is download from https://huggingface.co/rockerBOO/lumina-image-2/tree/main, and Gemma2 and AE are download from https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/tree/main/split_files.

The command is:

 accelerate launch  --mixed_precision bf16 --num_cpu_threads_per_process 1 lumina_train_network.py 
    --pretrained_model_name_or_path path\to\lumina-2.0\lumina-image-2.safetensors  
    --gemma2 path\to\lumina-2.0\gemma_2_2b_fp16.safetensors --ae path\to\lumina-2.0\ae.safetensors 
    --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 
    --seed 42 --mixed_precision bf16 --save_precision bf16 
    --network_module networks.lora_lumina --network_dim 4 
    --optimizer_type adamw8bit --learning_rate 1e-4  --gradient_checkpointing --highvram 
    --max_train_epochs 8 --save_every_n_epochs 1 
    --dataset_config path\to\dataset_config.toml --output_dir path\to\output\lora --output_name lumina-test-1 
    --sample_prompts=path\to\prompts.txt --sample_every_n_epochs 1 --vae_batch_size 4 --sample_at_first

Most flash_attn on Windows lack compiled training backends, causing usage to result in NAN.
Use the version I compiled directly, or compile a version with training backends yourself.
https://github.com/sdbds/flash-attention-for-windows/releases

@kohya-ss
Copy link
Owner

Thank you, I understand. So the --use_flash_attn option is required. The sample image is successfully generated, but the loss goes NaN at the first step. I'm using Flash Attention which is a same as the on used in Musubi Tuner repo, so it will be fine.

I got a following warning. Is it ok?: sd-scripts\library\lumina_models.py:51: UserWarning: Cannot import apex RMSNorm, switch to vanilla implementation

@rockerBOO
Copy link
Contributor

rockerBOO commented May 25, 2025

If you use Pytorch 2.6 I believe SDPA works correctly (which is the default).

UserWarning: Cannot import apex RMSNorm, switch to vanilla implementation

I think apex is https://github.com/NVIDIA/apex. So this warning is fine unless you wanted to use that library. We may be able to get rid of the warning but it was in the original implementation.

@kohya-ss
Copy link
Owner

Thank you! It is true that NaNs occur with SDPA in PyTorch 2.4 (venv with requirements.txt of the sd3 branch), but NaNs do not seem to occur in PyTorch 2.6 (venv with Musubi Tuner). Do you know the reason for this? And should we move to PyTorch 2.6?

@sdbds
Copy link
Contributor Author

sdbds commented May 25, 2025

Thank you! It is true that NaNs occur with SDPA in PyTorch 2.4 (venv with requirements.txt of the sd3 branch), but NaNs do not seem to occur in PyTorch 2.6 (venv with Musubi Tuner). Do you know the reason for this? And should we move to PyTorch 2.6?

I guess it's a bug in 2.4, and the migration to 2.7 didn't cause any issues... Currently, 2.7 hasn't encountered any bugs, and it can support 50xx GPUs as soon as possible.

@kohya-ss
Copy link
Owner

Thank you, I think the code in sd3 branch will work with PyTorch 2.6 or later, but some testing may be better. It might be a good idea to make it clear in the Lumina documentation that PyTorch 2.6 or later is required.

@rockerBOO
Copy link
Contributor

rockerBOO commented May 25, 2025

Thank you! It is true that NaNs occur with SDPA in PyTorch 2.4 (venv with requirements.txt of the sd3 branch), but NaNs do not seem to occur in PyTorch 2.6 (venv with Musubi Tuner). Do you know the reason for this? And should we move to PyTorch 2.6?

Something specific to their architecture but since the new version resolved it I didn't look further. They had some stuff in the code for bf16 and using flash attention but I reworked it so it used SDPA and then flash attention was a specific opt-in.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review. I think I've found a potential cause of NaN issues ​​in PyTorch 2.4. I think it may be better to address this cause for future proof. What do you think?


# Refine image context
for layer in self.noise_refiner:
x = layer(x, x_mask, img_freqs_cis, t)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, x_mask is zero values (set in line 1113).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this might be identifying a separate bug with how this was refactored. Can see it upstream: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py#L730-L733 I can give it another look through to fix this though.

Fixing the zero issue with the attention is probably also a good idea.

This comment was marked as duplicate.

Comment on lines +470 to +475
if valid_indices.numel() == 0:
# If all tokens are masked, create a zero output
batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If x_mask is zero values, sage_attn returns zero values of the expected output shape.

Comment on lines +387 to +393
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that SDPA may return NaN if the mask is zero values. With SDPA, it seems that we need to write code that returns zero values with the expected output shape when the mask is zero values, like sage_attn.

Copy link
Owner

@kohya-ss kohya-ss Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this will work (do not work if some of items of the batch are all-zero and some are not):

            valid_indices = torch.nonzero(x_mask, as_tuple=False).squeeze(-1)
            if valid_indices.numel() == 0:
                # If all tokens are masked, create a zero output
                # NOTE: This does not assume that there will be a mix of masked and unmasked items in the batch.
                output = torch.zeros_like(xq, dtype=dtype)
            else:
                output = (
                    self.attention_processor(
                        xq.permute(0, 2, 1, 3),
                        xk.permute(0, 2, 1, 3),
                        xv.permute(0, 2, 1, 3),
                        attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
                        scale=softmax_scale,
                    )
                    .permute(0, 2, 1, 3)
                    .to(dtype)
                )

@sdbds
Copy link
Contributor Author

sdbds commented Jun 5, 2025

Sorry for the late review. I think I've found a potential cause of NaN issues ​​in PyTorch 2.4. I think it may be better to address this cause for future proof. What do you think?

pytorch/pytorch#130014

I suspect it's a similar bug in PyTorch 2.4, and the best solution is to prioritize upgrading the PyTorch version to resolve the issue.
Because PyTorch 2.5 changed the cuDNN backend for SDPA

@kohya-ss
Copy link
Owner

kohya-ss commented Jun 5, 2025

I suspect it's a similar bug in PyTorch 2.4, and the best solution is to prioritize upgrading the PyTorch version to resolve the issue.
Because PyTorch 2.5 changed the cuDNN backend for SDPA

It's true about the PyTorch bug.

However, as rockerBOO pointed out, it seems necessary to create the mask correctly when calling noise_refiner. This will also prevent the effect of the bug in PyTorch.

@sdbds
Copy link
Contributor Author

sdbds commented Jun 5, 2025

I suspect it's a similar bug in PyTorch 2.4, and the best solution is to prioritize upgrading the PyTorch version to resolve the issue.
Because PyTorch 2.5 changed the cuDNN backend for SDPA

It's true about the PyTorch bug.

However, as rockerBOO pointed out, it seems necessary to create the mask correctly when calling noise_refiner. This will also prevent the effect of the bug in PyTorch.

I think so, because DIT models of sd3, flux types were rarely used with masks before...

@kohya-ss
Copy link
Owner

kohya-ss commented Jun 8, 2025

It seems that noise_refiner does not work properly without masking, so this part needs to be fixed for correct training. If it is difficult to fix, I will update it myself after merging this PR.

This PR also changes FlowMatchEulerDiscreteScheduler, but I am wondering if this will affect the training of FLUX.1 and SD (whether those training will work correctly). What do you think?

@rockerBOO
Copy link
Contributor

It seems that noise_refiner does not work properly without masking, so this part needs to be fixed for correct training. If it is difficult to fix, I will update it myself after merging this PR.

To fix I just need to find the time to do it. The fix be fairly simple (comparing to the upstream version).

This PR also changes FlowMatchEulerDiscreteScheduler, but I am wondering if this will affect the training of FLUX.1 and SD (whether those training will work correctly). What do you think?

FlowMatchEulerDiscreteScheduler was updated to the latest version at the time. This adds the dynamic shifting to this noise scheduler. This was partially done because I didn't understand how sigmas/timesteps worked and was using this noise scheduler to do the similar behavior we have for SD3/Flux with the shifting.

@kohya-ss
Copy link
Owner

kohya-ss commented Jun 9, 2025

FlowMatchEulerDiscreteScheduler was updated to the latest version at the time. This adds the dynamic shifting to this noise scheduler. This was partially done because I didn't understand how sigmas/timesteps worked and was using this noise scheduler to do the similar behavior we have for SD3/Flux with the shifting.

Thank you for clarification!

In my understanding, the dynamic shifting (a new function time_shift added to sd3_train_utils.py) is same as time_shift in lumina_train_util.py.

And this function is called from get_noisy_model_input_and_timesteps in lumina_train_util.py.

So I think we can keep FlowMatchEulerDiscreteScheduler as before, because the dynamic shifting is done in get_noisy_model_input_and_timesteps and time_shift.

@rockerBOO
Copy link
Contributor

I reverted FlowMatchEulerDiscreteScheduler to how it was before. I added the masking in. Because we use buckets that have the same size it doesn't do much on it's own. Would allow you to have latent's of different sizes. Originally it supported a list of tensors.

It's accurate now though.

@kohya-ss
Copy link
Owner

kohya-ss commented Jun 12, 2025

Thank you for updating!

There seems to be a problem with handling the system prompt.

Add a print statement to tokenize method as follows.

        text = [text] if isinstance(text, str) else text
        print(f"Tokenizing: {text}")
        encodings = self.tokenizer(

Use the following dataset settings:

[general]
resolution = [1024, 1024]

[[datasets]]
batch_size = 1
enable_bucket = false
caption_extension = ".txt"

  [[datasets.subsets]]
  image_dir = "/path/to/image_dir"
  num_repeats = 1
  caption_prefix = "1girl, orichara1, "
  system_prompt = "DUMMY SYSTEM PROMPT"

Then you will see the following string:

Tokenizing: ['DUMMY SYSTEM PROMPT <Prompt Start> 1girl, orichara1,  DUMMY SYSTEM PROMPT <Prompt Start> 1girl, breasts, looking at viewer, blush, smile, multiple girls, skirt, shirt, medium breasts, closed mouth, white shirt, short sleeves, pleated skirt, outdoors, sky, shorts, solo focus, day, black skirt, blue sky, umbrella, building, people']

Without caption_prefix, the log is like

Tokenizing: ['DUMMY SYSTEM PROMPT <Prompt Start> DUMMY SYSTEM PROMPT <Prompt Start> 1girl, solo, breasts, looking at viewer, blush, skirt, shirt, closed mouth, white shirt, short sleeves, outdoors, sky, day, looking back, cloud, black skirt, blue sky, plant, building, scenery, city, sign, potted plant, road, power lines, street, utility pole, traffic light, crosswalk, storefront']

Adding system_prompt to the DataSet settings has a large impact, so I would like to avoid it if possible.
For example, how about adding an argument that gives the system_prompt to the training script for lumina, setting it in the tokenizer (or text encoding) strategy, and processing it in the tokenizer strategy?

In addition, if we specify --cache_text_encoder_outputs, the script will stop with the following error.

  File "path\to\sd-scripts\library\lumina_models.py", line 1148, in forward
    x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)  
  File "path\to\sd-scripts\library\lumina_models.py", line 1105, in patchify_and_embed
    cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
RuntimeError: The expanded size of the tensor (256) must match the existing size (6424135) at non-singleton dimension 0.  Target sizes: [256, 48].  Tensor sizes: [6424135, 48]


# Refine image context
for layer in self.noise_refiner:
x = layer(x, x_mask, img_freqs_cis, t)

This comment was marked as duplicate.

caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is probably a duplication in adding system_prompt here.

@sdbds
Copy link
Contributor Author

sdbds commented Jun 14, 2025

@rockerBOO Can you roll back the dataset settings from before system prompts?

@kohya-ss
Copy link
Owner

If you don't have time to update the system prompt handling, I would like to merge this PR into a new branch and update the masking and system prompt handling there. What do you think?

@sdbds
Copy link
Contributor Author

sdbds commented Jun 15, 2025

If you don't have time to update the system prompt handling, I would like to merge this PR into a new branch and update the masking and system prompt handling there. What do you think?

Sorry, I've been quite busy lately. I'd be happy to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants