Skip to content

Commit f27b6e2

Browse files
authored
Add Grounded SAM support (text prompt image segmentation) (#6701)
## Summary This PR enables Grounded SAM workflows (https://arxiv.org/pdf/2401.14159) via the following: - `GroundingDinoInvocation` for running a Grounding DINO model. - `SegmentAnythingModelInvocation` for running a SAM model. - `MaskTensorToImageInvocation` for convenient visualization. Other notes: - Uses the transformers implementation of Grounding DINO and SAM. - The new models are treated as 'utility models' meaning that they are not visible in the Models tab, and are downloaded automatically the first time that they are used. <img width="874" alt="image" src="https://github.com/user-attachments/assets/1cbaa97d-0e27-4943-86b1-dc7327ba8675"> ## Example Input image ![be10ec0c-20a8-4ac7-840e-d1a05fffdb6a](https://github.com/user-attachments/assets/bf21572c-635d-4703-b4ab-7aba658a9671) Prompt: "wheels", all other configs default Result: ![2221c44e-64e6-4b18-b4cb-610514b7a554](https://github.com/user-attachments/assets/344b91f4-7f4a-4b70-8e2e-3b4a0e55176d) ## Related Issues / Discussions Thanks to @blessedcoolant for the initial draft here: #6678 ## QA Instructions Manual tests: - [ ] Test that default settings work well. - [ ] Test with / without apply_polygon_refinement - [ ] Test mask_filter options - [ ] Test detection_threshold values - [ ] Test RGB input image - [ ] Test RGBA input image - [ ] Test grayscale input image - [ ] Smoke test that an empty mask is returned when 0 objects are detected - [ ] Test on CPU - [ ] Test on MPS (Works on Mac OS, but had to force both models to run on CPU instead of MPS) Performance: - Peak GPU memory utilization with both Grounding DINO and SAM models loaded is ~4.5GB. (The models do not need to be loaded at the same time, so could be offloaded by the MM if needed.) - On an RTX4090, with the models already cached, node execution takes ~0.6 secs. - On my CPU, with the models cached, node execution takes ~10secs. ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
2 parents 08b1fee + 981475a commit f27b6e2

File tree

12 files changed

+529
-4
lines changed

12 files changed

+529
-4
lines changed

invokeai/app/invocations/fields.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Any, Callable, Optional, Tuple
33

4-
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
4+
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
55
from pydantic.fields import _Unset
66
from pydantic_core import PydanticUndefined
77

@@ -242,6 +242,31 @@ class ConditioningField(BaseModel):
242242
)
243243

244244

