|
| 1 | +from pathlib import Path |
| 2 | +from typing import Literal |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +from PIL import Image |
| 7 | +from transformers import AutoModelForMaskGeneration, AutoProcessor |
| 8 | +from transformers.models.sam import SamModel |
| 9 | +from transformers.models.sam.processing_sam import SamProcessor |
| 10 | + |
| 11 | +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation |
| 12 | +from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField |
| 13 | +from invokeai.app.invocations.primitives import MaskOutput |
| 14 | +from invokeai.app.services.shared.invocation_context import InvocationContext |
| 15 | +from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask |
| 16 | +from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline |
| 17 | + |
| 18 | +SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"] |
| 19 | +SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = { |
| 20 | + "segment-anything-base": "facebook/sam-vit-base", |
| 21 | + "segment-anything-large": "facebook/sam-vit-large", |
| 22 | + "segment-anything-huge": "facebook/sam-vit-huge", |
| 23 | +} |
| 24 | + |
| 25 | + |
| 26 | +@invocation( |
| 27 | + "segment_anything", |
| 28 | + title="Segment Anything", |
| 29 | + tags=["prompt", "segmentation"], |
| 30 | + category="segmentation", |
| 31 | + version="1.0.0", |
| 32 | +) |
| 33 | +class SegmentAnythingInvocation(BaseInvocation): |
| 34 | + """Runs a Segment Anything Model.""" |
| 35 | + |
| 36 | + # Reference: |
| 37 | + # - https://arxiv.org/pdf/2304.02643 |
| 38 | + # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam |
| 39 | + # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb |
| 40 | + |
| 41 | + model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.") |
| 42 | + image: ImageField = InputField(description="The image to segment.") |
| 43 | + bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.") |
| 44 | + apply_polygon_refinement: bool = InputField( |
| 45 | + description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).", |
| 46 | + default=True, |
| 47 | + ) |
| 48 | + mask_filter: Literal["all", "largest", "highest_box_score"] = InputField( |
| 49 | + description="The filtering to apply to the detected masks before merging them into a final output.", |
| 50 | + default="all", |
| 51 | + ) |
| 52 | + |
| 53 | + @torch.no_grad() |
| 54 | + def invoke(self, context: InvocationContext) -> MaskOutput: |
| 55 | + # The models expect a 3-channel RGB image. |
| 56 | + image_pil = context.images.get_pil(self.image.image_name, mode="RGB") |
| 57 | + |
| 58 | + if len(self.bounding_boxes) == 0: |
| 59 | + combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool) |
| 60 | + else: |
| 61 | + masks = self._segment(context=context, image=image_pil) |
| 62 | + masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes) |
| 63 | + |
| 64 | + # masks contains bool values, so we merge them via max-reduce. |
| 65 | + combined_mask, _ = torch.stack(masks).max(dim=0) |
| 66 | + |
| 67 | + mask_tensor_name = context.tensors.save(combined_mask) |
| 68 | + height, width = combined_mask.shape |
| 69 | + return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height) |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def _load_sam_model(model_path: Path): |
| 73 | + sam_model = AutoModelForMaskGeneration.from_pretrained( |
| 74 | + model_path, |
| 75 | + local_files_only=True, |
| 76 | + # TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the |
| 77 | + # model, and figure out how to make it work in the pipeline. |
| 78 | + # torch_dtype=TorchDevice.choose_torch_dtype(), |
| 79 | + ) |
| 80 | + assert isinstance(sam_model, SamModel) |
| 81 | + |
| 82 | + sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True) |
| 83 | + assert isinstance(sam_processor, SamProcessor) |
| 84 | + return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor) |
| 85 | + |
| 86 | + def _segment( |
| 87 | + self, |
| 88 | + context: InvocationContext, |
| 89 | + image: Image.Image, |
| 90 | + ) -> list[torch.Tensor]: |
| 91 | + """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.""" |
| 92 | + # Convert the bounding boxes to the SAM input format. |
| 93 | + sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] |
| 94 | + |
| 95 | + with ( |
| 96 | + context.models.load_remote_model( |
| 97 | + source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model |
| 98 | + ) as sam_pipeline, |
| 99 | + ): |
| 100 | + assert isinstance(sam_pipeline, SegmentAnythingPipeline) |
| 101 | + masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes) |
| 102 | + |
| 103 | + masks = self._process_masks(masks) |
| 104 | + if self.apply_polygon_refinement: |
| 105 | + masks = self._apply_polygon_refinement(masks) |
| 106 | + |
| 107 | + return masks |
| 108 | + |
| 109 | + def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]: |
| 110 | + """Convert the tensor output from the Segment Anything model from a tensor of shape |
| 111 | + [num_masks, channels, height, width] to a list of tensors of shape [height, width]. |
| 112 | + """ |
| 113 | + assert masks.dtype == torch.bool |
| 114 | + # [num_masks, channels, height, width] -> [num_masks, height, width] |
| 115 | + masks, _ = masks.max(dim=1) |
| 116 | + # Split the first dimension into a list of masks. |
| 117 | + return list(masks.cpu().unbind(dim=0)) |
| 118 | + |
| 119 | + def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]: |
| 120 | + """Apply polygon refinement to the masks. |
| 121 | +
|
| 122 | + Convert each mask to a polygon, then back to a mask. This has the following effect: |
| 123 | + - Smooth the edges of the mask slightly. |
| 124 | + - Ensure that each mask consists of a single closed polygon |
| 125 | + - Removes small mask pieces. |
| 126 | + - Removes holes from the mask. |
| 127 | + """ |
| 128 | + # Convert tensor masks to np masks. |
| 129 | + np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks] |
| 130 | + |
| 131 | + # Apply polygon refinement. |
| 132 | + for idx, mask in enumerate(np_masks): |
| 133 | + shape = mask.shape |
| 134 | + assert len(shape) == 2 # Assert length to satisfy type checker. |
| 135 | + polygon = mask_to_polygon(mask) |
| 136 | + mask = polygon_to_mask(polygon, shape) |
| 137 | + np_masks[idx] = mask |
| 138 | + |
| 139 | + # Convert np masks back to tensor masks. |
| 140 | + masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks] |
| 141 | + |
| 142 | + return masks |
| 143 | + |
| 144 | + def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]: |
| 145 | + """Filter the detected masks based on the specified mask filter.""" |
| 146 | + assert len(masks) == len(bounding_boxes) |
| 147 | + |
| 148 | + if self.mask_filter == "all": |
| 149 | + return masks |
| 150 | + elif self.mask_filter == "largest": |
| 151 | + # Find the largest mask. |
| 152 | + return [max(masks, key=lambda x: float(x.sum()))] |
| 153 | + elif self.mask_filter == "highest_box_score": |
| 154 | + # Find the index of the bounding box with the highest score. |
| 155 | + # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most |
| 156 | + # cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a |
| 157 | + # reasonable fallback since the expected score range is [0.0, 1.0]. |
| 158 | + max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0) |
| 159 | + return [masks[max_score_idx]] |
| 160 | + else: |
| 161 | + raise ValueError(f"Invalid mask filter: {self.mask_filter}") |
0 commit comments