4
4
import io
5
5
from hashlib import sha256
6
6
import json
7
+ import numpy as np
7
8
from platform import system
8
9
from typing import Tuple , List , Dict , Callable
9
10
from pandas import read_csv , read_json
17
18
from . import dbimutils
18
19
from tagger import settings
19
20
from tagger .uiset import QData , IOData , ItRetTP
21
+ import gradio as gr
20
22
21
23
Its = settings .InterrogatorSettings
22
24
23
25
# select a device to process
24
26
use_cpu = ('all' in shared .cmd_opts .use_cpu ) or (
25
27
'interrogate' in shared .cmd_opts .use_cpu )
26
28
29
+ # https://onnxruntime.ai/docs/execution-providers/
30
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
31
+ onnxrt_providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ]
32
+
27
33
if use_cpu :
28
34
TF_DEVICE_NAME = '/cpu:0'
35
+ onnxrt_providers .pop (0 )
29
36
else :
30
37
TF_DEVICE_NAME = '/gpu:0'
31
38
@@ -63,7 +70,7 @@ class Interrogator:
63
70
"output_dir" : '' ,
64
71
}
65
72
output = None
66
- #odd_increment = 0
73
+ # odd_increment = 0
67
74
68
75
@classmethod
69
76
def flip (cls , key ):
@@ -131,10 +138,12 @@ def unload(self) -> bool:
131
138
del self .model
132
139
self .model = None
133
140
unloaded = True
141
+ gr .collect ()
134
142
print (f'Unloaded { self .name } ' )
135
143
136
144
if hasattr (self , 'tags' ):
137
145
del self .tags
146
+ self .tags = None
138
147
139
148
return unloaded
140
149
@@ -259,6 +268,7 @@ class DeepDanbooruInterrogator(Interrogator):
259
268
def __init__ (self , name : str , project_path : os .PathLike ) -> None :
260
269
super ().__init__ (name )
261
270
self .project_path = project_path
271
+ self .model = None
262
272
self .tags = None
263
273
264
274
def load (self ) -> None :
@@ -331,7 +341,7 @@ def interrogate(
331
341
Dict [str , float ] # tag confidences
332
342
]:
333
343
# init model
334
- if not hasattr ( self , 'model' ) or self .model is None :
344
+ if self .model is None :
335
345
self .load ()
336
346
337
347
import deepdanbooru .data as ddd
@@ -363,36 +373,62 @@ def interrogate(
363
373
return ratings , tags
364
374
365
375
376
+ def get_onnxrt ():
377
+ try :
378
+ import onnxruntime
379
+ return onnxruntime
380
+ except ImportError :
381
+ # only one of these packages should be installed at one time in an env
382
+ # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
383
+ # TODO: remove old package when the environment changes?
384
+ from launch import is_installed , run_pip
385
+ if not is_installed ('onnxruntime' ):
386
+ if system () == "Darwin" :
387
+ package_name = "onnxruntime-silicon"
388
+ else :
389
+ package_name = "onnxruntime-gpu"
390
+ package = os .environ .get (
391
+ 'ONNXRUNTIME_PACKAGE' ,
392
+ package_name
393
+ )
394
+
395
+ run_pip (f'install { package } ' , 'onnxruntime' )
396
+
397
+ import onnxruntime
398
+ return onnxruntime
399
+
400
+
366
401
class WaifuDiffusionInterrogator (Interrogator ):
367
402
""" Interrogator for Waifu Diffusion models """
368
403
def __init__ (
369
404
self ,
370
405
name : str ,
371
406
model_path = 'model.onnx' ,
372
407
tags_path = 'selected_tags.csv' ,
373
- ** kwargs
408
+ repo_id = None ,
374
409
) -> None :
375
410
super ().__init__ (name )
411
+ self .repo_id = repo_id
376
412
self .model_path = model_path
377
413
self .tags_path = tags_path
378
414
self .tags = None
379
- self .kwargs = kwargs
380
-
381
- def download (self ) -> Tuple [os .PathLike , os .PathLike ]:
382
- print (f"Loading { self .name } model file from { self .kwargs ['repo_id' ]} " )
415
+ self .model = None
416
+ self .tags = None
383
417
418
+ def download (self ) -> None :
384
419
mdir = Path (shared .models_path , 'interrogators' )
385
- model_path = Path (hf_hub_download (** self .kwargs ,
386
- filename = self .model_path ,
387
- cache_dir = mdir ))
388
- tags_path = Path (hf_hub_download (** self .kwargs ,
389
- filename = self .tags_path ,
390
- cache_dir = mdir ))
420
+ if self .repo_id is not None :
421
+ print (f"Loading { self .name } model file from { self .repo_id } " )
422
+
423
+ self .model_path = hf_hub_download (self .repo_id , self .model_path ,
424
+ cache_dir = mdir )
425
+ self .tags_path = hf_hub_download (self .repo_id , self .tags_path ,
426
+ cache_dir = mdir )
391
427
392
428
download_model = {
393
429
'name' : self .name ,
394
- 'model_path' : str ( model_path ) ,
395
- 'tags_path' : str ( tags_path ) ,
430
+ 'model_path' : self . model_path ,
431
+ 'tags_path' : self . tags_path ,
396
432
}
397
433
mpath = Path (mdir , 'model.json' )
398
434
@@ -411,56 +447,14 @@ def download(self) -> Tuple[os.PathLike, os.PathLike]:
411
447
with io .open (mpath , 'w' ) as filename :
412
448
json .dump (data , filename )
413
449
414
- return model_path , tags_path
415
-
416
- def get_model_path (self ) -> Tuple [os .PathLike , os .PathLike ]:
417
- model_path = ''
418
- tags_path = ''
419
- mpath = Path (shared .models_path , 'interrogators' , 'model.json' )
420
- try :
421
- models = read_json (mpath ).to_dict (orient = 'records' )
422
- i = next (i for i in models if i ['name' ] == self .name )
423
- model_path = i ['model_path' ]
424
- tags_path = i ['tags_path' ]
425
- except Exception as e :
426
- print (f'{ mpath } : requires a name, model_ and tags_path: { repr (e )} ' )
427
- model_path , tags_path = self .download ()
428
- return model_path , tags_path
429
-
430
450
def load (self ) -> None :
431
- if isinstance (self .model_path , str ) or isinstance (self .tags_path , str ):
432
- model_path , tags_path = self .download ()
433
- else :
434
- model_path = self .model_path
435
- tags_path = self .tags_path
436
-
437
- # only one of these packages should be installed a time in any one env
438
- # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
439
- # TODO: remove old package when the environment changes?
440
- from launch import is_installed , run_pip
441
- if not is_installed ('onnxruntime' ):
442
- if system () == "Darwin" :
443
- package_name = "onnxruntime-silicon"
444
- else :
445
- package_name = "onnxruntime-gpu"
446
- package = os .environ .get (
447
- 'ONNXRUNTIME_PACKAGE' ,
448
- package_name
449
- )
450
-
451
- run_pip (f'install { package } ' , 'onnxruntime' )
452
-
453
- from onnxruntime import InferenceSession
454
-
455
- # https://onnxruntime.ai/docs/execution-providers/
456
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
457
- providers = ['CUDAExecutionProvider' , 'CPUExecutionProvider' ]
458
- if use_cpu :
459
- providers .pop (0 )
451
+ self .download ()
452
+ ort = get_onnxrt ()
453
+ self .model = ort .InferenceSession (self .model_path ,
454
+ providers = onnxrt_providers )
460
455
461
- print (f'Loading { self .name } model from { model_path } , { tags_path } ' )
462
- self .model = InferenceSession (str (model_path ), providers = providers )
463
- self .tags = read_csv (tags_path )
456
+ print (f'Loaded { self .name } model from { self .repo_id } ' )
457
+ self .tags = read_csv (self .tags_path )
464
458
465
459
def interrogate (
466
460
self ,
@@ -470,7 +464,7 @@ def interrogate(
470
464
Dict [str , float ] # tag confidences
471
465
]:
472
466
# init model
473
- if not hasattr ( self , 'model' ) or self .model is None :
467
+ if self .model is None :
474
468
self .load ()
475
469
476
470
# code for converting the image and running the model is taken from the
@@ -481,15 +475,14 @@ def interrogate(
481
475
_ , height , _ , _ = self .model .get_inputs ()[0 ].shape
482
476
483
477
# alpha to white
484
- image = image .convert ('RGBA' )
485
- new_image = Image .new ('RGBA' , image .size , 'WHITE' )
486
- new_image .paste (image , mask = image )
487
- image = new_image .convert ('RGB' )
488
- image = asarray (image )
478
+ image = dbimutils .fill_transparent (image )
489
479
480
+ image = np .asarray (image )
490
481
# PIL RGB to OpenCV BGR
491
482
image = image [:, :, ::- 1 ]
492
483
484
+ tags = dict
485
+
493
486
image = dbimutils .make_square (image , height )
494
487
image = dbimutils .smart_resize (image , height )
495
488
image = image .astype (float32 )
@@ -609,3 +602,69 @@ def pred_model(model):
609
602
QData .add_tag = orig_add_tags
610
603
del os .environ ["TF_XLA_FLAGS" ]
611
604
return ''
605
+
606
+
607
+ class MLDanbooruInterrogator (Interrogator ):
608
+ def __init__ (
609
+ self ,
610
+ name : str ,
611
+ repo_id : str ,
612
+ model_path : str ,
613
+ tags_path = 'classes.json'
614
+ ) -> None :
615
+ super ().__init__ (name )
616
+ self .model_path = model_path
617
+ self .tags_path = tags_path
618
+ self .repo_id = repo_id
619
+ self .tags = None
620
+ self .model = None
621
+
622
+ def download (self ) -> Tuple [str , str ]:
623
+ print (f"Loading { self .name } model file from { self .repo_id } " )
624
+
625
+ model_path = hf_hub_download (
626
+ repo_id = self .repo_id , filename = self .model_path )
627
+ tags_path = hf_hub_download (
628
+ repo_id = self .repo_id , filename = self .tags_path )
629
+ return model_path , tags_path
630
+
631
+ def load (self ) -> None :
632
+ self .model_path , self .tags_path = self .download ()
633
+
634
+ ort = get_onnxrt ()
635
+ self .model = ort .InferenceSession (self .model_path , providers = onnxrt_providers )
636
+
637
+ print (f'Loaded { self .name } model from { self .model_path } ' )
638
+
639
+ with open (self .tags_path , 'r' , encoding = 'utf-8' ) as f :
640
+ self .tags = json .load (f )
641
+
642
+ def interrogate (
643
+ self ,
644
+ image : Image
645
+ ) -> Tuple [
646
+ Dict [str , float ], # rating confidents
647
+ Dict [str , float ] # tag confidents
648
+ ]:
649
+ # init model
650
+ if self .model is None :
651
+ self .load ()
652
+
653
+ image = dbimutils .fill_transparent (image )
654
+ image = dbimutils .resize (image , 448 ) # TODO CUSTOMIZE
655
+
656
+ x = np .asarray (image , dtype = np .float32 ) / 255
657
+ # HWC -> 1CHW
658
+ x = x .transpose ((2 , 0 , 1 ))
659
+ x = np .expand_dims (x , 0 )
660
+
661
+ input_ = self .model .get_inputs ()[0 ]
662
+ output = self .model .get_outputs ()[0 ]
663
+ # evaluate model
664
+ y , = self .model .run ([output .name ], {input_ .name : x })
665
+
666
+ # Softmax
667
+ y = 1 / (1 + np .exp (- y ))
668
+
669
+ tags = {tag : float (conf ) for tag , conf in zip (self .tags , y .flatten ())}
670
+ return {}, tags
0 commit comments