-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Support Lumina-image-2.0 #1927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sd3
Are you sure you want to change the base?
Support Lumina-image-2.0 #1927
Conversation
I got this setup locally, I know it's not ready for anything but I want to get it working. Let me know if you want to work together on this. I can help with some of the model loading parts which is where I got stuck with after poking at it. If you are progressed past this, I can help wherever else or just testing. Thanks. |
Thank you, the framework is basically set up at the moment, but there is still some room for improvement in the caching strategy. I think I can discuss with @kohya-ss whether to continue using the previous method. |
Does that mean I can download your fork and test it now? |
It's still not quite working but I'm working through some issues at the moment. Mostly with model loading but will see what else is needed after that. It is fairly barebones so wouldn't expect it to be in working state just yet. |
Lumina 2 and Gemma 2 model loading
# Conflicts: # library/lumina_models.py
Lumina cache checkpointing
After multiple updates, the project can now run under limited conditions:
|
Samples attention
Regarding strategy, I would like you to proceed as is. I would like to refactor it together with other architectures later. The script seems to assume that the model file is .safetensors, but I could only find .pth: https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/tree/main I would appreciate it if you could tell me where .safetensors is. |
I'm sorry this is so late. I am testing the training, but the sample image seems to be a black image even with Lumina checkpoint is download from https://huggingface.co/rockerBOO/lumina-image-2/tree/main, and Gemma2 and AE are download from https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/tree/main/split_files. The command is:
|
Most flash_attn on Windows lack compiled training backends, causing usage to result in NAN. |
Thank you, I understand. So the I got a following warning. Is it ok?: |
If you use Pytorch 2.6 I believe SDPA works correctly (which is the default).
I think apex is https://github.com/NVIDIA/apex. So this warning is fine unless you wanted to use that library. We may be able to get rid of the warning but it was in the original implementation. |
Thank you! It is true that NaNs occur with SDPA in PyTorch 2.4 (venv with requirements.txt of the sd3 branch), but NaNs do not seem to occur in PyTorch 2.6 (venv with Musubi Tuner). Do you know the reason for this? And should we move to PyTorch 2.6? |
I guess it's a bug in 2.4, and the migration to 2.7 didn't cause any issues... Currently, 2.7 hasn't encountered any bugs, and it can support 50xx GPUs as soon as possible. |
Thank you, I think the code in sd3 branch will work with PyTorch 2.6 or later, but some testing may be better. It might be a good idea to make it clear in the Lumina documentation that PyTorch 2.6 or later is required. |
Something specific to their architecture but since the new version resolved it I didn't look further. They had some stuff in the code for bf16 and using flash attention but I reworked it so it used SDPA and then flash attention was a specific opt-in. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review. I think I've found a potential cause of NaN issues in PyTorch 2.4. I think it may be better to address this cause for future proof. What do you think?
|
||
# Refine image context | ||
for layer in self.noise_refiner: | ||
x = layer(x, x_mask, img_freqs_cis, t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, x_mask
is zero values (set in line 1113).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this might be identifying a separate bug with how this was refactored. Can see it upstream: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py#L730-L733 I can give it another look through to fix this though.
Fixing the zero issue with the attention is probably also a good idea.
This comment was marked as duplicate.
This comment was marked as duplicate.
Sorry, something went wrong.
if valid_indices.numel() == 0: | ||
# If all tokens are masked, create a zero output | ||
batch_output = torch.zeros( | ||
seqlen, self.n_local_heads, self.head_dim, | ||
device=q.device, dtype=q.dtype | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If x_mask
is zero values, sage_attn
returns zero values of the expected output shape.
self.attention_processor( | ||
xq.permute(0, 2, 1, 3), | ||
xk.permute(0, 2, 1, 3), | ||
xv.permute(0, 2, 1, 3), | ||
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), | ||
scale=softmax_scale, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that SDPA may return NaN if the mask is zero values. With SDPA, it seems that we need to write code that returns zero values with the expected output shape when the mask is zero values, like sage_attn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this will work (do not work if some of items of the batch are all-zero and some are not):
valid_indices = torch.nonzero(x_mask, as_tuple=False).squeeze(-1)
if valid_indices.numel() == 0:
# If all tokens are masked, create a zero output
# NOTE: This does not assume that there will be a mix of masked and unmasked items in the batch.
output = torch.zeros_like(xq, dtype=dtype)
else:
output = (
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
.to(dtype)
)
I suspect it's a similar bug in PyTorch 2.4, and the best solution is to prioritize upgrading the PyTorch version to resolve the issue. |
It's true about the PyTorch bug. However, as rockerBOO pointed out, it seems necessary to create the mask correctly when calling |
I think so, because DIT models of sd3, flux types were rarely used with masks before... |
It seems that noise_refiner does not work properly without masking, so this part needs to be fixed for correct training. If it is difficult to fix, I will update it myself after merging this PR. This PR also changes FlowMatchEulerDiscreteScheduler, but I am wondering if this will affect the training of FLUX.1 and SD (whether those training will work correctly). What do you think? |
To fix I just need to find the time to do it. The fix be fairly simple (comparing to the upstream version).
FlowMatchEulerDiscreteScheduler was updated to the latest version at the time. This adds the dynamic shifting to this noise scheduler. This was partially done because I didn't understand how sigmas/timesteps worked and was using this noise scheduler to do the similar behavior we have for SD3/Flux with the shifting. |
Thank you for clarification! In my understanding, the dynamic shifting (a new function And this function is called from So I think we can keep FlowMatchEulerDiscreteScheduler as before, because the dynamic shifting is done in |
Lumina test fix mask
I reverted FlowMatchEulerDiscreteScheduler to how it was before. I added the masking in. Because we use buckets that have the same size it doesn't do much on it's own. Would allow you to have latent's of different sizes. Originally it supported a list of tensors. It's accurate now though. |
Thank you for updating! There seems to be a problem with handling the system prompt. Add a print statement to
Use the following dataset settings:
Then you will see the following string:
Without
Adding system_prompt to the DataSet settings has a large impact, so I would like to avoid it if possible. In addition, if we specify
|
|
||
# Refine image context | ||
for layer in self.noise_refiner: | ||
x = layer(x, x_mask, img_freqs_cis, t) |
This comment was marked as duplicate.
This comment was marked as duplicate.
Sorry, something went wrong.
library/train_util.py
Outdated
caption = self.process_caption(subset, image_info.caption) | ||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension | ||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is probably a duplication in adding system_prompt here.
@rockerBOO Can you roll back the dataset settings from before system prompts? |
If you don't have time to update the system prompt handling, I would like to merge this PR into a new branch and update the masking and system prompt handling there. What do you think? |
Sorry, I've been quite busy lately. I'd be happy to. |
Still in preparation.
After checking their sampler using flux and vae, the textencoder part uses google's gemma2
@kohya-ss CC