diff --git a/library/train_util.py b/library/train_util.py index b9d08f253..33af29f2c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -66,7 +66,12 @@ from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np +import sys from PIL import Image +try: + from PIL import ImageCms +except: + print( "ImageCms not available. Images will not be converted to sRGB. Colours may be handled incorrectly." ) import imagesize import cv2 import safetensors.torch @@ -2490,10 +2495,36 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: + if getattr(image, "is_animated", False): + logger.warning( f"{image_path} is animated" ) + + # Convert image to sRGB + if "PIL.ImageCms" in sys.modules: + icc = image.info.get('icc_profile', '') + if icc: + try: + src_profile = ImageCms.ImageCmsProfile( BytesIO(icc) ) + srgb_profile = ImageCms.createProfile("sRGB") + ImageCms.profileToProfile(image, src_profile, srgb_profile, inPlace=True) + image.info["icc_profile"] = ImageCms.ImageCmsProfile(srgb_profile).tobytes() + except Exception as e: + logger.warning( f"Could not convert {image_path} to sRGB: {src_profile.profile.model} {src_profile.profile.profile_description}\n{e}" ) + if alpha: if not image.mode == "RGBA": image = image.convert("RGBA") else: + if image.mode == "P": + # Palette images with alpha are easier to handle as RGBA. + image = image.convert('RGBA') + + if "A" in image.getbands(): + # Replace transparency with white background. + alpha_layer = image.convert('RGBA').split()[-1] + bg = Image.new("RGBA", image.size, (255, 255, 255, 255) ) + bg.paste( image, mask=alpha_layer ) + image = bg.convert('RGB') + if not image.mode == "RGB": image = image.convert("RGB") img = np.array(image, np.uint8)