Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

dataset.incremental_reg_load()
dataset.make_buckets()

return DatasetGroup(datasets)


Expand Down
192 changes: 168 additions & 24 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import subprocess
from io import BytesIO
import toml
import copy

from tqdm import tqdm

Expand Down Expand Up @@ -164,6 +165,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.latent_cache_checked: bool = False
self.te_cache_checked: bool = False


class BucketManager:
Expand Down Expand Up @@ -653,6 +656,11 @@ def __init__(
# caching
self.caching_mode = None # None, 'latents', 'text'

# lists for incremental loading of regularization images
self.reg_infos = None
self.reg_infos_index = None
self.reg_randomize = False

def adjust_min_max_bucket_reso_by_steps(
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
) -> Tuple[int, int]:
Expand Down Expand Up @@ -684,6 +692,12 @@ def adjust_min_max_bucket_reso_by_steps(
def set_seed(self, seed):
self.seed = seed

def set_reg_randomize(self, reg_randomize = False):
self.reg_randomize = reg_randomize

def incremental_reg_load(self, make_bucket = False): # Placeholder method, does nothing unless overridden in subclasses.
return

def set_caching_mode(self, mode):
self.caching_mode = mode

Expand Down Expand Up @@ -951,11 +965,14 @@ def make_buckets(self):
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
batch_count: int = 0
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
count = len(bucket)
if count > 0:
batch_count += math.ceil(len(bucket) / self.batch_size)
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}, batches: {int(math.ceil(len(bucket) / self.batch_size))}")
logger.info(f"Total batch count: {batch_count}")

if len(img_ar_errors) == 0:
mean_img_ar_error = 0 # avoid NaN
Expand All @@ -967,6 +984,7 @@ def make_buckets(self):

# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List[BucketBatchIndex] = []
self.buckets_indices.clear()
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
Expand Down Expand Up @@ -1025,6 +1043,10 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc
logger.info("caching latents.")

image_infos = list(self.image_data.values())
image_infos = list(filter(lambda info: info.latent_cache_checked == False, image_infos))
if len(image_infos) == 0:
logger.info("All images latents previously checked and cached. Skipping.")
return

# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
Expand Down Expand Up @@ -1054,11 +1076,17 @@ def __eq__(self, other):
subset = self.image_to_subset[info.image_key]

if info.latents_npz is not None: # fine tuning dataset
info.latent_cache_checked = True
if self.reg_infos is not None and info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latent_cache_checked = True
continue

# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latent_cache_checked = True
if self.reg_infos is not None and info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latent_cache_checked = True
if not is_main_process: # store to info only
continue

Expand Down Expand Up @@ -1094,6 +1122,15 @@ def __eq__(self, other):
logger.info("caching latents...")
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
if self.reg_infos is not None:
for info in batch:
if info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].latents_npz = info.latents_npz
self.reg_infos[info.image_key][0].latents_original_size = info.latents_original_size
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_crop_ltrb
self.reg_infos[info.image_key][0].latents_crop_ltrb = info.latents_flipped
self.reg_infos[info.image_key][0].latents = info.latents
self.reg_infos[info.image_key][0].alpha_mask = info.alpha_mask

# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
Expand All @@ -1107,6 +1144,10 @@ def cache_text_encoder_outputs(
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())
image_infos = list(filter(lambda info: info.te_cache_checked == False, image_infos))
if len(image_infos) == 0:
logger.info("Text encoder outputs for all images previously checked and cached. Skipping.")
return

logger.info("checking cache existence...")
image_infos_to_cache = []
Expand All @@ -1115,6 +1156,10 @@ def cache_text_encoder_outputs(
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
info.te_cache_checked = True
if self.reg_infos is not None:
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
self.reg_infos[info.image_key][0].te_cache_checked = True

if not is_main_process: # store to info only
continue
Expand Down Expand Up @@ -1157,6 +1202,14 @@ def cache_text_encoder_outputs(
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
if self.reg_infos is not None:
for info in batch:
if info.image_key in self.reg_infos:
self.reg_infos[info.image_key][0].text_encoder_outputs_npz = te_out_npz
self.reg_infos[info.image_key][0].te_cache_checked = True
self.reg_infos[info.image_key][0].text_encoder_outputs1 = info.text_encoder_outputs1
self.reg_infos[info.image_key][0].text_encoder_outputs2 = info.text_encoder_outputs2
self.reg_infos[info.image_key][0].text_encoder_pool2 = info.text_encoder_pool2

def get_image_size(self, image_path):
return imagesize.get(image_path)
Expand Down Expand Up @@ -1561,6 +1614,9 @@ def __init__(
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.reg_infos: Dict[str, Tuple[ImageInfo, DreamBoothSubset]] = {}
self.reg_infos_index: List[str] = []
self.reg_infos_index_traverser = 0

self.enable_bucket = enable_bucket
if self.enable_bucket:
Expand Down Expand Up @@ -1689,7 +1745,6 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.info("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
if subset.num_repeats < 1:
logger.warning(
Expand Down Expand Up @@ -1720,7 +1775,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
if subset.num_repeats > 1:
info.num_repeats = 1
self.reg_infos[info.image_key] = (info, subset)
for i in range(subset.num_repeats):
self.reg_infos_index.append(info.image_key)
else:
self.register_image(info, subset)

Expand All @@ -1731,30 +1790,89 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
self.num_train_images = num_train_images

logger.info(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
self.num_reg_images = num_reg_images
self.reg_infos_index_traverser = 0

def set_reg_randomize(self, reg_randomize = False):
self.reg_randomize = reg_randomize
# As first set of data is loaded before the first opportunity to shuffle, will need to force reset self.reg_infos_index_traverser and reinitialize dataset
self.reg_infos_index_traverser = 0
self.bucket_manager = None
self.incremental_reg_load(True)

def subset_loaded_count(self):
count_str = ""
for index, subset in enumerate(self.subsets):
counter = 0
count_str += f"\nSubset {index} (Class: {subset.class_tokens}): " if isinstance(subset, DreamBoothSubset) and subset.class_tokens is not None else f"\nSubset {index}: "
img_keys = [key for key, value in self.image_to_subset.items() if value == subset]
for img_key in img_keys:
counter += self.image_data[img_key].num_repeats
count_str += f"{counter}/{subset.img_count * subset.num_repeats}"
count_str += f"\nSubset dir: {subset.image_dir}" if subset.image_dir is not None else ""
count_str += f"\n\n"
logger.info(count_str)

def incremental_reg_load(self, make_bucket = False):
#override to for loading random reg images
distributed_state = PartialState()

if self.num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
return
if self.num_train_images < self.num_reg_images:
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")

if num_reg_images == 0:
logger.warning("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
if not self.num_train_images == self.num_reg_images:
logger.info(f"Inititating loading of regularizaion images.")
for info, subset in self.reg_infos.values():
if info.image_key in self.image_data:
self.image_data.pop(info.image_key, None)
self.image_to_subset.pop(info.image_key, None)

temp_reg_infos = copy.deepcopy(self.reg_infos)
n = 0
first_loop = True
while n < num_train_images:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
logger.info(f"self.reg_infos_index_traverser at: {self.reg_infos_index_traverser}\n reg_infos_index len = {len(self.reg_infos_index)}")
reg_img_log = f"\nDataset seed: {self.seed}"
start_index = self.reg_infos_index_traverser

while n < self.num_train_images :
if self.reg_randomize and self.reg_infos_index_traverser == 0:
if distributed_state.num_processes > 1:
if not distributed_state.is_main_process:
self.reg_infos_index = []
else:
random.shuffle(self.reg_infos_index)
distributed_state.wait_for_everyone()
self.reg_infos_index = gather_object(self.reg_infos_index)
else:
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
first_loop = False

self.num_reg_images = num_reg_images

random.shuffle(self.reg_infos_index)
info, subset = temp_reg_infos[self.reg_infos_index[self.reg_infos_index_traverser]]
if info.image_key in self.image_data:
info.num_repeats += 1 # rewrite registered info
else:
self.register_image(info, subset)

self.reg_infos_index_traverser += 1
if self.reg_infos_index_traverser % len(self.reg_infos_index) == 0:
self.reg_infos_index_traverser = 0
'''
if n < 5:
reg_img_log += f"\nRegistering image: {info.absolute_path}, count: {info.num_repeats}"
'''
n += 1

# logger.info(reg_img_log)
if distributed_state.is_main_process:
self.subset_loaded_count()
self.bucket_manager = None
if make_bucket:
self.make_buckets()
del temp_reg_infos
else:
logger.warning(f"Number of training images({self.num_train_images}) is the same as number of regularization images({self.num_reg_images}).\nSkipping randomized/incremental loading of regularization images.")

class FineTuningDataset(BaseDataset):
def __init__(
self,
Expand Down Expand Up @@ -2098,6 +2216,11 @@ def __init__(

self.conditioning_image_transforms = IMAGE_TRANSFORMS

def incremental_reg_load(self, make_bucket = False):
self.dreambooth_dataset_delegate.incremental_reg_load()
if make_bucket:
self.make_buckets()

def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
Expand Down Expand Up @@ -2185,9 +2308,13 @@ def add_replacement(self, str_from, str_to):
for dataset in self.datasets:
dataset.add_replacement(str_from, str_to)

# def make_buckets(self):
# for dataset in self.datasets:
# dataset.make_buckets()
def set_reg_randomize(self, reg_randomize = False):
for dataset in self.datasets:
dataset.set_reg_randomize(reg_randomize)

def make_buckets(self):
for dataset in self.datasets:
dataset.make_buckets()

def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
Expand Down Expand Up @@ -2234,7 +2361,14 @@ def set_max_train_steps(self, max_train_steps):
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()

def incremental_reg_load(self, make_bucket = False):
for dataset in self.datasets:
dataset.incremental_reg_load(make_bucket)

def __len__(self):
self.cumulative_sizes = self.cumsum(self.datasets)
return self.cumulative_sizes[-1]

def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
Expand Down Expand Up @@ -3579,6 +3713,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--incremental_reg_load",
action="store_true",
help="Forces reload of regularization images at each Epoch. Will sequentially load regularization images unless '--randomized_regularization_image' is set. Useful if there are more regularization images than training images",
)
parser.add_argument(
"--randomized_regularization_image",
action="store_true",
help="Shuffles regularization images to even out distribution. Useful if there are more regularization images than training images",
)

if support_dreambooth:
# DreamBooth training
Expand Down
Loading