Skip to content

[HiDream LoRA] optimizations + small updates #11381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7f309a4
1. add pre-computation of prompt embeddings when custom prompts are u…
linoytsaban Apr 22, 2025
ca8e79b
pre encode validation prompt as well
linoytsaban Apr 22, 2025
4b0aa84
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 22, 2025
bb12f88
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 22, 2025
e4d365d
Update examples/dreambooth/train_dreambooth_lora_hidream.py
linoytsaban Apr 22, 2025
65832ee
pre encode validation prompt as well
linoytsaban Apr 22, 2025
8fd8d42
Merge remote-tracking branch 'origin/hidream-followup' into hidream-f…
linoytsaban Apr 22, 2025
b27d9bc
Apply style fixes
github-actions[bot] Apr 22, 2025
c8b2f07
empty commit
linoytsaban Apr 22, 2025
9e2091d
change default trained modules
linoytsaban Apr 22, 2025
2c59748
empty commit
linoytsaban Apr 22, 2025
2652029
Merge branch 'main' into hidream-followup
sayakpaul Apr 23, 2025
2bbeb9f
address comments + change encoding of validation prompt (before it wa…
linoytsaban Apr 23, 2025
d4dd84f
Apply style fixes
github-actions[bot] Apr 23, 2025
dd67962
empty commit
linoytsaban Apr 23, 2025
3f84f96
Merge remote-tracking branch 'origin/hidream-followup' into hidream-f…
linoytsaban Apr 23, 2025
e6b8f01
fix validation_embeddings definition
linoytsaban Apr 23, 2025
9c976d2
fix final inference condition
linoytsaban Apr 23, 2025
75eaaa7
fix pipeline deletion in last inference
linoytsaban Apr 23, 2025
44a9846
Apply style fixes
github-actions[bot] Apr 23, 2025
46fcd76
empty commit
linoytsaban Apr 23, 2025
5dc3468
Merge remote-tracking branch 'origin/hidream-followup' into hidream-f…
linoytsaban Apr 23, 2025
d093d08
Merge branch 'main' into hidream-followup
linoytsaban Apr 23, 2025
af42f02
layers
linoytsaban Apr 23, 2025
d795bb3
Merge remote-tracking branch 'origin/hidream-followup' into hidream-f…
linoytsaban Apr 23, 2025
0e6fa1b
Merge branch 'main' into hidream-followup
linoytsaban Apr 23, 2025
1efdc2a
remove readme remarks on only pre-computing when instance prompt is p…
linoytsaban Apr 23, 2025
cb53d9c
Merge remote-tracking branch 'origin/hidream-followup' into hidream-f…
linoytsaban Apr 23, 2025
12afa54
smol fix
linoytsaban Apr 23, 2025
dab9132
Merge branch 'main' into hidream-followup
linoytsaban Apr 23, 2025
bbf4f1a
empty commit
linoytsaban Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions examples/dreambooth/README_hidream.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,54 +51,41 @@ When running `accelerate config`, if we specify torch compile mode to True there
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.


### Dog toy example
### 3d icon example

Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.

Let's first download it locally:

```python
from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.

This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.

Now, we can launch training using:
> [!NOTE]
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
> 8-bit Adam optimizer, latent caching, offloading, no validation.
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image)
> text embeddings are pre-computed to save memory.

> 8-bit Adam optimizer, latent caching, offloading, no validation.
> all text embeddings are pre-computed to save memory.
```bash
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="dog"
export INSTANCE_DIR="linoyts/3d_icon"
export OUTPUT_DIR="trained-hidream-lora"

accelerate launch train_dreambooth_lora_hidream.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--dataset_name=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--instance_prompt="3d icon" \
--caption_column="prompt"\
--validation_prompt="a 3dicon, a llama eating ramen" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--rank=16 \
--rank=8 \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=1000 \
--cache_latents \
--cache_latents\
--gradient_checkpointing \
--validation_epochs=25 \
--seed="0" \
Expand Down Expand Up @@ -128,6 +115,5 @@ We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.

Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
Loading