Skip to content

support SD3 #1374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 533 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
533 commits
Select commit Hold shift + click to select a range
b8d3fec
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 7, 2024
186aa5b
fix illeagal block is swapped #1764
kohya-ss Nov 7, 2024
b3248a8
fix: sort order when getting image size from cache file
feffy380 Nov 7, 2024
2a2042a
Merge pull request #1770 from feffy380/fix-size-from-cache
kohya-ss Nov 9, 2024
8fac3c3
update README
kohya-ss Nov 9, 2024
26bd454
init
sdbds Nov 11, 2024
02bd76e
Refactor block swapping to utilize custom offloading utilities
kohya-ss Nov 11, 2024
7feaae5
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
92482c7
Merge pull request #1774 from sdbds/avif_get_imagesize
kohya-ss Nov 11, 2024
3fe94b0
update comment
kohya-ss Nov 11, 2024
cde90b8
feat: implement block swapping for FLUX.1 LoRA (WIP)
kohya-ss Nov 11, 2024
17cf249
Merge branch 'sd3' into faster-block-swap
kohya-ss Nov 11, 2024
2cb7a6d
feat: add block swap for FLUX.1/SD3 LoRA training
kohya-ss Nov 12, 2024
2bb0f54
update grad hook creation to fix TE lr in sd3 fine tuning
kohya-ss Nov 14, 2024
5c5b544
refactor: remove unused prepare_split_model method from FluxNetworkTr…
kohya-ss Nov 14, 2024
fd2d879
docs: update README
kohya-ss Nov 14, 2024
0047bb1
Merge pull request #1779 from kohya-ss/faster-block-swap
kohya-ss Nov 14, 2024
ccfaa00
add flux controlnet base module
minux302 Nov 15, 2024
42f6edf
fix for adding controlnet
minux302 Nov 15, 2024
e358b11
fix dataloader
minux302 Nov 16, 2024
2a188f0
Fix to work DOP with bock swap
kohya-ss Nov 17, 2024
b2660bb
train run
minux302 Nov 17, 2024
35778f0
fix sample_images type
minux302 Nov 17, 2024
4dd4cd6
work cn load and validation
minux302 Nov 18, 2024
31ca899
fix depth value
minux302 Nov 18, 2024
2a61fc0
docs: fix typo from block_to_swap to blocks_to_swap in README
kohya-ss Nov 20, 2024
0b5229a
save cn
minux302 Nov 21, 2024
420a180
Implement pseudo Huber loss for Flux and SD3
recris Nov 27, 2024
740ec1d
Fix issues found in review
recris Nov 28, 2024
9dff44d
fix device
minux302 Nov 29, 2024
575f583
add README
minux302 Nov 29, 2024
be5860f
add schnell option to load_cn
minux302 Nov 29, 2024
f40632b
rm abundant arg
minux302 Nov 29, 2024
928b939
Allow unknown schedule-free optimizers to continue to module loader
rockerBOO Nov 20, 2024
87f5224
Support d*lr for ProdigyPlus optimizer
rockerBOO Nov 20, 2024
6593cfb
Fix d * lr step log
rockerBOO Nov 21, 2024
c7cadbc
Add pytest testing
rockerBOO Nov 29, 2024
2dd063a
add torch torchvision accelerate versions
rockerBOO Nov 29, 2024
e59e276
Add dadaptation
rockerBOO Nov 29, 2024
dd3b846
Install pytorch first to pin version
rockerBOO Nov 29, 2024
89825d6
Run typos workflows once where appropriate
rockerBOO Nov 29, 2024
4f7f248
Bump typos action
rockerBOO Nov 29, 2024
9c885e5
fix: improve pos_embed handling for oversized images and update resol…
kohya-ss Nov 30, 2024
7b61e9e
Fix issues found in review (pt 2)
recris Nov 30, 2024
a5a27fe
Merge pull request #1808 from recris/huber-loss-flux
kohya-ss Dec 1, 2024
14f642f
fix: huber_schedule exponential not working on sd3_train.py
kohya-ss Dec 1, 2024
0fe6320
fix flux_train.py is not working
kohya-ss Dec 1, 2024
cc11989
fix: refactor huber-loss calculation in multiple training scripts
kohya-ss Dec 1, 2024
1476040
fix: update help text for huber loss parameters in train_util.py
kohya-ss Dec 1, 2024
bdf9a8c
Merge pull request #1815 from kohya-ss/flux-huber-loss
kohya-ss Dec 1, 2024
34e7f50
docs: update README for huber loss
kohya-ss Dec 1, 2024
14c9ba9
Merge pull request #1811 from rockerBOO/schedule-free-prodigy
kohya-ss Dec 1, 2024
1dc873d
update README and clean up code for schedulefree optimizer
kohya-ss Dec 1, 2024
e3fd6c5
Merge pull request #1812 from rockerBOO/tests
kohya-ss Dec 2, 2024
09a3740
Merge pull request #1813 from minux302/flux-controlnet
kohya-ss Dec 2, 2024
e369b9a
docs: update README with FLUX.1 ControlNet training details and impro…
kohya-ss Dec 2, 2024
5ab00f9
Update workflow tests with cleanup and documentation
rockerBOO Dec 2, 2024
63738ec
Add tests documentation
rockerBOO Dec 2, 2024
2610e96
Pytest
rockerBOO Dec 2, 2024
3e5d89c
Add more resources
rockerBOO Dec 2, 2024
8b36d90
feat: support block_to_swap for FLUX.1 ControlNet training
kohya-ss Dec 2, 2024
6bee18d
fix: resolve model corruption issue with pos_embed when using --enabl…
kohya-ss Dec 7, 2024
2be3366
Merge pull request #1817 from rockerBOO/workflow-tests-fixes
kohya-ss Dec 7, 2024
abff4b0
Unify controlnet parameters name and change scripts name. (#1821)
sdbds Dec 7, 2024
e425996
feat: unify ControlNet model name option and deprecate old training s…
kohya-ss Dec 7, 2024
3cb8cb2
Prevent git credentials from leaking into other actions
rockerBOO Dec 9, 2024
8e378cf
add RAdamScheduleFree support
nhamanasu Dec 11, 2024
d3305f9
Merge pull request #1828 from rockerBOO/workflow-security-audit
kohya-ss Dec 15, 2024
f2d38e6
Merge pull request #1830 from nhamanasu/sd3
kohya-ss Dec 15, 2024
e896539
update requirements.txt and README to include RAdamScheduleFree optim…
kohya-ss Dec 15, 2024
05bb918
Add Validation loss for LoRA training
hinablue Dec 27, 2024
62164e5
Change val loss calculate method
hinablue Dec 27, 2024
64bd531
Split val latents/batch and pick up val latents shape size which equa…
hinablue Dec 28, 2024
cb89e02
Change val latent loss compare
hinablue Dec 28, 2024
8743532
val
gesen2egee Mar 9, 2024
449c1c5
Adding modified train_util and config_util
rockerBOO Jan 2, 2025
7f6e124
Merge branch 'gesen2egee/val' into validation-loss-upstream
rockerBOO Jan 3, 2025
d23c732
Merge remote-tracking branch 'hina/feature/val-loss' into validation-…
rockerBOO Jan 3, 2025
7470173
Remove defunct code for train_controlnet.py
rockerBOO Jan 3, 2025
534059d
Typos and lingering is_train
rockerBOO Jan 3, 2025
c8c3569
Cleanup order, types, print to logger
rockerBOO Jan 3, 2025
fbfc275
Update text for train/reg with repeats
rockerBOO Jan 3, 2025
58bfa36
Add seed help clarifying info
rockerBOO Jan 3, 2025
6604b36
Remove duplicate assignment
rockerBOO Jan 3, 2025
0522070
Fix training, validation split, revert to using upstream implemenation
rockerBOO Jan 3, 2025
695f389
Move get_huber_threshold_if_needed
rockerBOO Jan 3, 2025
1f9ba40
Add step break for validation epoch. Remove unused variable
rockerBOO Jan 3, 2025
1c0ae30
Add missing functions for training batch
rockerBOO Jan 3, 2025
a9c5aa1
add CFG to FLUX.1 sample image
kohya-ss Jan 5, 2025
bbf6bbd
Use self.get_noise_pred_and_target and drop fixed timesteps
rockerBOO Jan 6, 2025
f4840ef
Revert train_db.py
rockerBOO Jan 6, 2025
1c63e7c
Cleanup unused code and formatting
rockerBOO Jan 6, 2025
c64d1a2
Add validate_every_n_epochs, change name validate_every_n_steps
rockerBOO Jan 6, 2025
f885029
Fix validate epoch, cleanup imports
rockerBOO Jan 6, 2025
fcb2ff0
Clean up some validation help documentation
rockerBOO Jan 6, 2025
742bee9
Set validation steps in multiple lines for readability
rockerBOO Jan 6, 2025
1231f51
Remove unused train_util code, fix accelerate.log for wandb, add init…
rockerBOO Jan 8, 2025
556f3f1
Fix documentation, remove unused function, fix bucket reso for sd1.5,…
rockerBOO Jan 8, 2025
9fde0d7
Handle tuple return from generate_dataset_group_by_blueprint
rockerBOO Jan 8, 2025
1e61392
Revert bucket_reso_steps to correct 64
rockerBOO Jan 8, 2025
d6f158d
Fix incorrect destructoring for load_abritrary_dataset
rockerBOO Jan 8, 2025
264167f
Apply is_training_dataset only to DreamBoothDataset. Add validation_s…
rockerBOO Jan 9, 2025
4c61adc
Add divergence to logs
rockerBOO Jan 12, 2025
2bbb40c
Fix regularization images with validation
rockerBOO Jan 12, 2025
0456858
Fix validate_every_n_steps always running first step
rockerBOO Jan 12, 2025
ee9265c
Fix validate_every_n_steps for gradient accumulation
rockerBOO Jan 12, 2025
25929dd
Remove Validating... print to fix output layout
rockerBOO Jan 12, 2025
b489082
Disable repeats for validation datasets
rockerBOO Jan 12, 2025
c04e5df
Fix loss recorder on 0. Fix validation for cached runs. Assert on val…
rockerBOO Jan 23, 2025
6acdbed
Merge branch 'dev' into sd3
kohya-ss Jan 26, 2025
23ce75c
Merge branch 'dev' into sd3
kohya-ss Jan 26, 2025
b833d47
Merge pull request #1864 from rockerBOO/validation-loss-upstream
kohya-ss Jan 26, 2025
58b82a5
Fix to work with validation dataset
kohya-ss Jan 26, 2025
e852961
README.md: Update recent updates section to include validation loss s…
kohya-ss Jan 26, 2025
f1ac81e
Merge pull request #1899 from kohya-ss/val-loss
kohya-ss Jan 26, 2025
59b3b94
README.md: Update limitation for validation loss support to include s…
kohya-ss Jan 26, 2025
532f5c5
formatting
kohya-ss Jan 27, 2025
86a2f3f
Fix gradient handling when Text Encoders are trained
kohya-ss Jan 27, 2025
b6a3093
call optimizer eval/train fn before/after validation
kohya-ss Jan 27, 2025
29f31d0
add network.train()/eval() for validation
kohya-ss Jan 27, 2025
0750859
validation: Implement timestep-based validation processing
kohya-ss Jan 27, 2025
0778dd9
fix Text Encoder only LoRA training
kohya-ss Jan 27, 2025
42c0a9e
Merge branch 'sd3' into val-loss-improvement
kohya-ss Jan 27, 2025
45ec02b
use same noise for every validation
kohya-ss Jan 27, 2025
de830b8
Move progress bar to account for sampling image first
rockerBOO Jan 29, 2025
c5b803c
rng state management: Implement functions to get and set RNG states f…
kohya-ss Feb 4, 2025
a24db1d
fix: validation timestep generation fails on SD/SDXL training
kohya-ss Feb 4, 2025
0911683
set python random state
kohya-ss Feb 9, 2025
344845b
fix: validation with block swap
kohya-ss Feb 9, 2025
1772038
fix: unpause training progress bar after vaidation
kohya-ss Feb 11, 2025
cd80752
fix: remove unused parameter 'accelerator' from encode_images_to_late…
kohya-ss Feb 11, 2025
76b7619
fix: simplify validation step condition in NetworkTrainer
kohya-ss Feb 11, 2025
ab88b43
Fix validation epoch divergence
rockerBOO Feb 14, 2025
ee295c7
Merge pull request #1935 from rockerBOO/validation-epoch-fix
kohya-ss Feb 15, 2025
63337d9
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 15, 2025
4671e23
Fix validation epoch loss to check epoch average
rockerBOO Feb 16, 2025
3c7496a
Fix sizes for validation split
rockerBOO Feb 17, 2025
f3a0109
Clear sizes for validation reg images to be consistent
rockerBOO Feb 17, 2025
6051fa8
Merge pull request #1940 from rockerBOO/split-size-fix
kohya-ss Feb 17, 2025
7c22e12
Merge pull request #1938 from rockerBOO/validation-epoch-loss-recorder
kohya-ss Feb 17, 2025
9436b41
Fix validation split and add test
rockerBOO Feb 17, 2025
894037f
Merge pull request #1943 from rockerBOO/validation-split-test
kohya-ss Feb 18, 2025
dc7d5fb
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 18, 2025
4a36996
modify log step calculation
kohya-ss Feb 18, 2025
58e9e14
Add resize interpolation configuration
rockerBOO Feb 14, 2025
d0128d1
Add resize interpolation CLI option
rockerBOO Feb 14, 2025
7729c4c
Add metadata
rockerBOO Feb 14, 2025
545425c
Typo
rockerBOO Feb 14, 2025
ca1c129
Fix metadata
rockerBOO Feb 14, 2025
7f27471
Use resize_image where resizing is required
rockerBOO Feb 19, 2025
efb2a12
fix wandb val logging
kohya-ss Feb 21, 2025
905f081
Merge branch 'dev' into sd3
kohya-ss Feb 24, 2025
67fde01
Merge branch 'dev' into sd3
kohya-ss Feb 24, 2025
6e90c0f
Merge pull request #1909 from rockerBOO/progress_bar
kohya-ss Feb 24, 2025
f4a0047
feat: support metadata loading in MemoryEfficientSafeOpen
kohya-ss Feb 26, 2025
5228db1
feat: add script to merge multiple safetensors files into a single fi…
kohya-ss Feb 26, 2025
ae409e8
fix: FLUX/SD3 network training not working without caching latents cl…
kohya-ss Feb 26, 2025
1fcac98
Merge branch 'sd3' into val-loss-improvement
kohya-ss Feb 26, 2025
4965189
Merge pull request #1903 from kohya-ss/val-loss-improvement
kohya-ss Feb 26, 2025
ec350c8
Merge branch 'dev' into sd3
kohya-ss Feb 26, 2025
3d79239
docs: update README to include recent improvements in validation loss…
kohya-ss Feb 26, 2025
734333d
feat: enhance merging logic for safetensors models to handle key pref…
kohya-ss Feb 28, 2025
272f4c3
Merge branch 'sd3' into sd3_safetensors_merge
kohya-ss Feb 28, 2025
ba52511
fix: save tensors as is dtype, add save_precision option
kohya-ss Mar 1, 2025
aa2bde7
docs: add utility script for merging SD3 weights into a single .safet…
kohya-ss Mar 5, 2025
75933d7
Merge pull request #1960 from kohya-ss/sd3_safetensors_merge
kohya-ss Mar 5, 2025
ea53290
Add LoRA-GGPO for Flux
rockerBOO Mar 6, 2025
e5b5c7e
Update requirements.txt
gesen2egee Mar 15, 2025
3647d06
Cache weight norms estimate on initialization. Move to update norms e…
rockerBOO Mar 18, 2025
0b25a05
Add IP noise gamma for Flux
rockerBOO Mar 18, 2025
c8be141
Apply IP gamma to noise fix
rockerBOO Mar 18, 2025
b425466
Fix IP noise gamma to use random values
rockerBOO Mar 18, 2025
a4f3a9f
Use ones_like
rockerBOO Mar 18, 2025
6f4d365
zeros_like because we are adding
rockerBOO Mar 18, 2025
b81bcd0
Move IP noise gamma to noise creation to remove complexity and align …
rockerBOO Mar 19, 2025
5b210ad
update prodigyopt and prodigy-plus-schedule-free
gesen2egee Mar 19, 2025
7197266
Perturbed noise should be separate of input noise
rockerBOO Mar 19, 2025
d93ad90
Add perturbation on noisy_model_input if needed
rockerBOO Mar 19, 2025
8e6817b
Remove double noise
rockerBOO Mar 19, 2025
1eddac2
Separate random to a variable, and make sure on device
rockerBOO Mar 19, 2025
5d5a7d2
Fix IP noise calculation
rockerBOO Mar 19, 2025
f974c6b
change order to match upstream
rockerBOO Mar 19, 2025
936d333
Merge pull request #1985 from gesen2egee/pytorch-optimizer
kohya-ss Mar 20, 2025
d151833
docs: update README with recent changes and specify version for pytor…
kohya-ss Mar 20, 2025
16cef81
Refactor sigmas and timesteps
rockerBOO Mar 20, 2025
e8b3254
Add flux_train_utils tests for get get_noisy_model_input_and_timesteps
rockerBOO Mar 20, 2025
8aa1265
Scale sigmoid to default 1.0
rockerBOO Mar 20, 2025
d40f5b1
Revert "Scale sigmoid to default 1.0"
rockerBOO Mar 20, 2025
89f0d27
Set sigmoid_scale to default 1.0
rockerBOO Mar 20, 2025
6364379
Merge branch 'dev' into sd3
kohya-ss Mar 21, 2025
8ebe858
Merge branch 'dev' into sd3
kohya-ss Mar 24, 2025
182544d
Remove pertubation seed
rockerBOO Mar 26, 2025
0181b7a
Remove progress bar avg norms
rockerBOO Mar 27, 2025
93a4efa
Merge branch 'sd3' into resize-interpolation
kohya-ss Mar 30, 2025
9e9a13a
Merge pull request #1936 from rockerBOO/resize-interpolation
kohya-ss Mar 30, 2025
1f432e2
use PIL for lanczos and box
kohya-ss Mar 30, 2025
96a133c
README.md: update recent updates section to include new interpolation…
kohya-ss Mar 30, 2025
3149b27
Merge pull request #2018 from kohya-ss/resize-interpolation-small-fix
kohya-ss Mar 30, 2025
59d98e4
Merge pull request #1974 from rockerBOO/lora-ggpo
kohya-ss Mar 30, 2025
d0b5c0e
chore: formatting, add TODO comment
kohya-ss Mar 30, 2025
aaa26bb
docs: update README to include LoRA-GGPO details for FLUX.1 training
kohya-ss Mar 30, 2025
b3c56b2
Merge branch 'dev' into sd3
kohya-ss Mar 31, 2025
ede3470
Ensure all size parameters are integers to prevent type errors
LexSong Apr 1, 2025
b822b7e
Fix the interpolation logic error in resize_image()
LexSong Apr 1, 2025
f1423a7
fix: add resize_interpolation parameter to FineTuningDataset constructor
kohya-ss Apr 3, 2025
92845e8
Merge pull request #2026 from kohya-ss/fix-finetune-dataset-resize-in…
kohya-ss Apr 3, 2025
fd36fd1
Fix resize PR link
rockerBOO Apr 3, 2025
606e687
Merge pull request #2022 from LexSong/fix-resize-issue
kohya-ss Apr 5, 2025
ee0f754
Merge pull request #2028 from rockerBOO/patch-5
kohya-ss Apr 5, 2025
c56dc90
Merge pull request #1992 from rockerBOO/flux-ip-noise-gamma
kohya-ss Apr 6, 2025
4589262
README.md: Update recent updates section to include IP noise gamma fe…
kohya-ss Apr 6, 2025
5a18a03
Merge branch 'dev' into sd3
kohya-ss Apr 7, 2025
06df037
Merge branch 'sd3' into flux-sample-cfg
kohya-ss Apr 16, 2025
629073c
Add guidance scale for prompt param and flux sampling
kohya-ss Apr 16, 2025
7c61c0d
Add autocast warpper for forward functions in deepspeed_utils.py to t…
sharlynxy Apr 22, 2025
d33d5ec
#
sharlynxy Apr 22, 2025
7f984f4
#
sharlynxy Apr 22, 2025
c8af252
refactor
robertwenquan Apr 22, 2025
f501209
Merge branch 'dev/xy/align_dtype_using_mixed_precision' of github.com…
robertwenquan Apr 22, 2025
0d9da0e
Merge pull request #1 from saibit-tech/dev/xy/align_dtype_using_mixed…
sharlynxy Apr 22, 2025
b11c053
Merge branch 'dev' into sd3
kohya-ss Apr 22, 2025
adb775c
Update: requirement diffusers[torch]==0.25.0
sharlynxy Apr 23, 2025
abf2c44
Dynamically set device in deepspeed wrapper (#2)
sharlynxy Apr 23, 2025
46ad3be
update deepspeed wrapper
sharlynxy Apr 24, 2025
5c50cdb
Merge branch 'sd3' into flux-sample-cfg
kohya-ss Apr 27, 2025
8387e0b
docs: update README to include CFG scale support in FLUX.1 training
kohya-ss Apr 27, 2025
309c44b
Merge pull request #2064 from kohya-ss/flux-sample-cfg
kohya-ss Apr 27, 2025
0e8ac43
Merge branch 'dev' into sd3
kohya-ss Apr 27, 2025
13296ae
Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3
kohya-ss Apr 27, 2025
fd3a445
fix: revert default emb guidance scale and CFG scale for FLUX.1 sampling
kohya-ss Apr 27, 2025
29523c9
docs: add note for user feedback on CFG scale in FLUX.1 training
kohya-ss Apr 27, 2025
80320d2
Merge pull request #2066 from kohya-ss/quick-fix-flux-sampling-scales
kohya-ss Apr 27, 2025
64430eb
Merge branch 'dev' into sd3
kohya-ss Apr 29, 2025
1684aba
remove deepspeed from requirements.txt
sharlynxy Apr 30, 2025
a4fae93
Add pythonpath to pytest.ini
rockerBOO May 1, 2025
f62c68d
Make grad_norm and combined_grad_norm None is not recording
rockerBOO May 1, 2025
b4a89c3
Fix None
rockerBOO May 1, 2025
7c075a9
Merge pull request #2060 from saibit-tech/sd3
kohya-ss May 1, 2025
865c8d5
README.md: Update recent updates and add DeepSpeed installation instr…
kohya-ss May 1, 2025
a27ace7
doc: add DeepSpeed installation in header section
kohya-ss May 1, 2025
e858132
Merge pull request #2074 from kohya-ss/deepspeed-readme
kohya-ss May 1, 2025
e2ed265
Merge pull request #2072 from rockerBOO/pytest-pythonpath
kohya-ss May 1, 2025
5b38d07
Merge pull request #2073 from rockerBOO/fix-mean-grad-norms
kohya-ss May 11, 2025
2bfda12
Update workflows to read-all instead of write-all
rockerBOO May 20, 2025
5753b8f
Merge pull request #2088 from rockerBOO/checkov-update
kohya-ss May 20, 2025
e4d6923
Add tests for syntax checking training scripts
rockerBOO Jun 3, 2025
61eda76
Merge pull request #2108 from rockerBOO/syntax-test
kohya-ss Jun 4, 2025
bb47f1e
Fix unwrap_model handling for None text_encoders in sample_images fun…
kohya-ss Jun 8, 2025
fc40a27
Merge branch 'dev' into sd3
kohya-ss Jun 15, 2025
3e6935a
Merge pull request #2115 from kohya-ss/fix-flux-sampling-accelerate-e…
kohya-ss Jun 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Test with pytest

on:
push:
branches:
- main
- dev
- sd3
pull_request:
branches:
- main
- dev
- sd3

# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
permissions: read-all

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.10"] # Python versions to test
pytorch-version: ["2.4.0"] # PyTorch versions to test

steps:
- uses: actions/checkout@v4
with:
# https://woodruffw.github.io/zizmor/audits/#artipacked
persist-credentials: false

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Install and update pip, setuptools, wheel
run: |
# Setuptools, wheel for compiling some packages
python -m pip install --upgrade pip setuptools wheel

- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
pip install -r requirements.txt

- name: Test with pytest
run: pytest # See pytest.ini for configuration

14 changes: 11 additions & 3 deletions .github/workflows/typos.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
---
# yamllint disable rule:line-length
name: Typos

on: # yamllint disable-line rule:truthy
on:
push:
branches:
- main
- dev
pull_request:
types:
- opened
- synchronize
- reopened

# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
permissions: read-all

jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
# https://woodruffw.github.io/zizmor/audits/#artipacked
persist-credentials: false

- name: typos-action
uses: crate-ci/typos@v1.24.3
uses: crate-ci/typos@v1.28.1
787 changes: 783 additions & 4 deletions README.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions docs/config_README-en.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ These options are related to subset configuration.
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` | (not specified) | o | o | o |

* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
Expand All @@ -165,6 +166,8 @@ These options are related to subset configuration.
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
* `resize_interpolation`
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.

### DreamBooth-specific options

Expand Down
4 changes: 4 additions & 0 deletions docs/config_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` |(通常は設定しません) | o | o | o |

* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
Expand All @@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。

* `resize_interpolation`
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos``box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。

### DreamBooth 方式専用のオプション

DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
Expand Down
2 changes: 1 addition & 1 deletion docs/train_lllite_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ for img_file in img_files:

### Creating a dataset configuration file

You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.

```toml
[general]
Expand Down
86 changes: 52 additions & 34 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -39,6 +39,7 @@
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd


def train(args):
Expand All @@ -52,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する

tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)

# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

# データセットを準備する
if args.dataset_class is None:
Expand Down Expand Up @@ -81,10 +90,11 @@ def train(args):
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -167,8 +177,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)

train_dataset_group.new_cache_latents(vae, accelerator)

vae.to("cpu")
clean_memory_on_device(accelerator.device)

Expand All @@ -194,6 +205,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
text_encoder.eval()

text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

if not cache_latents:
vae.requires_grad_(False)
vae.eval()
Expand All @@ -216,7 +230,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)

# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()

# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -319,7 +337,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
Expand All @@ -344,25 +367,21 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -374,11 +393,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -390,9 +408,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand All @@ -411,7 +427,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

# 指定ステップごとにモデルを保存
Expand All @@ -436,7 +452,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
Expand All @@ -449,7 +465,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

Expand All @@ -474,7 +490,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)

is_main_process = accelerator.is_main_process
if is_main_process:
Expand Down
7 changes: 2 additions & 5 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm import tqdm

import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image

setup_logging()
import logging
Expand Down Expand Up @@ -42,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)

image = image.astype(np.float32)
return image
Expand Down
Loading