Skip to content

Randomized and Incremental Regularization Image Loading - Overhaul of Regularization Image Loading Logic #2096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

DKnight54
Copy link
Contributor

In the original implementation of the regularization image loading, the code will load the first N regularization images, where N is number of training images * number of repeats.

this leads to a couple of edge cases where there could be suboptimal results for training when considering use cases for training on free resources such as google colab which limits the number of hours training can run per day. when the number of regularization images is not equal to N. When N is greater, due to the first few images consistently having additional repeats added to the dataset can, over extended training over muliple epochs and/or resumed training, lead to them having stronger influence on the training model.

When N is lessor than the number of regularization images available, this means that some training strategies which make use of the regularization images to simultaneously improve the overall quality by adding additional ground truth images would not be able to fully utilize all the prepared regularization images and captions.

Additionally use of multiple subsets to organize categories of regularization images may result in training being weighted unintentionally by user to specific concepts based on the order the subsets are loaded.

This pull request intends to migitate this by implementing two training strategies:-

  1. Randomized loading of regularization images
  • This shuffles the order by which regularization images are loaded, resulting in more even distribution of available regularization images being loaded.
  1. Incremental loading of regularization images
  • This option causes the dataset to reload the regularization images, walking down the list (and looping back to the start when the list is exhausted) and loading the next N regularization images on each epoch

Both strategies can by activated separately, together and turned off completely (returning to the original loading strategy by default) using the arguements --incremental_reg_load and --randomized_regularization_image

Points to consider:

  • If using --incremental_reg_load, the length of the dataloader will vary between epochs, especially when using buckets due to possible different number of batchs available.
  • If using --incremental_reg_load, persistent dataloader workers would not work as the dataloaders have to be recreated at the start of each epoch in order to correctly update the number of batches available.
  • Attempts at bypassing this by triggering the reloads in the set_current_epoch method in the BaseDataset class have failed, due to inability to propogate updated __len__() values to enclosing DatasetGroup class.
  • The use of multiprocessing dataloader workers also complicates syncing data back to the mainthread, resulting in unexpected behaivour.
  • To implement a more efficient caching strategy of only caching images that are going to be trained, the cache_text_encoder_outputs_if_needed() function had to be moved to just before training starts for the epoch.
  • To ensure the text encoders are in the correct device, placeholder code to move the models following the logic of cache_text_encoder_outputs_if_needed() was put in place as here.

As this is a proof of concept only, it is implemented only in the LoRA training script for SD1.5 and SDXL.

Please let me know if you have any questions

@DKnight54 DKnight54 changed the title Overhaul of Regularization Image Loading Logic Randomized and Incremental Regularization Image Loading - Overhaul of Regularization Image Loading Logic May 27, 2025
@kohya-ss
Copy link
Owner

Thank you for your excellent PR proposal, and I apologize for the delay in my response.

I appreciate your detailed consideration and suggestions for improvement regarding the behavior when the number of regularization images differs from N (number of training images * number of repeats).

First, let me share my understanding of the current implementation:

  • When N is greater than the number of regularization images, I believed the impact was minor because the difference in the number of times each image is used is at most one.
  • When N is less than the number of regularization images, some images are not used, but I assumed this could be avoided by the user increasing the number of repeats.

However, I agree that the points you raised, such as "potential bias accumulation when resuming training" and "the order dependency of multiple subsets," are indeed points to consider for improving usability and training stability.

Regarding the two options you proposed, I am concerned about the incremental option due to its implementation complexity and performance impact, as it requires recreating the dataloader, which you also mentioned.

With that in mind, how about we focus our discussion on implementing the randomized option first? This option alone seems capable of robustly solving many of the issues you've pointed out (especially the behavior of multiple subsets) with fewer code changes.

On a technical note, regarding multi-GPU support, it might be possible to implement it more simply by providing the same random seed to each process, as an alternative to the proposed method using distributed_state and gather_object. I would like to discuss this point as well.

The incremental option is a very interesting approach, so I would appreciate it if we could discuss it separately in another Issue or PR after we have settled on a direction for the randomized option.
(For that case, perhaps we could implement it while keeping persistent_workers enabled, for example, by preparing a shared list of regularization image indices for all processes and shifting the reading range for each epoch.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants