Skip to content

Add train_sana_sprint_diffusers file #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

scxue
Copy link
Collaborator

@scxue scxue commented Apr 7, 2025

Initial implementation of SANA-Sprint training script adapted for Diffusers.
This needs further refinement and optimization. @lawrence-cj @sayakpaul

@sayakpaul
Copy link
Contributor

Will review in a bit.

Copy link
Contributor

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really promising. I left some comments, LMK if they make sense.

Additionally, if we could wrap the loss computations for the different phases into different functions, I think that will be easier to read. LMK what you think.

@@ -0,0 +1,1823 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to add SANA Sprint team here too :)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False

complex_human_instruction = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_human_instruction = [
COMPLEX_HUMAN_INSTRUCTION = [

return False


class Text2ImageDataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an example dataset with which it would work?

)
# add meta-data to dataloader instance for convenience
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use num_train_examples here no?

disc.eval()
models_to_accumulate = [transformer]
with accelerator.accumulate(models_to_accumulate):
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can then remove this context manager.

images = None
del pipeline

# Save the lora layers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not doing LoRA. So, this can be safely omitted.

cfg_y = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
cfg_y_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)

cfg_pretrain_pred = pretrained_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As another optimization, we could keep the pretrained_model in CPU once this computation is done and load to GPU again when needed.

phase = "G"

optimizer_D.step()
optimizer_D.zero_grad(set_to_none=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think set_to_none is by default True.

lr_scheduler.step()
optimizer_G.zero_grad(set_to_none=True)

elif phase == "D":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this alternates between two phases in the same training step, right? If so, I would add a comment.

Also, should we let the users control the step interval in which the discriminator should be updated? Or not really?

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@scxue
Copy link
Collaborator Author

scxue commented Apr 9, 2025

Thanks for your thorough review and helpful suggestions! I'll carefully go through them and incorporate the changes when I'm back. Really appreciate it!

@sayakpaul
Copy link
Contributor

Please don't hesitate to ping me for running tests, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants