Skip to content

Commit b7918a2

Browse files
author
Roel Kluin
committed
Manually merged: Support ML-Danbooru #6, changes amended from CCRcmcpe's
pull request to Toriato
1 parent 2748a7e commit b7918a2

File tree

5 files changed

+207
-130
lines changed

5 files changed

+207
-130
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
deepdanbooru
2+
onnxruntime-silicon; sys_platform == 'darwin'
3+
onnxruntime-gpu; sys_platform != 'darwin'
24
fastapi
35
gradio
46
huggingface_hub

tagger/dbimutils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@
55
from PIL import Image
66

77

8+
def fill_transparent(image: Image.Image, color='WHITE'):
9+
image = image.convert('RGBA')
10+
new_image = Image.new('RGBA', image.size, color)
11+
new_image.paste(image, mask=image)
12+
image = new_image.convert('RGB')
13+
return image
14+
15+
16+
def resize(pic: Image.Image, size: int, keep_ratio=True) -> Image.Image:
17+
if not keep_ratio:
18+
target_size = (size, size)
19+
else:
20+
min_edge = min(pic.size)
21+
target_size = (
22+
int(pic.size[0] / min_edge * size),
23+
int(pic.size[1] / min_edge * size),
24+
)
25+
26+
target_size = (target_size[0] & ~3, target_size[1] & ~3)
27+
28+
return pic.resize(target_size, resample=Image.Resampling.LANCZOS)
29+
30+
831
def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
932
""" Read an image, convert to 24-bit if necessary """
1033
if img.endswith(".gif"):

tagger/interrogator.py

Lines changed: 128 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io
55
from hashlib import sha256
66
import json
7+
import numpy as np
78
from platform import system
89
from typing import Tuple, List, Dict, Callable
910
from pandas import read_csv, read_json
@@ -17,15 +18,21 @@
1718
from . import dbimutils
1819
from tagger import settings
1920
from tagger.uiset import QData, IOData, ItRetTP
21+
import gradio as gr
2022

2123
Its = settings.InterrogatorSettings
2224

2325
# select a device to process
2426
use_cpu = ('all' in shared.cmd_opts.use_cpu) or (
2527
'interrogate' in shared.cmd_opts.use_cpu)
2628

29+
# https://onnxruntime.ai/docs/execution-providers/
30+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
31+
onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
32+
2733
if use_cpu:
2834
TF_DEVICE_NAME = '/cpu:0'
35+
onnxrt_providers.pop(0)
2936
else:
3037
TF_DEVICE_NAME = '/gpu:0'
3138

@@ -63,7 +70,7 @@ class Interrogator:
6370
"output_dir": '',
6471
}
6572
output = None
66-
#odd_increment = 0
73+
# odd_increment = 0
6774

6875
@classmethod
6976
def flip(cls, key):
@@ -131,10 +138,12 @@ def unload(self) -> bool:
131138
del self.model
132139
self.model = None
133140
unloaded = True
141+
gr.collect()
134142
print(f'Unloaded {self.name}')
135143

136144
if hasattr(self, 'tags'):
137145
del self.tags
146+
self.tags = None
138147

139148
return unloaded
140149

@@ -259,6 +268,7 @@ class DeepDanbooruInterrogator(Interrogator):
259268
def __init__(self, name: str, project_path: os.PathLike) -> None:
260269
super().__init__(name)
261270
self.project_path = project_path
271+
self.model = None
262272
self.tags = None
263273

264274
def load(self) -> None:
@@ -331,7 +341,7 @@ def interrogate(
331341
Dict[str, float] # tag confidences
332342
]:
333343
# init model
334-
if not hasattr(self, 'model') or self.model is None:
344+
if self.model is None:
335345
self.load()
336346

337347
import deepdanbooru.data as ddd
@@ -363,36 +373,62 @@ def interrogate(
363373
return ratings, tags
364374

365375

376+
def get_onnxrt():
377+
try:
378+
import onnxruntime
379+
return onnxruntime
380+
except ImportError:
381+
# only one of these packages should be installed at one time in an env
382+
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
383+
# TODO: remove old package when the environment changes?
384+
from launch import is_installed, run_pip
385+
if not is_installed('onnxruntime'):
386+
if system() == "Darwin":
387+
package_name = "onnxruntime-silicon"
388+
else:
389+
package_name = "onnxruntime-gpu"
390+
package = os.environ.get(
391+
'ONNXRUNTIME_PACKAGE',
392+
package_name
393+
)
394+
395+
run_pip(f'install {package}', 'onnxruntime')
396+
397+
import onnxruntime
398+
return onnxruntime
399+
400+
366401
class WaifuDiffusionInterrogator(Interrogator):
367402
""" Interrogator for Waifu Diffusion models """
368403
def __init__(
369404
self,
370405
name: str,
371406
model_path='model.onnx',
372407
tags_path='selected_tags.csv',
373-
**kwargs
408+
repo_id=None,
374409
) -> None:
375410
super().__init__(name)
411+
self.repo_id = repo_id
376412
self.model_path = model_path
377413
self.tags_path = tags_path
378414
self.tags = None
379-
self.kwargs = kwargs
380-
381-
def download(self) -> Tuple[os.PathLike, os.PathLike]:
382-
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
415+
self.model = None
416+
self.tags = None
383417

418+
def download(self) -> None:
384419
mdir = Path(shared.models_path, 'interrogators')
385-
model_path = Path(hf_hub_download(**self.kwargs,
386-
filename=self.model_path,
387-
cache_dir=mdir))
388-
tags_path = Path(hf_hub_download(**self.kwargs,
389-
filename=self.tags_path,
390-
cache_dir=mdir))
420+
if self.repo_id is not None:
421+
print(f"Loading {self.name} model file from {self.repo_id}")
422+
423+
self.model_path = hf_hub_download(self.repo_id, self.model_path,
424+
cache_dir=mdir)
425+
self.tags_path = hf_hub_download(self.repo_id, self.tags_path,
426+
cache_dir=mdir)
391427

392428
download_model = {
393429
'name': self.name,
394-
'model_path': str(model_path),
395-
'tags_path': str(tags_path),
430+
'model_path': self.model_path,
431+
'tags_path': self.tags_path,
396432
}
397433
mpath = Path(mdir, 'model.json')
398434

@@ -411,56 +447,14 @@ def download(self) -> Tuple[os.PathLike, os.PathLike]:
411447
with io.open(mpath, 'w') as filename:
412448
json.dump(data, filename)
413449

414-
return model_path, tags_path
415-
416-
def get_model_path(self) -> Tuple[os.PathLike, os.PathLike]:
417-
model_path = ''
418-
tags_path = ''
419-
mpath = Path(shared.models_path, 'interrogators', 'model.json')
420-
try:
421-
models = read_json(mpath).to_dict(orient='records')
422-
i = next(i for i in models if i['name'] == self.name)
423-
model_path = i['model_path']
424-
tags_path = i['tags_path']
425-
except Exception as e:
426-
print(f'{mpath}: requires a name, model_ and tags_path: {repr(e)}')
427-
model_path, tags_path = self.download()
428-
return model_path, tags_path
429-
430450
def load(self) -> None:
431-
if isinstance(self.model_path, str) or isinstance(self.tags_path, str):
432-
model_path, tags_path = self.download()
433-
else:
434-
model_path = self.model_path
435-
tags_path = self.tags_path
436-
437-
# only one of these packages should be installed a time in any one env
438-
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
439-
# TODO: remove old package when the environment changes?
440-
from launch import is_installed, run_pip
441-
if not is_installed('onnxruntime'):
442-
if system() == "Darwin":
443-
package_name = "onnxruntime-silicon"
444-
else:
445-
package_name = "onnxruntime-gpu"
446-
package = os.environ.get(
447-
'ONNXRUNTIME_PACKAGE',
448-
package_name
449-
)
450-
451-
run_pip(f'install {package}', 'onnxruntime')
452-
453-
from onnxruntime import InferenceSession
454-
455-
# https://onnxruntime.ai/docs/execution-providers/
456-
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
457-
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
458-
if use_cpu:
459-
providers.pop(0)
451+
self.download()
452+
ort = get_onnxrt()
453+
self.model = ort.InferenceSession(self.model_path,
454+
providers=onnxrt_providers)
460455

461-
print(f'Loading {self.name} model from {model_path}, {tags_path}')
462-
self.model = InferenceSession(str(model_path), providers=providers)
463-
self.tags = read_csv(tags_path)
456+
print(f'Loaded {self.name} model from {self.repo_id}')
457+
self.tags = read_csv(self.tags_path)
464458

465459
def interrogate(
466460
self,
@@ -470,7 +464,7 @@ def interrogate(
470464
Dict[str, float] # tag confidences
471465
]:
472466
# init model
473-
if not hasattr(self, 'model') or self.model is None:
467+
if self.model is None:
474468
self.load()
475469

476470
# code for converting the image and running the model is taken from the
@@ -481,15 +475,14 @@ def interrogate(
481475
_, height, _, _ = self.model.get_inputs()[0].shape
482476

483477
# alpha to white
484-
image = image.convert('RGBA')
485-
new_image = Image.new('RGBA', image.size, 'WHITE')
486-
new_image.paste(image, mask=image)
487-
image = new_image.convert('RGB')
488-
image = asarray(image)
478+
image = dbimutils.fill_transparent(image)
489479

480+
image = np.asarray(image)
490481
# PIL RGB to OpenCV BGR
491482
image = image[:, :, ::-1]
492483

484+
tags = dict
485+
493486
image = dbimutils.make_square(image, height)
494487
image = dbimutils.smart_resize(image, height)
495488
image = image.astype(float32)
@@ -609,3 +602,69 @@ def pred_model(model):
609602
QData.add_tag = orig_add_tags
610603
del os.environ["TF_XLA_FLAGS"]
611604
return ''
605+
606+
607+
class MLDanbooruInterrogator(Interrogator):
608+
def __init__(
609+
self,
610+
name: str,
611+
repo_id: str,
612+
model_path: str,
613+
tags_path='classes.json'
614+
) -> None:
615+
super().__init__(name)
616+
self.model_path = model_path
617+
self.tags_path = tags_path
618+
self.repo_id = repo_id
619+
self.tags = None
620+
self.model = None
621+
622+
def download(self) -> Tuple[str, str]:
623+
print(f"Loading {self.name} model file from {self.repo_id}")
624+
625+
model_path = hf_hub_download(
626+
repo_id=self.repo_id, filename=self.model_path)
627+
tags_path = hf_hub_download(
628+
repo_id=self.repo_id, filename=self.tags_path)
629+
return model_path, tags_path
630+
631+
def load(self) -> None:
632+
self.model_path, self.tags_path = self.download()
633+
634+
ort = get_onnxrt()
635+
self.model = ort.InferenceSession(self.model_path, providers=onnxrt_providers)
636+
637+
print(f'Loaded {self.name} model from {self.model_path}')
638+
639+
with open(self.tags_path, 'r', encoding='utf-8') as f:
640+
self.tags = json.load(f)
641+
642+
def interrogate(
643+
self,
644+
image: Image
645+
) -> Tuple[
646+
Dict[str, float], # rating confidents
647+
Dict[str, float] # tag confidents
648+
]:
649+
# init model
650+
if self.model is None:
651+
self.load()
652+
653+
image = dbimutils.fill_transparent(image)
654+
image = dbimutils.resize(image, 448) # TODO CUSTOMIZE
655+
656+
x = np.asarray(image, dtype=np.float32) / 255
657+
# HWC -> 1CHW
658+
x = x.transpose((2, 0, 1))
659+
x = np.expand_dims(x, 0)
660+
661+
input_ = self.model.get_inputs()[0]
662+
output = self.model.get_outputs()[0]
663+
# evaluate model
664+
y, = self.model.run([output.name], {input_.name: x})
665+
666+
# Softmax
667+
y = 1 / (1 + np.exp(-y))
668+
669+
tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())}
670+
return {}, tags

tagger/ui.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def unload_interrogators() -> List[str]:
3030

3131
def check_for_errors(name) -> str:
3232
errors = It.get_errors()
33-
if name not in utils.interrogators:
33+
if not any(i.name == name for i in utils.interrogators.values()):
3434
errors += f"'{name}': invalid interrogator"
3535

3636
return errors
@@ -44,7 +44,8 @@ def on_interrogate(name: str, inverse=False) -> ItRetTP:
4444
if err != '':
4545
return (None, None, None, err)
4646

47-
interrogator: It = utils.interrogators[name]
47+
such_name = (i for i in utils.interrogators.values() if name == i.name)
48+
interrogator: It = next(such_name, None)
4849
QData.inverse = inverse
4950
return interrogator.batch_interrogate()
5051

@@ -209,7 +210,11 @@ def on_ui_tabs():
209210
# interrogator selector
210211
with gr.Column():
211212
with gr.Row(variant='compact'):
212-
interrogator_names = utils.refresh_interrogators()
213+
def refresh():
214+
utils.refresh_interrogators()
215+
return sorted(x.name for x in utils.interrogators
216+
.values())
217+
interrogator_names = refresh()
213218
interrogator = utils.preset.component(
214219
gr.Dropdown,
215220
label='Interrogator',
@@ -224,7 +229,7 @@ def on_ui_tabs():
224229
ui.create_refresh_button(
225230
interrogator,
226231
lambda: None,
227-
lambda: {'choices': utils.refresh_interrogators()},
232+
lambda: {'choices': refresh()},
228233
'refresh_interrogator'
229234
)
230235

0 commit comments

Comments
 (0)