245+
class BoundingBoxField(BaseModel):
246+
"""A bounding box primitive value."""
247+
248+
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
249+
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
250+
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
251+
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
252+
253+
score: Optional[float] = Field(
254+
default=None,
255+
ge=0.0,
256+
le=1.0,
257+
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
258+
"when the bounding box was produced by a detector and has an associated confidence score.",
259+
)
260+
261+
@model_validator(mode="after")
262+
def check_coords(self):
263+
if self.x_min > self.x_max:
264+
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
265+
if self.y_min > self.y_max:
266+
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
267+
return self
268+
269+
245270
class MetadataField(RootModel[dict[str, Any]]):
246271
"""
247272
Pydantic model for metadata with custom root of type dict[str, Any].
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from pathlib import Path
2+
from typing import Literal
3+
4+
import torch
5+
from PIL import Image
6+
from transformers import pipeline
7+
from transformers.pipelines import ZeroShotObjectDetectionPipeline
8+
9+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
10+
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
11+
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
12+
from invokeai.app.services.shared.invocation_context import InvocationContext
13+
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
14+
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
15+
16+
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
17+
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
18+
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
19+
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
20+
}
21+
22+
23+
@invocation(
24+
"grounding_dino",
25+
title="Grounding DINO (Text Prompt Object Detection)",
26+
tags=["prompt", "object detection"],
27+
category="image",
28+
version="1.0.0",
29+
)
30+
class GroundingDinoInvocation(BaseInvocation):
31+
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
32+
33+
# Reference:
34+
# - https://arxiv.org/pdf/2303.05499
35+
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
36+
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
37+
38+
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
39+
prompt: str = InputField(description="The prompt describing the object to segment.")
40+
image: ImageField = InputField(description="The image to segment.")
41+
detection_threshold: float = InputField(
42+
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
43+
ge=0.0,
44+
le=1.0,
45+
default=0.3,
46+
)
47+
48+
@torch.no_grad()
49+
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
50+
# The model expects a 3-channel RGB image.
51+
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
52+
53+
detections = self._detect(
54+
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
55+
)
56+
57+
# Convert detections to BoundingBoxCollectionOutput.
58+
bounding_boxes: list[BoundingBoxField] = []
59+
for detection in detections:
60+
bounding_boxes.append(
61+
BoundingBoxField(
62+
x_min=detection.box.xmin,
63+
x_max=detection.box.xmax,
64+
y_min=detection.box.ymin,
65+
y_max=detection.box.ymax,
66+
score=detection.score,
67+
)
68+
)
69+
return BoundingBoxCollectionOutput(collection=bounding_boxes)
70+
71+
@staticmethod
72+
def _load_grounding_dino(model_path: Path):
73+
grounding_dino_pipeline = pipeline(
74+
model=str(model_path),
75+
task="zero-shot-object-detection",
76+
local_files_only=True,
77+
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
78+
# model, and figure out how to make it work in the pipeline.
79+
# torch_dtype=TorchDevice.choose_torch_dtype(),
80+
)
81+
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
82+
return GroundingDinoPipeline(grounding_dino_pipeline)
83+
84+
def _detect(
85+
self,
86+
context: InvocationContext,
87+
image: Image.Image,
88+
labels: list[str],
89+
threshold: float = 0.3,
90+
) -> list[DetectionResult]:
91+
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
92+
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
93+
# actually makes a difference.
94+
labels = [label if label.endswith(".") else label + "." for label in labels]
95+
96+
with context.models.load_remote_model(
97+
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
98+
) as detector:
99+
assert isinstance(detector, GroundingDinoPipeline)
100+
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

invokeai/app/invocations/mask.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import torch
3+
from PIL import Image
34

45
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
5-
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
6-
from invokeai.app.invocations.primitives import MaskOutput
6+
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
7+
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
78

89

910
@invocation(
@@ -118,3 +119,27 @@ def invoke(self, context: InvocationContext) -> MaskOutput:
118119
height=mask.shape[1],
119120
width=mask.shape[2],
120121
)
122+
123+
124+
@invocation(
125+
"tensor_mask_to_image",
126+
title="Tensor Mask to Image",
127+
tags=["mask"],
128+
category="mask",
129+
version="1.0.0",
130+
)
131+
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
132+
"""Convert a mask tensor to an image."""
133+
134+
mask: TensorField = InputField(description="The mask tensor to convert.")
135+
136+
def invoke(self, context: InvocationContext) -> ImageOutput:
137+
mask = context.tensors.load(self.mask.tensor_name)
138+
# Ensure that the mask is binary.
139+
if mask.dtype != torch.bool:
140+
mask = mask > 0.5
141+
mask_np = (mask.float() * 255).byte().cpu().numpy()
142+
143+
mask_pil = Image.fromarray(mask_np, mode="L")
144+
image_dto = context.images.save(image=mask_pil)
145+
return ImageOutput.build(image_dto)

invokeai/app/invocations/primitives.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
88
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
99
from invokeai.app.invocations.fields import (
10+
BoundingBoxField,
1011
ColorField,
1112
ConditioningField,
1213
DenoiseMaskField,
@@ -469,3 +470,42 @@ def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
469470

470471

471472
# endregion
473+
474+
# region BoundingBox
475+
476+
477+
@invocation_output("bounding_box_output")
478+
class BoundingBoxOutput(BaseInvocationOutput):
479+
"""Base class for nodes that output a single bounding box"""
480+
481+
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
482+
483+
484+
@invocation_output("bounding_box_collection_output")
485+
class BoundingBoxCollectionOutput(BaseInvocationOutput):
486+
"""Base class for nodes that output a collection of bounding boxes"""
487+
488+
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
489+
490+
491+
@invocation(
492+
"bounding_box",
493+
title="Bounding Box",
494+
tags=["primitives", "segmentation", "collection", "bounding box"],
495+
category="primitives",
496+
version="1.0.0",
497+
)
498+
class BoundingBoxInvocation(BaseInvocation):
499+
"""Create a bounding box manually by supplying box coordinates"""
500+
501+
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
502+
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
503+
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
504+
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
505+
506+
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
507+
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
508+
return BoundingBoxOutput(bounding_box=bounding_box)
509+
510+
511+
# endregion
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from pathlib import Path
2+
from typing import Literal
3+
4+
import numpy as np
5+
import torch
6+
from PIL import Image
7+
from transformers import AutoModelForMaskGeneration, AutoProcessor
8+
from transformers.models.sam import SamModel
9+
from transformers.models.sam.processing_sam import SamProcessor
10+
11+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
12+
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
13+
from invokeai.app.invocations.primitives import MaskOutput
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
16+
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
17+
18+
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
19+
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
20+
"segment-anything-base": "facebook/sam-vit-base",
21+
"segment-anything-large": "facebook/sam-vit-large",
22+
"segment-anything-huge": "facebook/sam-vit-huge",
23+
}
24+
25+
26+
@invocation(
27+
"segment_anything",
28+
title="Segment Anything",
29+
tags=["prompt", "segmentation"],
30+
category="segmentation",
31+
version="1.0.0",
32+
)
33+
class SegmentAnythingInvocation(BaseInvocation):
34+
"""Runs a Segment Anything Model."""
35+
36+
# Reference:
37+
# - https://arxiv.org/pdf/2304.02643
38+
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
39+
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
40+
41+
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
42+
image: ImageField = InputField(description="The image to segment.")
43+
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
44+
apply_polygon_refinement: bool = InputField(
45+
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
46+
default=True,
47+
)
48+
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
49+
description="The filtering to apply to the detected masks before merging them into a final output.",
50+
default="all",
51+
)
52+
53+
@torch.no_grad()
54+
def invoke(self, context: InvocationContext) -> MaskOutput:
55+
# The models expect a 3-channel RGB image.
56+
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
57+
58+
if len(self.bounding_boxes) == 0:
59+
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
60+
else:
61+
masks = self._segment(context=context, image=image_pil)
62+
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
63+
64+
# masks contains bool values, so we merge them via max-reduce.
65+
combined_mask, _ = torch.stack(masks).max(dim=0)
66+
67+
mask_tensor_name = context.tensors.save(combined_mask)
68+
height, width = combined_mask.shape
69+
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
70+
71+
@staticmethod
72+
def _load_sam_model(model_path: Path):
73+
sam_model = AutoModelForMaskGeneration.from_pretrained(
74+
model_path,
75+
local_files_only=True,
76+
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
77+
# model, and figure out how to make it work in the pipeline.
78+
# torch_dtype=TorchDevice.choose_torch_dtype(),
79+
)
80+
assert isinstance(sam_model, SamModel)
81+
82+
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
83+
assert isinstance(sam_processor, SamProcessor)
84+
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
85+
86+
def _segment(
87+
self,
88+
context: InvocationContext,
89+
image: Image.Image,
90+
) -> list[torch.Tensor]:
91+
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
92+
# Convert the bounding boxes to the SAM input format.
93+
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
94+
95+
with (
96+
context.models.load_remote_model(
97+
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
98+
) as sam_pipeline,
99+
):
100+
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
101+
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
102+
103+
masks = self._process_masks(masks)
104+
if self.apply_polygon_refinement:
105+
masks = self._apply_polygon_refinement(masks)
106+
107+
return masks
108+
109+
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
110+
"""Convert the tensor output from the Segment Anything model from a tensor of shape
111+
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
112+
"""
113+
assert masks.dtype == torch.bool
114+
# [num_masks, channels, height, width] -> [num_masks, height, width]
115+
masks, _ = masks.max(dim=1)
116+
# Split the first dimension into a list of masks.
117+
return list(masks.cpu().unbind(dim=0))
118+
119+
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
120+
"""Apply polygon refinement to the masks.
121+
122+
Convert each mask to a polygon, then back to a mask. This has the following effect:
123+
- Smooth the edges of the mask slightly.
124+
- Ensure that each mask consists of a single closed polygon
125+
- Removes small mask pieces.
126+
- Removes holes from the mask.
127+
"""
128+
# Convert tensor masks to np masks.
129+
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
130+
131+
# Apply polygon refinement.
132+
for idx, mask in enumerate(np_masks):
133+
shape = mask.shape
134+
assert len(shape) == 2 # Assert length to satisfy type checker.
135+
polygon = mask_to_polygon(mask)
136+
mask = polygon_to_mask(polygon, shape)
137+
np_masks[idx] = mask
138+
139+
# Convert np masks back to tensor masks.
140+
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
141+
142+
return masks
143+
144+
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
145+
"""Filter the detected masks based on the specified mask filter."""
146+
assert len(masks) == len(bounding_boxes)
147+
148+
if self.mask_filter == "all":
149+
return masks
150+
elif self.mask_filter == "largest":
151+
# Find the largest mask.
152+
return [max(masks, key=lambda x: float(x.sum()))]
153+
elif self.mask_filter == "highest_box_score":
154+
# Find the index of the bounding box with the highest score.
155+
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
156+
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
157+
# reasonable fallback since the expected score range is [0.0, 1.0].
158+
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
159+
return [masks[max_score_idx]]
160+
else:
161+
raise ValueError(f"Invalid mask filter: {self.mask_filter}")

invokeai/backend/image_util/grounding_dino/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)