From 34133359e33b8c4af479e89a26156a4967c795f0 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 24 Aug 2025 17:39:42 +0900 Subject: [PATCH 1/8] fix: support dataset with metadata --- library/train_util.py | 248 +++++++++++++++++------------------------- 1 file changed, 100 insertions(+), 148 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 395183957..61e421086 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -683,7 +683,7 @@ def __init__( resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, - resize_interpolation: Optional[str] = None + resize_interpolation: Optional[str] = None, ) -> None: super().__init__() @@ -719,7 +719,9 @@ def __init__( self.image_transforms = IMAGE_TRANSFORMS if resize_interpolation is not None: - assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + assert validate_interpolation_fn( + resize_interpolation + ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation self.image_data: Dict[str, ImageInfo] = {} @@ -1613,7 +1615,11 @@ def __getitem__(self, index): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation + subset.random_crop, + img, + image_info.bucket_reso, + image_info.resized_size, + resize_interpolation=image_info.resize_interpolation, ) else: if face_cx > 0: # 顔位置情報あり @@ -2101,7 +2107,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset): for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) - info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) if size is not None: info.image_size = size if subset.is_reg: @@ -2162,6 +2170,23 @@ def __init__( super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.latents_cache = None + + self.enable_bucket = enable_bucket + if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False self.num_train_images = 0 self.num_reg_images = 0 @@ -2193,28 +2218,43 @@ def __init__( ) continue + strategy = LatentsCachingStrategy.get_strategy() + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths = [os.path.basename(x) for x in npz_paths] + npz_paths = sorted(npz_paths, key=len, reverse=True) # make longer paths come first to speed up matching + tags_list = [] - for image_key, img_md in metadata.items(): - # path情報を作る - abs_path = None - # まず画像を優先して探す + # Match image filename longer to shorter because some images share same prefix + image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True) + + size_set_count = 0 + for image_key in image_keys_sorted_by_length_desc: + img_md = metadata[image_key] + + # make absolute path for image or npz + abs_path, npz_path = None, None + + # full path for image? + image_rel_key = image_key if os.path.exists(image_key): + image_rel_key = os.path.basename(image_key) abs_path = image_key else: - # わりといい加減だがいい方法が思いつかん + # relative path without extension paths = glob_images(subset.image_dir, image_key) if len(paths) > 0: abs_path = paths[0] - # なければnpzを探す - if abs_path is None: - if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): - abs_path = os.path.splitext(image_key)[0] + ".npz" - else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path + # search npz + npz_path = None + for candidate in npz_paths: + if candidate.startswith(image_rel_key): + npz_path = candidate + break + if npz_path is not None: + npz_paths.remove(npz_path) # remove to speed up next search + abs_path = abs_path or npz_path assert abs_path is not None, f"no image / 画像がありません: {image_key}" @@ -2244,14 +2284,20 @@ def __init__( caption = "" image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) - image_info.image_size = img_md.get("train_resolution") + image_info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) - if not subset.color_aug and not subset.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + # get image size from npz filename + if npz_path is not None and strategy is not None: + w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path) + image_info.image_size = (w, h) + size_set_count += 1 self.register_image(image_info, subset) + if size_set_count > 0: + logger.info(f"set image size from cache files: {size_set_count}/{len(image_keys_sorted_by_length_desc)}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag @@ -2259,117 +2305,6 @@ def __init__( subset.img_count = len(metadata) self.subsets.append(subset) - # check existence of all npz files - use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) - if use_npz_latents: - flip_aug_in_subset = False - npz_any = False - npz_all = True - - for image_info in self.image_data.values(): - subset = self.image_to_subset[image_info.image_key] - - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz - - if subset.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - flip_aug_in_subset = True - npz_all = npz_all and has_npz - - if npz_any and not npz_all: - break - - if not npz_any: - use_npz_latents = False - logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") - elif not npz_all: - use_npz_latents = False - logger.warning( - f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" - ) - if flip_aug_in_subset: - logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") - # else: - # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") - - # check min/max bucket size - sizes = set() - resos = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - resos.add(tuple(image_info.image_size)) - - if sizes is None: - if use_npz_latents: - use_npz_latents = False - logger.warning( - f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" - ) - - assert ( - resolution is not None - ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" - - self.enable_bucket = enable_bucket - if self.enable_bucket: - min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( - resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps - ) - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - if not enable_bucket: - logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") - self.enable_bucket = True - - assert ( - not bucket_no_upscale - ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - - # bucket情報を初期化しておく、make_bucketsで再作成しない - self.bucket_manager = BucketManager(False, None, None, None, None) - self.bucket_manager.set_predefined_resos(resos) - - # npz情報をきれいにしておく - if not use_npz_latents: - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None - - def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + ".npz" - - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + "_flip.npz" - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip - - # if not full path, check image_dir. if image_dir is None, return None - if subset.image_dir is None: - return None, None - - # image_key is relative path - npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") - npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") - - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None - - return npz_file_norm, npz_file_flip - class ControlNetDataset(BaseDataset): def __init__( @@ -2385,7 +2320,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2448,7 +2383,7 @@ def __init__( self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed self.resize_interpolation = resize_interpolation # assert all conditioning data exists @@ -2538,7 +2473,14 @@ def __getitem__(self, index): cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + original_size_hw[1], + original_size_hw[0], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2552,7 +2494,14 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + cond_img.shape[0], + cond_img.shape[1], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -3000,7 +2949,9 @@ def load_images_and_masks_for_caching( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) original_sizes.append(original_size) crop_ltrbs.append(crop_ltrb) @@ -3041,7 +2992,9 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -3482,9 +3435,9 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, - flux: str = None, # "dev", "schnell" or "chroma" + flux: str = None, # "dev", "schnell" or "chroma" lumina: str = None, - optional_metadata: dict[str, str] | None = None + optional_metadata: dict[str, str] | None = None, ): timestamp = time.time() @@ -3513,7 +3466,7 @@ def get_sai_model_spec( # Extract metadata_* fields from args and merge with optional_metadata extracted_metadata = {} - + # Extract all metadata_* attributes from args for attr_name in dir(args): if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): @@ -3523,7 +3476,7 @@ def get_sai_model_spec( field_name = attr_name[9:] # len("metadata_") = 9 if field_name not in ["title", "author", "description", "license", "tags"]: extracted_metadata[field_name] = value - + # Merge extracted metadata with provided optional_metadata all_optional_metadata = {**extracted_metadata} if optional_metadata: @@ -3546,7 +3499,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int - model_config=model_config, + model_config=model_config, optional_metadata=all_optional_metadata if all_optional_metadata else None, ) return metadata @@ -3562,7 +3515,7 @@ def get_sai_model_spec_dataclass( sd3: str = None, flux: str = None, lumina: str = None, - optional_metadata: dict[str, str] | None = None + optional_metadata: dict[str, str] | None = None, ) -> sai_model_spec.ModelSpecMetadata: """ Get ModelSpec metadata as a dataclass - preferred for new code. @@ -5558,11 +5511,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): - + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: return - + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): @@ -6279,7 +6233,6 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["renorm_cfg"] = float(m.group(1)) continue - except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) @@ -6647,4 +6600,3 @@ def moving_average(self) -> float: if losses == 0: return 0 return self.loss_total / losses - From 77ad20bc8f39e3595ccb3f733e1a6bc85d4c41d4 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 30 Aug 2025 15:47:47 +0900 Subject: [PATCH 2/8] feat: support another tagger model --- docs/wd14_tagger_README-en.md | 6 +- docs/wd14_tagger_README-ja.md | 6 +- finetune/tag_images_by_wd14_tagger.py | 434 +++++++++++++++++++------- 3 files changed, 328 insertions(+), 118 deletions(-) diff --git a/docs/wd14_tagger_README-en.md b/docs/wd14_tagger_README-en.md index 34f448823..48a4e9df4 100644 --- a/docs/wd14_tagger_README-en.md +++ b/docs/wd14_tagger_README-en.md @@ -5,9 +5,11 @@ This document is based on the information from this github page (https://github. Using onnx for inference is recommended. Please install onnx with the following command: ```powershell -pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 +pip install onnx onnxruntime-gpu ``` +See [the official documentation](https://onnxruntime.ai/docs/install/#python-installs) for more details. + The model weights will be automatically downloaded from Hugging Face. # Usage @@ -49,6 +51,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge # Options +All options can be checked with `python tag_images_by_wd14_tagger.py --help`. + ## General Options - `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately. diff --git a/docs/wd14_tagger_README-ja.md b/docs/wd14_tagger_README-ja.md index 58e9ede95..49d14636c 100644 --- a/docs/wd14_tagger_README-ja.md +++ b/docs/wd14_tagger_README-ja.md @@ -5,9 +5,11 @@ onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。 ```powershell -pip install onnx==1.15.0 onnxruntime-gpu==1.17.1 +pip install onnx onnxruntime-gpu ``` +詳細は[公式ドキュメント](https://onnxruntime.ai/docs/install/#python-installs)をご覧ください。 + モデルの重みはHugging Faceから自動的にダウンロードしてきます。 # 使い方 @@ -48,6 +50,8 @@ python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagge # オプション +全てオプションは `python tag_images_by_wd14_tagger.py --help` で確認できます。 + ## 一般オプション - `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。 diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 07a6510e6..17a75bd76 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,12 +1,13 @@ import argparse import csv +import json import os from pathlib import Path import cv2 import numpy as np import torch -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, errors from PIL import Image from tqdm import tqdm @@ -29,8 +30,22 @@ SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] +TAG_JSON_FILE = "tag_mapping.json" + + +def preprocess_image(image: Image.Image) -> np.ndarray: + # If image has transparency, convert to RGBA. If not, convert to RGB + if image.mode in ("RGBA", "LA") or "transparency" in image.info: + image = image.convert("RGBA") + elif image.mode != "RGB": + image = image.convert("RGB") + + # If image is RGBA, combine with white background + if image.mode == "RGBA": + background = Image.new("RGB", image.size, (255, 255, 255)) + background.paste(image, mask=image.split()[3]) # Use alpha channel as mask + image = background -def preprocess_image(image): image = np.array(image) image = image[:, :, ::-1] # RGB->BGR @@ -59,7 +74,7 @@ def __getitem__(self, idx): img_path = str(self.images[idx]) try: - image = Image.open(img_path).convert("RGB") + image = Image.open(img_path) image = preprocess_image(image) # tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・) except Exception as e: @@ -81,35 +96,64 @@ def collate_fn_remove_corrupted(batch): def main(args): # model location is model_dir + repo_id - # repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash - model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_")) + # given repo_id may be "namespace/repo_name" or "namespace/repo_name/subdir" + # so we split it to "namespace/reponame" and "subdir" + tokens = args.repo_id.split("/") + + if len(tokens) > 2: + repo_id = "/".join(tokens[:2]) + subdir = "/".join(tokens[2:]) + model_location = os.path.join(args.model_dir, repo_id.replace("/", "_"), subdir) + onnx_model_name = "model_optimized.onnx" + default_format = False + else: + repo_id = args.repo_id + subdir = None + model_location = os.path.join(args.model_dir, repo_id.replace("/", "_")) + onnx_model_name = "model.onnx" + default_format = True - # hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする - # depreacatedの警告が出るけどなくなったらその時 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 + if not os.path.exists(model_location) or args.force_download: os.makedirs(args.model_dir, exist_ok=True) logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - files = FILES - if args.onnx: - files = ["selected_tags.csv"] - files += FILES_ONNX - else: - for file in SUB_DIR_FILES: + + if subdir is None: + # SmilingWolf structure + files = FILES + if args.onnx: + files = ["selected_tags.csv"] + files += FILES_ONNX + else: + for file in SUB_DIR_FILES: + hf_hub_download( + repo_id=args.repo_id, + filename=file, + subfolder=SUB_DIR, + local_dir=os.path.join(model_location, SUB_DIR), + force_download=True, + ) + + for file in files: hf_hub_download( repo_id=args.repo_id, filename=file, - subfolder=SUB_DIR, - local_dir=os.path.join(model_location, SUB_DIR), + local_dir=model_location, + force_download=True, + ) + else: + # another structure + files = [onnx_model_name, "tag_mapping.json"] + + for file in files: + hf_hub_download( + repo_id=repo_id, + filename=file, + subfolder=subdir, + local_dir=os.path.join(args.model_dir, repo_id.replace("/", "_")), # because subdir is specified force_download=True, ) - for file in files: - hf_hub_download( - repo_id=args.repo_id, - filename=file, - local_dir=model_location, - force_download=True, - ) else: logger.info("using existing wd14 tagger model") @@ -118,7 +162,7 @@ def main(args): import onnx import onnxruntime as ort - onnx_path = f"{model_location}/model.onnx" + onnx_path = os.path.join(model_location, onnx_model_name) logger.info("Running wd14 tagger with onnx") logger.info(f"loading onnx model: {onnx_path}") @@ -150,39 +194,30 @@ def main(args): ort_sess = ort.InferenceSession( onnx_path, providers=(["OpenVINOExecutionProvider"]), - provider_options=[{'device_type' : "GPU", "precision": "FP32"}], + provider_options=[{"device_type": "GPU", "precision": "FP32"}], ) else: - ort_sess = ort.InferenceSession( - onnx_path, - providers=( - ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else - ["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else - ["CPUExecutionProvider"] - ), + providers = ( + ["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ( + ["ROCMExecutionProvider"] + if "ROCMExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"] + ) ) + logger.info(f"Using onnxruntime providers: {providers}") + ort_sess = ort.InferenceSession(onnx_path, providers=providers) else: from tensorflow.keras.models import load_model model = load_model(f"{model_location}") + # We read the CSV file manually to avoid adding dependencies. # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - - with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: - reader = csv.reader(f) - line = [row for row in reader] - header = line[0] # tag_id,name,category,count - rows = line[1:] - assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" - - rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] - general_tags = [row[1] for row in rows[0:] if row[2] == "0"] - character_tags = [row[1] for row in rows[0:] if row[2] == "4"] - - # preprocess tags in advance - if args.character_tag_expand: - for i, tag in enumerate(character_tags): + + def expand_character_tags(char_tags): + for i, tag in enumerate(char_tags): if tag.endswith(")"): # chara_name_(series) -> chara_name, series # chara_name_(costume)_(series) -> chara_name_(costume), series @@ -191,30 +226,86 @@ def main(args): if character_tag.endswith("_"): character_tag = character_tag[:-1] series_tag = tags[-1].replace(")", "") - character_tags[i] = character_tag + args.caption_separator + series_tag + char_tags[i] = character_tag + args.caption_separator + series_tag - if args.remove_underscore: - rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags] - general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags] - character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags] + def remove_underscore(tags): + return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags] - if args.tag_replacement is not None: - # escape , and ; in tag_replacement: wd14 tag names may contain , and ; - escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####") + def process_tag_replacement(tags: list[str], tag_replacements_arg: str): + # escape , and ; in tag_replacement: wd14 tag names may contain , and ;, + # so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means + # `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD` + escaped_tag_replacements = tag_replacements_arg.replace("\\,", "@@@@").replace("\\;", "####") tag_replacements = escaped_tag_replacements.split(";") - for tag_replacement in tag_replacements: - tags = tag_replacement.split(",") # source, target - assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" + + for tag_replacements_arg in tag_replacements: + tags = tag_replacements_arg.split(",") # source, target + assert ( + len(tags) == 2 + ), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}" source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags] logger.info(f"replacing tag: {source} -> {target}") - if source in general_tags: - general_tags[general_tags.index(source)] = target - elif source in character_tags: - character_tags[character_tags.index(source)] = target - elif source in rating_tags: - rating_tags[rating_tags.index(source)] = target + if source in tags: + tags[tags.index(source)] = target + + if default_format: + with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: + reader = csv.reader(f) + line = [row for row in reader] + header = line[0] # tag_id,name,category,count + rows = line[1:] + assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}" + + rating_tags = [row[1] for row in rows[0:] if row[2] == "9"] + general_tags = [row[1] for row in rows[0:] if row[2] == "0"] + character_tags = [row[1] for row in rows[0:] if row[2] == "4"] + + if args.character_tag_expand: + expand_character_tags(character_tags) + if args.remove_underscore: + rating_tags = remove_underscore(rating_tags) + character_tags = remove_underscore(character_tags) + general_tags = remove_underscore(general_tags) + if args.tag_replacement is not None: + process_tag_replacement(rating_tags, args.tag_replacement) + process_tag_replacement(general_tags, args.tag_replacement) + process_tag_replacement(character_tags, args.tag_replacement) + else: + with open(os.path.join(model_location, TAG_JSON_FILE), "r", encoding="utf-8") as f: + tag_mapping = json.load(f) + + rating_tags = [] + general_tags = [] + character_tags = [] + + tag_id_to_tag_mapping = {} + tag_id_to_category_mapping = {} + for tag_id, tag_info in tag_mapping.items(): + tag = tag_info["tag"] + category = tag_info["category"] + assert category in [ + "Rating", + "General", + "Character", + "Copyright", + "Meta", + "Model", + "Quality", + ], f"unexpected category: {category}" + + if args.remove_underscore: + tag = remove_underscore([tag])[0] + if args.tag_replacement is not None: + tag = process_tag_replacement([tag], args.tag_replacement)[0] + if category == "Character" and args.character_tag_expand: + tag_list = [tag] + expand_character_tags(tag_list) + tag = tag_list[0] + + tag_id_to_tag_mapping[int(tag_id)] = tag + tag_id_to_category_mapping[int(tag_id)] = category # 画像を読み込む train_data_dir_path = Path(args.train_data_dir) @@ -238,6 +329,9 @@ def run_batch(path_imgs): if args.onnx: # if len(imgs) < args.batch_size: # imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + if not default_format: + imgs = imgs.transpose(0, 3, 1, 2) # to NCHW + imgs = imgs / 127.5 - 1.0 probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy probs = probs[: len(path_imgs)] else: @@ -249,42 +343,112 @@ def run_batch(path_imgs): rating_tag_text = "" character_tag_text = "" general_tag_text = "" - - # 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する - # First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold - for i, p in enumerate(prob[4:]): - if i < len(general_tags) and p >= args.general_threshold: - tag_name = general_tags[i] - - if tag_name not in undesired_tags: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += caption_separator + tag_name - combined_tags.append(tag_name) - elif i >= len(general_tags) and p >= args.character_threshold: - tag_name = character_tags[i - len(general_tags)] - - if tag_name not in undesired_tags: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += caption_separator + tag_name - if args.character_tags_first: # insert to the beginning - combined_tags.insert(0, tag_name) - else: + other_tag_text = "" + + if default_format: + # 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する + # First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold + for i, p in enumerate(prob[4:]): + if i < len(general_tags) and p >= args.general_threshold: + tag_name = general_tags[i] + + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) - - # 最初の4つはratingなのでargmaxで選ぶ - # First 4 labels are actually ratings: pick one with argmax - if args.use_rating_tags or args.use_rating_tags_as_last_tag: - ratings_probs = prob[:4] - rating_index = ratings_probs.argmax() - found_rating = rating_tags[rating_index] - - if found_rating not in undesired_tags: - tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 - rating_tag_text = found_rating - if args.use_rating_tags: - combined_tags.insert(0, found_rating) # insert to the beginning - else: - combined_tags.append(found_rating) + elif i >= len(general_tags) and p >= args.character_threshold: + tag_name = character_tags[i - len(general_tags)] + + if tag_name not in undesired_tags: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # insert to the beginning + combined_tags.insert(0, tag_name) + else: + combined_tags.append(tag_name) + + # 最初の4つはratingなのでargmaxで選ぶ + # First 4 labels are actually ratings: pick one with argmax + if args.use_rating_tags or args.use_rating_tags_as_last_tag: + ratings_probs = prob[:4] + rating_index = ratings_probs.argmax() + found_rating = rating_tags[rating_index] + + if found_rating not in undesired_tags: + tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1 + rating_tag_text = found_rating + if args.use_rating_tags: + combined_tags.insert(0, found_rating) # insert to the beginning + else: + combined_tags.append(found_rating) + else: + # apply sigmoid to probabilities + prob = 1 / (1 + np.exp(-prob)) + + rating_max_prob = -1 + rating_tag = None + quality_max_prob = -1 + quality_tag = None + character_tags = [] + for i, p in enumerate(prob): + if i in tag_id_to_tag_mapping and p >= args.thresh: + tag_name = tag_id_to_tag_mapping[i] + category = tag_id_to_category_mapping[i] + if tag_name in undesired_tags: + continue + + if category == "Rating": + if p > rating_max_prob: + rating_max_prob = p + rating_tag = tag_name + rating_tag_text = tag_name + continue + elif category == "Quality": + if p > quality_max_prob: + quality_max_prob = p + quality_tag = tag_name + if args.use_quality_tags or args.use_quality_tags_as_last_tag: + other_tag_text += caption_separator + tag_name + continue + + if category == "General" and p >= args.general_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += caption_separator + tag_name + combined_tags.append((tag_name, p)) + elif category == "Character" and p >= args.character_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # we separate character tags + character_tags.append((tag_name, p)) + else: + combined_tags.append((tag_name, p)) + elif ( + (category == "Copyright" and p >= args.copyright_threshold) + or (category == "Meta" and p >= args.meta_threshold) + or (category == "Model" and p >= args.model_threshold) + ): + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + other_tag_text += f"{caption_separator}{tag_name} ({category})" + combined_tags.append((tag_name, p)) + + # sort by probability + combined_tags.sort(key=lambda x: x[1], reverse=True) + if character_tags: + print(character_tags) + character_tags.sort(key=lambda x: x[1], reverse=True) + combined_tags = character_tags + combined_tags + combined_tags = [t[0] for t in combined_tags] # remove probability + + if quality_tag is not None: + if args.use_quality_tags_as_last_tag: + combined_tags.append(quality_tag) + elif args.use_quality_tags: + combined_tags.insert(0, quality_tag) + if rating_tag is not None: + if args.use_rating_tags_as_last_tag: + combined_tags.append(rating_tag) + elif args.use_rating_tags: + combined_tags.insert(0, rating_tag) # 一番最初に置くタグを指定する # Always put some tags at the beginning @@ -299,6 +463,8 @@ def run_batch(path_imgs): general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: character_tag_text = character_tag_text[len(caption_separator) :] + if len(other_tag_text) > 0: + other_tag_text = other_tag_text[len(caption_separator) :] caption_file = os.path.splitext(image_path)[0] + args.caption_extension @@ -328,6 +494,8 @@ def run_batch(path_imgs): logger.info(f"\tRating tags: {rating_tag_text}") logger.info(f"\tCharacter tags: {character_tag_text}") logger.info(f"\tGeneral tags: {general_tag_text}") + if other_tag_text: + logger.info(f"\tOther tags: {other_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -353,8 +521,6 @@ def run_batch(path_imgs): if image is None: try: image = Image.open(image_path) - if image.mode != "RGB": - image = image.convert("RGB") image = preprocess_image(image) except Exception as e: logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") @@ -381,9 +547,7 @@ def run_batch(path_imgs): def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() - parser.add_argument( - "train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ" - ) + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( "--repo_id", type=str, @@ -401,9 +565,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします", ) - parser.add_argument( - "--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ" - ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", type=int, @@ -432,7 +594,29 @@ def setup_parser() -> argparse.ArgumentParser: "--character_threshold", type=float, default=None, - help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ", + help="threshold of confidence to add a tag for character category, same as --thres if omitted. set above 1 to disable character tags" + " / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcharacterタグを無効化できる", + ) + parser.add_argument( + "--meta_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for meta category, same as --thresh if omitted. set above 1 to disable meta tags" + " / metaカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmetaタグを無効化できる", + ) + parser.add_argument( + "--model_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for model category, same as --thresh if omitted. set above 1 to disable model tags" + " / modelカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとmodelタグを無効化できる", + ) + parser.add_argument( + "--copyright_threshold", + type=float, + default=None, + help="threshold of confidence to add a tag for copyright category, same as --thresh if omitted. set above 1 to disable copyright tags" + " / copyrightカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ。1以上にするとcopyrightタグを無効化できる", ) parser.add_argument( "--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する" @@ -442,9 +626,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える", ) - parser.add_argument( - "--debug", action="store_true", help="debug mode" - ) + parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument( "--undesired_tags", type=str, @@ -454,20 +636,34 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する" ) + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument( - "--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する" + "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" ) parser.add_argument( - "--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する" + "--use_rating_tags", + action="store_true", + help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", ) parser.add_argument( - "--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する", + "--use_rating_tags_as_last_tag", + action="store_true", + help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", + ) + parser.add_argument( + "--use_quality_tags", + action="store_true", + help="Adds quality tags as the first tag / クオリティタグを最初のタグとして追加する", ) parser.add_argument( - "--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する", + "--use_quality_tags_as_last_tag", + action="store_true", + help="Adds quality tags as the last tag / クオリティタグを最後のタグとして追加する", ) parser.add_argument( - "--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", + "--character_tags_first", + action="store_true", + help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する", ) parser.add_argument( "--always_first_tags", @@ -512,5 +708,11 @@ def setup_parser() -> argparse.ArgumentParser: args.general_threshold = args.thresh if args.character_threshold is None: args.character_threshold = args.thresh + if args.meta_threshold is None: + args.meta_threshold = args.thresh + if args.model_threshold is None: + args.model_threshold = args.thresh + if args.copyright_threshold is None: + args.copyright_threshold = args.thresh main(args) From 9e661a5eb0d70b71d0dbb020ea9759a14058aa97 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 20:13:09 +0900 Subject: [PATCH 3/8] fix: improve handling of image size and caption/tag processing in FineTuningDataset --- library/train_util.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index caafcc28f..efc34ab02 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2253,31 +2253,35 @@ def __init__( npz_path = candidate break if npz_path is not None: - npz_paths.remove(npz_path) # remove to speed up next search + npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix) abs_path = abs_path or npz_path assert abs_path is not None, f"no image / 画像がありません: {image_key}" caption = img_md.get("caption") tags = img_md.get("tags") + image_size = img_md.get("image_size") + if caption is None: - caption = tags # could be multiline - tags = None + caption = "" if subset.enable_wildcard: - # tags must be single line + # tags must be single line (split by caption separator) if tags is not None: tags = tags.replace("\n", subset.caption_separator) # add tags to each line of caption - if caption is not None and tags is not None: + if tags is not None: caption = "\n".join( [f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""] ) + tags_list.append(tags) else: # use as is if tags is not None and len(tags) > 0: - caption = caption + subset.caption_separator + tags + if len(caption) > 0: + caption = caption + subset.caption_separator + caption = caption + tags tags_list.append(tags) if caption is None: @@ -2288,8 +2292,10 @@ def __init__( subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation ) - # get image size from npz filename - if npz_path is not None and strategy is not None: + if image_size is not None: + image_info.image_size = tuple(image_size) # width, height + elif npz_path is not None and strategy is not None: + # get image size from npz filename w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path) image_info.image_size = (w, h) size_set_count += 1 From 668d188def2a34d12864255d09ed495eb7ea3e9a Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 20:25:12 +0900 Subject: [PATCH 4/8] fix: enhance metadata loading to support JSONL format in FineTuningDataset --- library/train_util.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index efc34ab02..29b61bf3f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2206,9 +2206,25 @@ def __init__( # メタデータを読み込む if os.path.exists(subset.metadata_file): - logger.info(f"loading existing metadata: {subset.metadata_file}") - with open(subset.metadata_file, "rt", encoding="utf-8") as f: - metadata = json.load(f) + if subset.metadata_file.endswith(".jsonl"): + logger.info(f"loading existing JSOL metadata: {subset.metadata_file}") + # optional JSONL format + # {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1", "image_size": [width, height]} + metadata = {} + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + for line in f: + line_md = json.loads(line) + image_md = {"caption": line_md.get("caption", "")} + if "image_size" in line_md: + image_md["image_size"] = line_md["image_size"] + if "tags" in line_md: + image_md["tags"] = line_md["tags"] + metadata[line_md["image_path"]] = image_md + else: + # standard JSON format + logger.info(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) else: raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") From 0fa4a6baa8cb0e234b2948d0433a2b4905d62c81 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 20:25:22 +0900 Subject: [PATCH 5/8] feat: enhance image loading and processing in ImageLoadingPrepDataset with batch support and output options --- finetune/tag_images_by_wd14_tagger.py | 269 ++++++++++++++++---------- 1 file changed, 162 insertions(+), 107 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 17a75bd76..9def06da7 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,13 +1,14 @@ import argparse import csv import json +import math import os from pathlib import Path +from typing import Optional -import cv2 import numpy as np import torch -from huggingface_hub import hf_hub_download, errors +from huggingface_hub import hf_hub_download from PIL import Image from tqdm import tqdm @@ -64,33 +65,40 @@ def preprocess_image(image: Image.Image) -> np.ndarray: class ImageLoadingPrepDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths: list[str], batch_size: int): + self.image_paths = image_paths + self.batch_size = batch_size def __len__(self): - return len(self.images) + return math.ceil(len(self.image_paths) / self.batch_size) - def __getitem__(self, idx): - img_path = str(self.images[idx]) + def __getitem__(self, batch_index: int) -> tuple[str, np.ndarray, tuple[int, int]]: + image_index_start = batch_index * self.batch_size + image_index_end = min((batch_index + 1) * self.batch_size, len(self.image_paths)) - try: - image = Image.open(img_path) - image = preprocess_image(image) - # tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None - - return (image, img_path) - - -def collate_fn_remove_corrupted(batch): - """Collate function that allows to remove corrupted examples in the - dataloader. It expects that the dataloader returns 'None' when that occurs. - The 'None's in the batch are removed. - """ - # Filter out all the Nones (corrupted examples) - batch = list(filter(lambda x: x is not None, batch)) + batch_image_paths = [] + images = [] + image_sizes = [] + for idx in range(image_index_start, image_index_end): + img_path = str(self.image_paths[idx]) + + try: + image = Image.open(img_path) + image_size = image.size + image = preprocess_image(image) + + batch_image_paths.append(img_path) + images.append(image) + image_sizes.append(image_size) + except Exception as e: + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + + images = np.stack(images) if len(images) > 0 else np.zeros((0, IMAGE_SIZE, IMAGE_SIZE, 3)) + return batch_image_paths, images, image_sizes + + +def collate_fn_no_op(batch): + """Collate function that does nothing and returns the batch as is.""" return batch @@ -311,6 +319,7 @@ def process_tag_replacement(tags: list[str], tag_replacements_arg: str): train_data_dir_path = Path(args.train_data_dir) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) logger.info(f"found {len(image_paths)} images.") + image_paths = [str(ip) for ip in image_paths] tag_freq = {} @@ -323,8 +332,11 @@ def process_tag_replacement(tags: list[str], tag_replacements_arg: str): if args.always_first_tags is not None: always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""] - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) + def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[list[str]]: + nonlocal args, default_format, model, ort_sess, input_name, tag_freq + + imgs = path_imgs[1] + result = {} if args.onnx: # if len(imgs) < args.batch_size: @@ -333,12 +345,12 @@ def run_batch(path_imgs): imgs = imgs.transpose(0, 3, 1, 2) # to NCHW imgs = imgs / 127.5 - 1.0 probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy - probs = probs[: len(path_imgs)] + probs = probs[: len(imgs)] # remove padding else: probs = model(imgs, training=False) probs = probs.numpy() - for (image_path, _), prob in zip(path_imgs, probs): + for image_path, image_size, prob in zip(path_imgs[0], path_imgs[2], probs): combined_tags = [] rating_tag_text = "" character_tag_text = "" @@ -390,51 +402,64 @@ def run_batch(path_imgs): quality_max_prob = -1 quality_tag = None character_tags = [] - for i, p in enumerate(prob): - if i in tag_id_to_tag_mapping and p >= args.thresh: - tag_name = tag_id_to_tag_mapping[i] - category = tag_id_to_category_mapping[i] - if tag_name in undesired_tags: - continue - - if category == "Rating": - if p > rating_max_prob: - rating_max_prob = p - rating_tag = tag_name - rating_tag_text = tag_name - continue - elif category == "Quality": - if p > quality_max_prob: - quality_max_prob = p - quality_tag = tag_name - if args.use_quality_tags or args.use_quality_tags_as_last_tag: - other_tag_text += caption_separator + tag_name - continue - - if category == "General" and p >= args.general_threshold: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += caption_separator + tag_name - combined_tags.append((tag_name, p)) - elif category == "Character" and p >= args.character_threshold: - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += caption_separator + tag_name - if args.character_tags_first: # we separate character tags - character_tags.append((tag_name, p)) - else: - combined_tags.append((tag_name, p)) - elif ( - (category == "Copyright" and p >= args.copyright_threshold) - or (category == "Meta" and p >= args.meta_threshold) - or (category == "Model" and p >= args.model_threshold) - ): - tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - other_tag_text += f"{caption_separator}{tag_name} ({category})" + + min_thres = min( + args.thresh, + args.general_threshold, + args.character_threshold, + args.copyright_threshold, + args.meta_threshold, + args.model_threshold, + ) + prob_indices = np.where(prob >= min_thres)[0] + # for i, p in enumerate(prob): + for i in prob_indices: + if i not in tag_id_to_tag_mapping: + continue + p = prob[i] + + tag_name = tag_id_to_tag_mapping[i] + category = tag_id_to_category_mapping[i] + if tag_name in undesired_tags: + continue + + if category == "Rating": + if p > rating_max_prob: + rating_max_prob = p + rating_tag = tag_name + rating_tag_text = tag_name + continue + elif category == "Quality": + if p > quality_max_prob: + quality_max_prob = p + quality_tag = tag_name + if args.use_quality_tags or args.use_quality_tags_as_last_tag: + other_tag_text += caption_separator + tag_name + continue + + if category == "General" and p >= args.general_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + general_tag_text += caption_separator + tag_name + combined_tags.append((tag_name, p)) + elif category == "Character" and p >= args.character_threshold: + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + character_tag_text += caption_separator + tag_name + if args.character_tags_first: # we separate character tags + character_tags.append((tag_name, p)) + else: combined_tags.append((tag_name, p)) + elif ( + (category == "Copyright" and p >= args.copyright_threshold) + or (category == "Meta" and p >= args.meta_threshold) + or (category == "Model" and p >= args.model_threshold) + ): + tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 + other_tag_text += f"{caption_separator}{tag_name} ({category})" + combined_tags.append((tag_name, p)) # sort by probability combined_tags.sort(key=lambda x: x[1], reverse=True) if character_tags: - print(character_tags) character_tags.sort(key=lambda x: x[1], reverse=True) combined_tags = character_tags + combined_tags combined_tags = [t[0] for t in combined_tags] # remove probability @@ -486,55 +511,79 @@ def run_batch(path_imgs): # Create new tag_text tag_text = caption_separator.join(existing_tags + new_tags) - with open(caption_file, "wt", encoding="utf-8") as f: - f.write(tag_text + "\n") - if args.debug: - logger.info("") - logger.info(f"{image_path}:") - logger.info(f"\tRating tags: {rating_tag_text}") - logger.info(f"\tCharacter tags: {character_tag_text}") - logger.info(f"\tGeneral tags: {general_tag_text}") - if other_tag_text: - logger.info(f"\tOther tags: {other_tag_text}") + if not args.output_path: + with open(caption_file, "wt", encoding="utf-8") as f: + f.write(tag_text + "\n") + else: + entry = {"tags": tag_text, "image_size": list(image_size)} + result[image_path] = entry + + if args.debug: + logger.info("") + logger.info(f"{image_path}:") + logger.info(f"\tRating tags: {rating_tag_text}") + logger.info(f"\tCharacter tags: {character_tag_text}") + logger.info(f"\tGeneral tags: {general_tag_text}") + if other_tag_text: + logger.info(f"\tOther tags: {other_tag_text}") + + return result # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: - dataset = ImageLoadingPrepDataset(image_paths) + dataset = ImageLoadingPrepDataset(image_paths, args.batch_size) data = torch.utils.data.DataLoader( dataset, - batch_size=args.batch_size, + batch_size=1, shuffle=False, num_workers=args.max_data_loader_n_workers, - collate_fn=collate_fn_remove_corrupted, + collate_fn=collate_fn_no_op, drop_last=False, ) else: - data = [[(None, ip)] for ip in image_paths] - - b_imgs = [] + # data = [[(ip, None, None)] for ip in image_paths] + data = [[]] + for ip in image_paths: + if len(data[-1]) >= args.batch_size: + data.append([]) + data[-1].append((ip, None, None)) + + results = {} for data_entry in tqdm(data, smoothing=0.0): - for data in data_entry: - if data is None: - continue - - image, image_path = data - if image is None: - try: - image = Image.open(image_path) - image = preprocess_image(image) - except Exception as e: - logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") - continue - b_imgs.append((image_path, image)) - - if len(b_imgs) >= args.batch_size: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) - b_imgs.clear() - - if len(b_imgs) > 0: - b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string - run_batch(b_imgs) + if data_entry is None or len(data_entry) == 0: + continue + + if data_entry[0][1] is None: + # No preloaded image, need to load + images = [] + image_sizes = [] + for image_path, _, _ in data_entry: + image = Image.open(image_path) + image_size = image.size + image = preprocess_image(image) + images.append(image) + image_sizes.append(image_size) + b_imgs = ([ip for ip, _, _ in data_entry], np.stack(images), image_sizes) + else: + b_imgs = data_entry[0] + + r = run_batch(b_imgs) + if args.output_path and r is not None: + results.update(r) + + if args.output_path: + if args.output_path.endswith(".jsonl"): + # optional JSONL metadata + with open(args.output_path, "wt", encoding="utf-8") as f: + for image_path, entry in results.items(): + f.write( + json.dumps({"image_path": image_path, "caption": entry["tags"], "image_size": entry["image_size"]}) + "\n" + ) + else: + # standard JSON metadata + with open(args.output_path, "wt", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=4) + logger.info(f"captions saved to {args.output_path}") if args.frequency_tags: sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) @@ -572,6 +621,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--output_path", + type=str, + default=None, + help="path for output captions (json format). if this is set, captions will be saved to this file / 出力キャプションのパス(json形式)。このオプションが設定されている場合、キャプションはこのファイルに保存されます", + ) parser.add_argument( "--caption_extention", type=str, From f3d5b063376ea554eb0f8a21977990a43eb6dedc Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 22:00:20 +0900 Subject: [PATCH 6/8] fix: improve image path handling and memory management in dataset classes --- library/train_util.py | 90 ++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 29b61bf3f..131cb6129 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1131,7 +1131,8 @@ def __init__(self, reso, flip_aug, alpha_mask, random_crop): def __eq__(self, other): return ( - self.reso == other.reso + other is not None + and self.reso == other.reso and self.flip_aug == other.flip_aug and self.alpha_mask == other.alpha_mask and self.random_crop == other.random_crop @@ -1193,6 +1194,8 @@ def submit_batch(batch, cond): if len(batch) > 0 and current_condition != condition: submit_batch(batch, current_condition) batch = [] + if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed + clean_memory_on_device(accelerator.device) if info.image is None: # load image in parallel @@ -1205,7 +1208,7 @@ def submit_batch(batch, cond): if len(batch) >= caching_strategy.batch_size: submit_batch(batch, current_condition) batch = [] - current_condition = None + # current_condition = None if len(batch) > 0: submit_batch(batch, current_condition) @@ -2234,49 +2237,53 @@ def __init__( ) continue - strategy = LatentsCachingStrategy.get_strategy() - npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths = [os.path.basename(x) for x in npz_paths] - npz_paths = sorted(npz_paths, key=len, reverse=True) # make longer paths come first to speed up matching + # Add full path for image + image_dirs = set() + if subset.image_dir is not None: + image_dirs.add(subset.image_dir) + for image_key in metadata.keys(): + if not os.path.isabs(image_key): + assert ( + subset.image_dir is not None + ), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}" + abs_path = os.path.join(subset.image_dir, image_key) + else: + abs_path = image_key + image_dirs.add(os.path.dirname(abs_path)) + metadata[image_key]["abs_path"] = abs_path - tags_list = [] + # Enumerate existing npz files + strategy = LatentsCachingStrategy.get_strategy() + npz_paths = [] + for image_dir in image_dirs: + npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix))) + npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first # Match image filename longer to shorter because some images share same prefix image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True) - size_set_count = 0 + # Collect tags and sizes + tags_list = [] + size_set_from_metadata = 0 + size_set_from_cache_filename = 0 for image_key in image_keys_sorted_by_length_desc: img_md = metadata[image_key] - - # make absolute path for image or npz - abs_path, npz_path = None, None - - # full path for image? - image_rel_key = image_key - if os.path.exists(image_key): - image_rel_key = os.path.basename(image_key) - abs_path = image_key - else: - # relative path without extension - paths = glob_images(subset.image_dir, image_key) - if len(paths) > 0: - abs_path = paths[0] - - # search npz - npz_path = None - for candidate in npz_paths: - if candidate.startswith(image_rel_key): - npz_path = candidate - break - if npz_path is not None: - npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix) - abs_path = abs_path or npz_path - - assert abs_path is not None, f"no image / 画像がありません: {image_key}" - caption = img_md.get("caption") tags = img_md.get("tags") image_size = img_md.get("image_size") + abs_path = img_md.get("abs_path") + + # search npz if image_size is not given + npz_path = None + if image_size is None: + image_without_ext = os.path.splitext(image_key)[0] + for candidate in npz_paths: + if candidate.startswith(image_without_ext): + npz_path = candidate + break + if npz_path is not None: + npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix) + abs_path = npz_path if caption is None: caption = "" @@ -2310,16 +2317,21 @@ def __init__( if image_size is not None: image_info.image_size = tuple(image_size) # width, height - elif npz_path is not None and strategy is not None: + size_set_from_metadata += 1 + elif npz_path is not None: # get image size from npz filename w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path) image_info.image_size = (w, h) - size_set_count += 1 + size_set_from_cache_filename += 1 self.register_image(image_info, subset) - if size_set_count > 0: - logger.info(f"set image size from cache files: {size_set_count}/{len(image_keys_sorted_by_length_desc)}") + if size_set_from_cache_filename > 0: + logger.info( + f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}" + ) + if size_set_from_metadata > 0: + logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag From 3f01189bd9e266f326e6cc46f9079c3c64842d8e Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 22:11:09 +0900 Subject: [PATCH 7/8] Update finetune/tag_images_by_wd14_tagger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- finetune/tag_images_by_wd14_tagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 9def06da7..38d8449b1 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -332,7 +332,7 @@ def process_tag_replacement(tags: list[str], tag_replacements_arg: str): if args.always_first_tags is not None: always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""] - def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[list[str]]: + def run_batch(path_imgs: tuple[list[str], np.ndarray, list[tuple[int, int]]]) -> Optional[dict[str, dict]]: nonlocal args, default_format, model, ort_sess, input_name, tag_freq imgs = path_imgs[1] From b1b9dec980a0056f4d2b57ef10b00af7e4d98681 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 22:13:40 +0900 Subject: [PATCH 8/8] fix: add return type annotation for process_tag_replacement function and ensure tags are returned --- finetune/tag_images_by_wd14_tagger.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 38d8449b1..5c426c569 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -239,7 +239,7 @@ def expand_character_tags(char_tags): def remove_underscore(tags): return [tag.replace("_", " ") if len(tag) > 3 else tag for tag in tags] - def process_tag_replacement(tags: list[str], tag_replacements_arg: str): + def process_tag_replacement(tags: list[str], tag_replacements_arg: str) -> list[str]: # escape , and ; in tag_replacement: wd14 tag names may contain , and ;, # so user must be specified them like `aa\,bb,AA\,BB;cc\;dd,CC\;DD` which means # `aa,bb` is replaced with `AA,BB` and `cc;dd` is replaced with `CC;DD` @@ -258,6 +258,8 @@ def process_tag_replacement(tags: list[str], tag_replacements_arg: str): if source in tags: tags[tags.index(source)] = target + return tags + if default_format: with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f: reader = csv.reader(f)