-
Notifications
You must be signed in to change notification settings - Fork 77
Support ML-Danbooru #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import json | ||
import os | ||
import gc | ||
import pandas as pd | ||
|
@@ -7,7 +8,6 @@ | |
from io import BytesIO | ||
from PIL import Image | ||
|
||
from pathlib import Path | ||
from huggingface_hub import hf_hub_download | ||
|
||
from modules import shared | ||
|
@@ -20,8 +20,13 @@ | |
use_cpu = ('all' in shared.cmd_opts.use_cpu) or ( | ||
'interrogate' in shared.cmd_opts.use_cpu) | ||
|
||
# https://onnxruntime.ai/docs/execution-providers/ | ||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 | ||
onnxrt_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | ||
|
||
if use_cpu: | ||
tf_device_name = '/cpu:0' | ||
onnxrt_providers.pop(0) | ||
else: | ||
tf_device_name = '/gpu:0' | ||
|
||
|
@@ -94,13 +99,13 @@ def load(self): | |
def unload(self) -> bool: | ||
unloaded = False | ||
|
||
if hasattr(self, 'model') and self.model is not None: | ||
del self.model | ||
if self.model is not None: | ||
self.model = None | ||
unloaded = True | ||
gc.collect() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what exactly does gr.collect() do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may already notice, this line is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yes I noticed and forgot. I also found this so numba possible sollution for cuda users. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new solution looks cool, please give it a try. |
||
print(f'Unloaded {self.name}') | ||
|
||
if hasattr(self, 'tags'): | ||
del self.tags | ||
self.tags = None | ||
|
||
return unloaded | ||
|
||
|
@@ -118,6 +123,7 @@ class DeepDanbooruInterrogator(Interrogator): | |
def __init__(self, name: str, project_path: os.PathLike) -> None: | ||
super().__init__(name) | ||
self.project_path = project_path | ||
self.model = None | ||
|
||
def load(self) -> None: | ||
print(f'Loading {self.name} from {str(self.project_path)}') | ||
|
@@ -183,7 +189,7 @@ def interrogate( | |
Dict[str, float] # tag confidents | ||
]: | ||
# init model | ||
if not hasattr(self, 'model') or self.model is None: | ||
if self.model is None: | ||
self.load() | ||
|
||
import deepdanbooru.data as ddd | ||
|
@@ -212,54 +218,57 @@ def interrogate( | |
return ratings, tags | ||
|
||
|
||
def get_onnxrt(): | ||
try: | ||
import onnxruntime | ||
return onnxruntime | ||
except ImportError: | ||
# only one of these packages should be installed at a time in any one environment | ||
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime | ||
# TODO: remove old package when the environment changes? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In what scenario do you expect the environment to change? installed on an external drive? |
||
from launch import is_installed, run_pip | ||
if not is_installed('onnxruntime'): | ||
package = os.environ.get( | ||
'ONNXRUNTIME_PACKAGE', | ||
'onnxruntime-gpu' | ||
) | ||
|
||
run_pip(f'install {package}', 'onnxruntime') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can just add onnxruntime to the requirements.txt, right?
or people can adapt this |
||
|
||
import onnxruntime | ||
return onnxruntime | ||
|
||
|
||
class WaifuDiffusionInterrogator(Interrogator): | ||
def __init__( | ||
self, | ||
name: str, | ||
repo_id: str, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have some without a repo_id, so I've changed this to a repo_id=None, also requires some changes elsewhere, |
||
model_path='model.onnx', | ||
tags_path='selected_tags.csv', | ||
**kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the kwargs are now obsolete |
||
) -> None: | ||
super().__init__(name) | ||
self.repo_id = repo_id | ||
self.model_path = model_path | ||
self.tags_path = tags_path | ||
self.kwargs = kwargs | ||
self.model = None | ||
self.tags = None | ||
|
||
def download(self) -> Tuple[os.PathLike, os.PathLike]: | ||
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}") | ||
def download(self) -> Tuple[str, str]: | ||
print(f"Loading {self.name} model file from {self.repo_id}") | ||
|
||
model_path = Path(hf_hub_download( | ||
**self.kwargs, filename=self.model_path)) | ||
tags_path = Path(hf_hub_download( | ||
**self.kwargs, filename=self.tags_path)) | ||
model_path = hf_hub_download(self.repo_id, self.model_path) | ||
tags_path = hf_hub_download(self.repo_id, self.model_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is a bug, I've fixed it. |
||
return model_path, tags_path | ||
|
||
def load(self) -> None: | ||
model_path, tags_path = self.download() | ||
|
||
# only one of these packages should be installed at a time in any one environment | ||
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime | ||
# TODO: remove old package when the environment changes? | ||
from launch import is_installed, run_pip | ||
if not is_installed('onnxruntime'): | ||
package = os.environ.get( | ||
'ONNXRUNTIME_PACKAGE', | ||
'onnxruntime-gpu' | ||
) | ||
|
||
run_pip(f'install {package}', 'onnxruntime') | ||
ort = get_onnxrt() | ||
self.model = ort.InferenceSession(model_path, providers=onnxrt_providers) | ||
|
||
from onnxruntime import InferenceSession | ||
|
||
# https://onnxruntime.ai/docs/execution-providers/ | ||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958 | ||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | ||
if use_cpu: | ||
providers.pop(0) | ||
|
||
self.model = InferenceSession(str(model_path), providers=providers) | ||
|
||
print(f'Loaded {self.name} model from {model_path}') | ||
print(f'Loaded {self.name} model from {self.repo_id}') | ||
|
||
self.tags = pd.read_csv(tags_path) | ||
|
||
|
@@ -271,7 +280,7 @@ def interrogate( | |
Dict[str, float] # tag confidents | ||
]: | ||
# init model | ||
if not hasattr(self, 'model') or self.model is None: | ||
if self.model is None: | ||
self.load() | ||
|
||
# code for converting the image and running the model is taken from the link below | ||
|
@@ -282,12 +291,9 @@ def interrogate( | |
_, height, _, _ = self.model.get_inputs()[0].shape | ||
|
||
# alpha to white | ||
image = image.convert('RGBA') | ||
new_image = Image.new('RGBA', image.size, 'WHITE') | ||
new_image.paste(image, mask=image) | ||
image = new_image.convert('RGB') | ||
image = np.asarray(image) | ||
image = dbimutils.fill_transparent(image) | ||
|
||
image = np.asarray(image) | ||
# PIL RGB to OpenCV BGR | ||
image = image[:, :, ::-1] | ||
|
||
|
@@ -311,3 +317,69 @@ def interrogate( | |
tags = dict(tags[4:].values) | ||
|
||
return ratings, tags | ||
|
||
|
||
class MLDanbooruInterrogator(Interrogator): | ||
def __init__( | ||
self, | ||
name: str, | ||
repo_id: str, | ||
model_path: str, | ||
tags_path='classes.json' | ||
) -> None: | ||
super().__init__(name) | ||
self.model_path = model_path | ||
self.tags_path = tags_path | ||
self.repo_id = repo_id | ||
self.tags = None | ||
self.model = None | ||
|
||
def download(self) -> Tuple[str, str]: | ||
print(f"Loading {self.name} model file from {self.repo_id}") | ||
|
||
model_path = hf_hub_download( | ||
repo_id=self.repo_id, filename=self.model_path) | ||
tags_path = hf_hub_download( | ||
repo_id=self.repo_id, filename=self.tags_path) | ||
return model_path, tags_path | ||
|
||
def load(self) -> None: | ||
model_path, tags_path = self.download() | ||
|
||
ort = get_onnxrt() | ||
self.model = ort.InferenceSession(model_path, providers=onnxrt_providers) | ||
|
||
print(f'Loaded {self.name} model from {model_path}') | ||
|
||
with open(tags_path, 'r', encoding='utf-8') as f: | ||
self.tags = json.load(f) | ||
|
||
def interrogate( | ||
self, | ||
image: Image | ||
) -> Tuple[ | ||
Dict[str, float], # rating confidents | ||
Dict[str, float] # tag confidents | ||
]: | ||
# init model | ||
if self.model is None: | ||
self.load() | ||
|
||
image = dbimutils.fill_transparent(image) | ||
image = dbimutils.resize(image, 448) # TODO CUSTOMIZE | ||
|
||
x = np.asarray(image, dtype=np.float32) / 255 | ||
# HWC -> 1CHW | ||
x = x.transpose((2, 0, 1)) | ||
x = np.expand_dims(x, 0) | ||
|
||
input_ = self.model.get_inputs()[0] | ||
output = self.model.get_outputs()[0] | ||
# evaluate model | ||
y, = self.model.run([output.name], {input_.name: x}) | ||
|
||
# Softmax | ||
y = 1 / (1 + np.exp(-y)) | ||
|
||
tags = {tag: float(conf) for tag, conf in zip(self.tags, y.flatten())} | ||
return None, tags |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ def on_interrogate( | |
batch_remove_duplicated_tag: bool, | ||
batch_output_save_json: bool, | ||
|
||
interrogator: str, | ||
interrogator_name: str, | ||
threshold: float, | ||
additional_tags: str, | ||
exclude_tags: str, | ||
|
@@ -48,11 +48,10 @@ def on_interrogate( | |
|
||
unload_model_after_running: bool | ||
): | ||
if interrogator not in utils.interrogators: | ||
interrogator: Interrogator = next((i for i in utils.interrogators.values() if interrogator_name == i.name), None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a bit silly |
||
if interrogator is None: | ||
return ['', None, None, f"'{interrogator}' is not a valid interrogator"] | ||
|
||
interrogator: Interrogator = utils.interrogators[interrogator] | ||
|
||
postprocess_opts = ( | ||
threshold, | ||
split_str(additional_tags), | ||
|
@@ -78,7 +77,7 @@ def on_interrogate( | |
return [ | ||
', '.join(processed_tags), | ||
ratings, | ||
tags, | ||
dict(list(tags.items())[:200]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my branch There is already a solution for this (the slider on in settings -> tagger), so I can omit this. |
||
'' | ||
] | ||
|
||
|
@@ -326,7 +325,12 @@ def on_ui_tabs(): | |
# interrogator selector | ||
with gr.Column(): | ||
with gr.Row(variant='compact'): | ||
interrogator_names = utils.refresh_interrogators() | ||
def refresh(): | ||
utils.refresh_interrogators() | ||
return sorted(x.name for x in utils.interrogators.values()) | ||
|
||
interrogator_names = refresh() | ||
|
||
interrogator = utils.preset.component( | ||
gr.Dropdown, | ||
label='Interrogator', | ||
|
@@ -341,7 +345,7 @@ def on_ui_tabs(): | |
ui.create_refresh_button( | ||
interrogator, | ||
lambda: None, | ||
lambda: {'choices': utils.refresh_interrogators()}, | ||
lambda: {'choices': refresh()}, | ||
'refresh_interrogator' | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the same as
right?