Skip to content
This repository was archived by the owner on Jul 17, 2023. It is now read-only.

Support ML-Danbooru #77

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tagger/dbimutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
from PIL import Image


def fill_transparent(image: Image.Image, color='WHITE'):
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, color)
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
return image


def resize(pic: Image.Image, size: int, keep_ratio=True) -> Image.Image:
if not keep_ratio:
target_size = (size, size)
else:
min_edge = min(pic.size)
target_size = (
int(pic.size[0] / min_edge * size),
int(pic.size[1] / min_edge * size),
)

target_size = (
(target_size[0] // 4) * 4,
(target_size[1] // 4) * 4,
)

return pic.resize(target_size, resample=Image.Resampling.LANCZOS)


def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
if img.endswith(".gif"):
img = Image.open(img)
Expand Down
154 changes: 113 additions & 41 deletions tagger/interrogator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import gc
import pandas as pd
Expand All @@ -7,7 +8,6 @@
from io import BytesIO
from PIL import Image

from pathlib import Path
from huggingface_hub import hf_hub_download

from modules import shared
Expand All @@ -20,8 +20,13 @@
use_cpu = ('all' in shared.cmd_opts.use_cpu) or (
'interrogate' in shared.cmd_opts.use_cpu)

# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

if use_cpu:
tf_device_name = '/cpu:0'
onnxrt_providers.pop(0)
else:
tf_device_name = '/gpu:0'

Expand Down Expand Up @@ -94,13 +99,13 @@ def load(self):
def unload(self) -> bool:
unloaded = False

if hasattr(self, 'model') and self.model is not None:
del self.model
if self.model is not None:
self.model = None
unloaded = True
gc.collect()
print(f'Unloaded {self.name}')

if hasattr(self, 'tags'):
del self.tags
self.tags = None

return unloaded

Expand All @@ -118,6 +123,7 @@ class DeepDanbooruInterrogator(Interrogator):
def __init__(self, name: str, project_path: os.PathLike) -> None:
super().__init__(name)
self.project_path = project_path
self.model = None

def load(self) -> None:
print(f'Loading {self.name} from {str(self.project_path)}')
Expand Down Expand Up @@ -183,7 +189,7 @@ def interrogate(
Dict[str, float] # tag confidents
]:
# init model
if not hasattr(self, 'model') or self.model is None:
if self.model is None:
self.load()

import deepdanbooru.data as ddd
Expand Down Expand Up @@ -212,54 +218,57 @@ def interrogate(
return ratings, tags


def get_onnxrt():
try:
import onnxruntime
return onnxruntime
except ImportError:
# only one of these packages should be installed at a time in any one environment
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
from launch import is_installed, run_pip
if not is_installed('onnxruntime'):
package = os.environ.get(
'ONNXRUNTIME_PACKAGE',
'onnxruntime-gpu'
)

run_pip(f'install {package}', 'onnxruntime')

import onnxruntime
return onnxruntime


class WaifuDiffusionInterrogator(Interrogator):
def __init__(
self,
name: str,
repo_id: str,
model_path='model.onnx',
tags_path='selected_tags.csv',
**kwargs
) -> None:
super().__init__(name)
self.repo_id = repo_id
self.model_path = model_path
self.tags_path = tags_path
self.kwargs = kwargs
self.model = None
self.tags = None

def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
def download(self) -> Tuple[str, str]:
print(f"Loading {self.name} model file from {self.repo_id}")

model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tags_path = Path(hf_hub_download(
**self.kwargs, filename=self.tags_path))
model_path = hf_hub_download(self.repo_id, self.model_path)
tags_path = hf_hub_download(self.repo_id, self.model_path)
return model_path, tags_path

def load(self) -> None:
model_path, tags_path = self.download()

# only one of these packages should be installed at a time in any one environment
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
from launch import is_installed, run_pip
if not is_installed('onnxruntime'):
package = os.environ.get(
'ONNXRUNTIME_PACKAGE',
'onnxruntime-gpu'
)

run_pip(f'install {package}', 'onnxruntime')
ort = get_onnxrt()
self.model = ort.InferenceSession(model_path, providers=onnxrt_providers)

from onnxruntime import InferenceSession

# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
if use_cpu:
providers.pop(0)

self.model = InferenceSession(str(model_path), providers=providers)

print(f'Loaded {self.name} model from {model_path}')
print(f'Loaded {self.name} model from {self.repo_id}')

self.tags = pd.read_csv(tags_path)

Expand All @@ -271,7 +280,7 @@ def interrogate(
Dict[str, float] # tag confidents
]:
# init model
if not hasattr(self, 'model') or self.model is None:
if self.model is None:
self.load()

# code for converting the image and running the model is taken from the link below
Expand All @@ -282,12 +291,9 @@ def interrogate(
_, height, _, _ = self.model.get_inputs()[0].shape

# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)
image = dbimutils.fill_transparent(image)

image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]

Expand All @@ -311,3 +317,69 @@ def interrogate(
tags = dict(tags[4:].values)

return ratings, tags


class MLDanbooruInterrogator(Interrogator):
def __init__(
self,
name: str,
repo_id: str,
model_path: str,
tags_path='classes.json'
) -> None:
super().__init__(name)
self.model_path = model_path
self.tags_path = tags_path
self.repo_id = repo_id
self.tags = None
self.model = None

def download(self) -> Tuple[str, str]:
print(f"Loading {self.name} model file from {self.repo_id}")

model_path = hf_hub_download(
repo_id=self.repo_id, filename=self.model_path)
tags_path = hf_hub_download(
repo_id=self.repo_id, filename=self.tags_path)
return model_path, tags_path

def load(self) -> None:
model_path, tags_path = self.download()

ort = get_onnxrt()
self.model = ort.InferenceSession(model_path, providers=onnxrt_providers)

print(f'Loaded {self.name} model from {model_path}')

with open(tags_path, 'r', encoding='utf-8') as f:
self.tags = json.load(f)

def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidents
Dict[str, float] # tag confidents
]:
# init model
if self.model is None:
self.load()

image = dbimutils.fill_transparent(image)
image = dbimutils.resize(image, 448) # TODO CUSTOMIZE

x = np.asarray(image, dtype=np.float32) / 255
# HWC -> 1CHW
x = x.transpose((2, 0, 1))
x = np.expand_dims(x, 0)

input_ = self.model.get_inputs()[0]
output = self.model.get_outputs()[0]
# evaluate model
y, = self.model.run([output.name], {input_.name: x})

# Softmax
y = 1 / (1 + np.exp(-y))

tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())}
return None, tags
18 changes: 11 additions & 7 deletions tagger/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def on_interrogate(
batch_remove_duplicated_tag: bool,
batch_output_save_json: bool,

interrogator: str,
interrogator_name: str,
threshold: float,
additional_tags: str,
exclude_tags: str,
Expand All @@ -48,11 +48,10 @@ def on_interrogate(

unload_model_after_running: bool
):
if interrogator not in utils.interrogators:
interrogator: Interrogator = next((i for i in utils.interrogators.values() if interrogator_name == i.name), None)
if interrogator is None:
return ['', None, None, f"'{interrogator}' is not a valid interrogator"]

interrogator: Interrogator = utils.interrogators[interrogator]

postprocess_opts = (
threshold,
split_str(additional_tags),
Expand All @@ -78,7 +77,7 @@ def on_interrogate(
return [
', '.join(processed_tags),
ratings,
tags,
dict(list(tags.items())[:200]),
''
]

Expand Down Expand Up @@ -326,7 +325,12 @@ def on_ui_tabs():
# interrogator selector
with gr.Column():
with gr.Row(variant='compact'):
interrogator_names = utils.refresh_interrogators()
def refresh():
utils.refresh_interrogators()
return sorted(x.name for x in utils.interrogators.values())

interrogator_names = refresh()

interrogator = utils.preset.component(
gr.Dropdown,
label='Interrogator',
Expand All @@ -341,7 +345,7 @@ def on_ui_tabs():
ui.create_refresh_button(
interrogator,
lambda: None,
lambda: {'choices': utils.refresh_interrogators()},
lambda: {'choices': refresh()},
'refresh_interrogator'
)

Expand Down
Loading