Skip to content

Commit 27ac61a

Browse files
committed
Expose all model options in the GroundingDinoInvocation and the SegmentAnythingInvocation.
1 parent 675ffc2 commit 27ac61a

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

invokeai/app/invocations/grounding_dino.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from typing import Literal
23

34
import torch
45
from PIL import Image
@@ -12,7 +13,11 @@
1213
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
1314
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
1415

15-
GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
16+
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
17+
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
18+
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
19+
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
20+
}
1621

1722

1823
@invocation(
@@ -30,6 +35,7 @@ class GroundingDinoInvocation(BaseInvocation):
3035
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
3136
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
3237

38+
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
3339
prompt: str = InputField(description="The prompt describing the object to segment.")
3440
image: ImageField = InputField(description="The image to segment.")
3541
detection_threshold: float = InputField(
@@ -88,7 +94,7 @@ def _detect(
8894
labels = [label if label.endswith(".") else label + "." for label in labels]
8995

9096
with context.models.load_remote_model(
91-
source=GROUNDING_DINO_MODEL_ID, loader=GroundingDinoInvocation._load_grounding_dino
97+
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
9298
) as detector:
9399
assert isinstance(detector, GroundingDinoPipeline)
94100
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

invokeai/app/invocations/segment_anything.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
1616
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
1717

18-
SEGMENT_ANYTHING_MODEL_ID = "facebook/sam-vit-base"
18+
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
19+
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
20+
"segment-anything-base": "facebook/sam-vit-base",
21+
"segment-anything-large": "facebook/sam-vit-large",
22+
"segment-anything-huge": "facebook/sam-vit-huge",
23+
}
1924

2025

2126
@invocation(
@@ -33,6 +38,7 @@ class SegmentAnythingInvocation(BaseInvocation):
3338
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
3439
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
3540

41+
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
3642
image: ImageField = InputField(description="The image to segment.")
3743
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
3844
apply_polygon_refinement: bool = InputField(
@@ -88,7 +94,7 @@ def _segment(
8894

8995
with (
9096
context.models.load_remote_model(
91-
source=SEGMENT_ANYTHING_MODEL_ID, loader=SegmentAnythingInvocation._load_sam_model
97+
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
9298
) as sam_pipeline,
9399
):
94100
assert isinstance(sam_pipeline, SegmentAnythingPipeline)

0 commit comments

Comments
 (0)