-
Notifications
You must be signed in to change notification settings - Fork 255
Add train_sana_sprint_diffusers file #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Will review in a bit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really promising. I left some comments, LMK if they make sense.
Additionally, if we could wrap the loss computations for the different phases into different functions, I think that will be easier to read. LMK what you think.
@@ -0,0 +1,1823 @@ | |||
#!/usr/bin/env python | |||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to add SANA Sprint team here too :)
if is_torch_npu_available(): | ||
torch.npu.config.allow_internal_format = False | ||
|
||
complex_human_instruction = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_human_instruction = [ | |
COMPLEX_HUMAN_INSTRUCTION = [ |
return False | ||
|
||
|
||
class Text2ImageDataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an example dataset with which it would work?
) | ||
# add meta-data to dataloader instance for convenience | ||
self._train_dataloader.num_batches = num_batches | ||
self._train_dataloader.num_samples = num_samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could use num_train_examples
here no?
disc.eval() | ||
models_to_accumulate = [transformer] | ||
with accelerator.accumulate(models_to_accumulate): | ||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can then remove this context manager.
images = None | ||
del pipeline | ||
|
||
# Save the lora layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not doing LoRA. So, this can be safely omitted.
cfg_y = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | ||
cfg_y_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) | ||
|
||
cfg_pretrain_pred = pretrained_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As another optimization, we could keep the pretrained_model
in CPU once this computation is done and load to GPU again when needed.
phase = "G" | ||
|
||
optimizer_D.step() | ||
optimizer_D.zero_grad(set_to_none=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think set_to_none
is by default True.
lr_scheduler.step() | ||
optimizer_G.zero_grad(set_to_none=True) | ||
|
||
elif phase == "D": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this alternates between two phases in the same training step, right? If so, I would add a comment.
Also, should we let the users control the step interval in which the discriminator should be updated? Or not really?
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Thanks for your thorough review and helpful suggestions! I'll carefully go through them and incorporate the changes when I'm back. Really appreciate it! |
Please don't hesitate to ping me for running tests, etc. |
Initial implementation of SANA-Sprint training script adapted for Diffusers.
This needs further refinement and optimization. @lawrence-cj @sayakpaul