1
1
from pathlib import Path
2
+ from typing import Literal
2
3
3
4
import torch
4
5
from PIL import Image
12
13
from invokeai .backend .image_util .grounding_dino .detection_result import DetectionResult
13
14
from invokeai .backend .image_util .grounding_dino .grounding_dino_pipeline import GroundingDinoPipeline
14
15
15
- GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
16
+ GroundingDinoModelKey = Literal ["grounding-dino-tiny" , "grounding-dino-base" ]
17
+ GROUNDING_DINO_MODEL_IDS : dict [GroundingDinoModelKey , str ] = {
18
+ "grounding-dino-tiny" : "IDEA-Research/grounding-dino-tiny" ,
19
+ "grounding-dino-base" : "IDEA-Research/grounding-dino-base" ,
20
+ }
16
21
17
22
18
23
@invocation (
@@ -30,6 +35,7 @@ class GroundingDinoInvocation(BaseInvocation):
30
35
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
31
36
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
32
37
38
+ model : GroundingDinoModelKey = InputField (description = "The Grounding DINO model to use." )
33
39
prompt : str = InputField (description = "The prompt describing the object to segment." )
34
40
image : ImageField = InputField (description = "The image to segment." )
35
41
detection_threshold : float = InputField (
@@ -88,7 +94,7 @@ def _detect(
88
94
labels = [label if label .endswith ("." ) else label + "." for label in labels ]
89
95
90
96
with context .models .load_remote_model (
91
- source = GROUNDING_DINO_MODEL_ID , loader = GroundingDinoInvocation ._load_grounding_dino
97
+ source = GROUNDING_DINO_MODEL_IDS [ self . model ] , loader = GroundingDinoInvocation ._load_grounding_dino
92
98
) as detector :
93
99
assert isinstance (detector , GroundingDinoPipeline )
94
100
return detector .detect (image = image , candidate_labels = labels , threshold = threshold )
0 commit comments