diff --git a/tagger/dbimutils.py b/tagger/dbimutils.py index 332b25b..08588f0 100644 --- a/tagger/dbimutils.py +++ b/tagger/dbimutils.py @@ -5,6 +5,32 @@ from PIL import Image +def fill_transparent(image: Image.Image, color='WHITE'): + image = image.convert('RGBA') + new_image = Image.new('RGBA', image.size, color) + new_image.paste(image, mask=image) + image = new_image.convert('RGB') + return image + + +def resize(pic: Image.Image, size: int, keep_ratio=True) -> Image.Image: + if not keep_ratio: + target_size = (size, size) + else: + min_edge = min(pic.size) + target_size = ( + int(pic.size[0] / min_edge * size), + int(pic.size[1] / min_edge * size), + ) + + target_size = ( + (target_size[0] // 4) * 4, + (target_size[1] // 4) * 4, + ) + + return pic.resize(target_size, resample=Image.Resampling.LANCZOS) + + def smart_imread(img, flag=cv2.IMREAD_UNCHANGED): if img.endswith(".gif"): img = Image.open(img) diff --git a/tagger/interrogator.py b/tagger/interrogator.py index b644a5f..9186929 100644 --- a/tagger/interrogator.py +++ b/tagger/interrogator.py @@ -1,3 +1,4 @@ +import json import os import gc import pandas as pd @@ -7,7 +8,6 @@ from io import BytesIO from PIL import Image -from pathlib import Path from huggingface_hub import hf_hub_download from modules import shared @@ -20,8 +20,13 @@ use_cpu = ('all' in shared.cmd_opts.use_cpu) or ( 'interrogate' in shared.cmd_opts.use_cpu) +# https://onnxruntime.ai/docs/execution-providers/ +# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 +onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + if use_cpu: tf_device_name = '/cpu:0' + onnxrt_providers.pop(0) else: tf_device_name = '/gpu:0' @@ -94,13 +99,13 @@ def load(self): def unload(self) -> bool: unloaded = False - if hasattr(self, 'model') and self.model is not None: - del self.model + if self.model is not None: + self.model = None unloaded = True + gc.collect() print(f'Unloaded {self.name}') - if hasattr(self, 'tags'): - del self.tags + self.tags = None return unloaded @@ -118,6 +123,7 @@ class DeepDanbooruInterrogator(Interrogator): def __init__(self, name: str, project_path: os.PathLike) -> None: super().__init__(name) self.project_path = project_path + self.model = None def load(self) -> None: print(f'Loading {self.name} from {str(self.project_path)}') @@ -183,7 +189,7 @@ def interrogate( Dict[str, float] # tag confidents ]: # init model - if not hasattr(self, 'model') or self.model is None: + if self.model is None: self.load() import deepdanbooru.data as ddd @@ -212,54 +218,57 @@ def interrogate( return ratings, tags +def get_onnxrt(): + try: + import onnxruntime + return onnxruntime + except ImportError: + # only one of these packages should be installed at a time in any one environment + # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime + # TODO: remove old package when the environment changes? + from launch import is_installed, run_pip + if not is_installed('onnxruntime'): + package = os.environ.get( + 'ONNXRUNTIME_PACKAGE', + 'onnxruntime-gpu' + ) + + run_pip(f'install {package}', 'onnxruntime') + + import onnxruntime + return onnxruntime + + class WaifuDiffusionInterrogator(Interrogator): def __init__( self, name: str, + repo_id: str, model_path='model.onnx', tags_path='selected_tags.csv', **kwargs ) -> None: super().__init__(name) + self.repo_id = repo_id self.model_path = model_path self.tags_path = tags_path - self.kwargs = kwargs + self.model = None + self.tags = None - def download(self) -> Tuple[os.PathLike, os.PathLike]: - print(f"Loading {self.name} model file from {self.kwargs['repo_id']}") + def download(self) -> Tuple[str, str]: + print(f"Loading {self.name} model file from {self.repo_id}") - model_path = Path(hf_hub_download( - **self.kwargs, filename=self.model_path)) - tags_path = Path(hf_hub_download( - **self.kwargs, filename=self.tags_path)) + model_path = hf_hub_download(self.repo_id, self.model_path) + tags_path = hf_hub_download(self.repo_id, self.model_path) return model_path, tags_path def load(self) -> None: model_path, tags_path = self.download() - # only one of these packages should be installed at a time in any one environment - # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime - # TODO: remove old package when the environment changes? - from launch import is_installed, run_pip - if not is_installed('onnxruntime'): - package = os.environ.get( - 'ONNXRUNTIME_PACKAGE', - 'onnxruntime-gpu' - ) - - run_pip(f'install {package}', 'onnxruntime') + ort = get_onnxrt() + self.model = ort.InferenceSession(model_path, providers=onnxrt_providers) - from onnxruntime import InferenceSession - - # https://onnxruntime.ai/docs/execution-providers/ - # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 - providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] - if use_cpu: - providers.pop(0) - - self.model = InferenceSession(str(model_path), providers=providers) - - print(f'Loaded {self.name} model from {model_path}') + print(f'Loaded {self.name} model from {self.repo_id}') self.tags = pd.read_csv(tags_path) @@ -271,7 +280,7 @@ def interrogate( Dict[str, float] # tag confidents ]: # init model - if not hasattr(self, 'model') or self.model is None: + if self.model is None: self.load() # code for converting the image and running the model is taken from the link below @@ -282,12 +291,9 @@ def interrogate( _, height, _, _ = self.model.get_inputs()[0].shape # alpha to white - image = image.convert('RGBA') - new_image = Image.new('RGBA', image.size, 'WHITE') - new_image.paste(image, mask=image) - image = new_image.convert('RGB') - image = np.asarray(image) + image = dbimutils.fill_transparent(image) + image = np.asarray(image) # PIL RGB to OpenCV BGR image = image[:, :, ::-1] @@ -311,3 +317,69 @@ def interrogate( tags = dict(tags[4:].values) return ratings, tags + + +class MLDanbooruInterrogator(Interrogator): + def __init__( + self, + name: str, + repo_id: str, + model_path: str, + tags_path='classes.json' + ) -> None: + super().__init__(name) + self.model_path = model_path + self.tags_path = tags_path + self.repo_id = repo_id + self.tags = None + self.model = None + + def download(self) -> Tuple[str, str]: + print(f"Loading {self.name} model file from {self.repo_id}") + + model_path = hf_hub_download( + repo_id=self.repo_id, filename=self.model_path) + tags_path = hf_hub_download( + repo_id=self.repo_id, filename=self.tags_path) + return model_path, tags_path + + def load(self) -> None: + model_path, tags_path = self.download() + + ort = get_onnxrt() + self.model = ort.InferenceSession(model_path, providers=onnxrt_providers) + + print(f'Loaded {self.name} model from {model_path}') + + with open(tags_path, 'r', encoding='utf-8') as f: + self.tags = json.load(f) + + def interrogate( + self, + image: Image + ) -> Tuple[ + Dict[str, float], # rating confidents + Dict[str, float] # tag confidents + ]: + # init model + if self.model is None: + self.load() + + image = dbimutils.fill_transparent(image) + image = dbimutils.resize(image, 448) # TODO CUSTOMIZE + + x = np.asarray(image, dtype=np.float32) / 255 + # HWC -> 1CHW + x = x.transpose((2, 0, 1)) + x = np.expand_dims(x, 0) + + input_ = self.model.get_inputs()[0] + output = self.model.get_outputs()[0] + # evaluate model + y, = self.model.run([output.name], {input_.name: x}) + + # Softmax + y = 1 / (1 + np.exp(-y)) + + tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())} + return None, tags diff --git a/tagger/ui.py b/tagger/ui.py index ec9bda1..ae552c1 100644 --- a/tagger/ui.py +++ b/tagger/ui.py @@ -36,7 +36,7 @@ def on_interrogate( batch_remove_duplicated_tag: bool, batch_output_save_json: bool, - interrogator: str, + interrogator_name: str, threshold: float, additional_tags: str, exclude_tags: str, @@ -48,11 +48,10 @@ def on_interrogate( unload_model_after_running: bool ): - if interrogator not in utils.interrogators: + interrogator: Interrogator = next((i for i in utils.interrogators.values() if interrogator_name == i.name), None) + if interrogator is None: return ['', None, None, f"'{interrogator}' is not a valid interrogator"] - interrogator: Interrogator = utils.interrogators[interrogator] - postprocess_opts = ( threshold, split_str(additional_tags), @@ -78,7 +77,7 @@ def on_interrogate( return [ ', '.join(processed_tags), ratings, - tags, + dict(list(tags.items())[:200]), '' ] @@ -326,7 +325,12 @@ def on_ui_tabs(): # interrogator selector with gr.Column(): with gr.Row(variant='compact'): - interrogator_names = utils.refresh_interrogators() + def refresh(): + utils.refresh_interrogators() + return sorted(x.name for x in utils.interrogators.values()) + + interrogator_names = refresh() + interrogator = utils.preset.component( gr.Dropdown, label='Interrogator', @@ -341,7 +345,7 @@ def on_ui_tabs(): ui.create_refresh_button( interrogator, lambda: None, - lambda: {'choices': utils.refresh_interrogators()}, + lambda: {'choices': refresh()}, 'refresh_interrogator' ) diff --git a/tagger/utils.py b/tagger/utils.py index 21f1135..171ed90 100644 --- a/tagger/utils.py +++ b/tagger/utils.py @@ -6,52 +6,48 @@ from modules import shared, scripts from preload import default_ddp_path from tagger.preset import Preset -from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, WaifuDiffusionInterrogator +from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, WaifuDiffusionInterrogator, MLDanbooruInterrogator preset = Preset(Path(scripts.basedir(), 'presets')) -interrogators: Dict[str, Interrogator] = {} - +interrogators: Dict[str, Interrogator] = { + 'wd14-vit.v1': WaifuDiffusionInterrogator( + 'WD14 ViT v1', + repo_id='SmilingWolf/wd-v1-4-vit-tagger'), + 'wd14-vit.v2': WaifuDiffusionInterrogator( + 'WD14 ViT v2', + repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', + ), + 'wd14-convnext.v1': WaifuDiffusionInterrogator( + 'WD14 ConvNeXT v1', + repo_id='SmilingWolf/wd-v1-4-convnext-tagger' + ), + 'wd14-convnext.v2': WaifuDiffusionInterrogator( + 'WD14 ConvNeXT v2', + repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2', + ), + 'wd14-convnextv2.v1': WaifuDiffusionInterrogator( + 'WD14 ConvNeXTV2 v1', + repo_id='SmilingWolf/wd-v1-4-convnextv2-tagger-v2' + ), + 'wd14-swinv2.v2': WaifuDiffusionInterrogator( + 'WD14 SwinV2 v1', + repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2' + ), + 'mld-caformer.dec-5-97527': MLDanbooruInterrogator( + 'ML-Danbooru Caformer dec-5-97527', + repo_id='deepghs/ml-danbooru-onnx', + model_path='ml_caformer_m36_dec-5-97527.onnx' + ), + 'mld-tresnetd.6-30000': MLDanbooruInterrogator( + 'ML-Danbooru TResNet-D 6-30000', + repo_id='deepghs/ml-danbooru-onnx', + model_path='TResnet-D-FLq_ema_6-30000.onnx' + ) +} -def refresh_interrogators() -> List[str]: - global interrogators - interrogators = { - 'wd14-vit-v2': WaifuDiffusionInterrogator( - 'wd14-vit-v2', - repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2', - revision='v2.0' - ), - 'wd14-convnext-v2': WaifuDiffusionInterrogator( - 'wd14-convnext-v2', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2', - revision='v2.0' - ), - 'wd14-swinv2-v2': WaifuDiffusionInterrogator( - 'wd14-swinv2-v2', - repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2', - revision='v2.0' - ), - 'wd14-vit-v2-git': WaifuDiffusionInterrogator( - 'wd14-vit-v2-git', - repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2' - ), - 'wd14-convnext-v2-git': WaifuDiffusionInterrogator( - 'wd14-convnext-v2-git', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2' - ), - 'wd14-swinv2-v2-git': WaifuDiffusionInterrogator( - 'wd14-swinv2-v2-git', - repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2' - ), - 'wd14-vit': WaifuDiffusionInterrogator( - 'wd14-vit', - repo_id='SmilingWolf/wd-v1-4-vit-tagger'), - 'wd14-convnext': WaifuDiffusionInterrogator( - 'wd14-convnext', - repo_id='SmilingWolf/wd-v1-4-convnext-tagger' - ), - } +def refresh_interrogators(): # load deepdanbooru project os.makedirs( getattr(shared.cmd_opts, 'deepdanbooru_projects_path', default_ddp_path), @@ -65,9 +61,7 @@ def refresh_interrogators() -> List[str]: if not Path(path, 'project.json').is_file(): continue - interrogators[path.name] = DeepDanbooruInterrogator(path.name, path) - - return sorted(interrogators.keys()) + interrogators["deepdanbooru"] = DeepDanbooruInterrogator(path.name, path) def split_str(s: str, separator=',') -> List[str]: