Skip to content

Commit 4f8a4b0

Browse files
Merge branch 'main' into depth_anything_v2
2 parents a743f3c + 571ba87 commit 4f8a4b0

File tree

29 files changed

+959
-186
lines changed

29 files changed

+959
-186
lines changed

invokeai/app/invocations/compel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8080

8181
with (
8282
# apply all patches while the model is on the target device
83-
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
83+
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
8484
tokenizer_info as tokenizer,
8585
ModelPatcher.apply_lora_text_encoder(
8686
text_encoder,
8787
loras=_lora_loader(),
88-
model_state_dict=model_state_dict,
88+
cached_weights=cached_weights,
8989
),
9090
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
9191
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
@@ -175,13 +175,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
175175

176176
with (
177177
# apply all patches while the model is on the target device
178-
text_encoder_info.model_on_device() as (state_dict, text_encoder),
178+
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
179179
tokenizer_info as tokenizer,
180180
ModelPatcher.apply_lora(
181181
text_encoder,
182182
loras=_lora_loader(),
183183
prefix=lora_prefix,
184-
model_state_dict=state_dict,
184+
cached_weights=cached_weights,
185185
),
186186
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
187187
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),

invokeai/app/invocations/create_gradient_mask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
3939
title="Create Gradient Mask",
4040
tags=["mask", "denoise"],
4141
category="latents",
42-
version="1.1.0",
42+
version="1.2.0",
4343
)
4444
class CreateGradientMaskInvocation(BaseInvocation):
4545
"""Creates mask for denoising model run."""
@@ -93,6 +93,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
9393

9494
# redistribute blur so that the original edges are 0 and blur outwards to 1
9595
blur_tensor = (blur_tensor - 0.5) * 2
96+
blur_tensor[blur_tensor < 0] = 0.0
9697

9798
threshold = 1 - self.minimum_denoise
9899

invokeai/app/invocations/denoise_latents.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6363
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
6464
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
65+
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
6566
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6667
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6768
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
@@ -845,6 +846,16 @@ def step_callback(state: PipelineIntermediateState) -> None:
845846
if self.unet.freeu_config:
846847
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
847848

849+
### lora
850+
if self.unet.loras:
851+
for lora_field in self.unet.loras:
852+
ext_manager.add_extension(
853+
LoRAExt(
854+
node_context=context,
855+
model_id=lora_field.lora,
856+
weight=lora_field.weight,
857+
)
858+
)
848859
### seamless
849860
if self.unet.seamless_axes:
850861
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
@@ -964,14 +975,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
964975
assert isinstance(unet_info.model, UNet2DConditionModel)
965976
with (
966977
ExitStack() as exit_stack,
967-
unet_info.model_on_device() as (model_state_dict, unet),
978+
unet_info.model_on_device() as (cached_weights, unet),
968979
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
969980
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
970981
# Apply the LoRA after unet has been moved to its target device for faster patching.
971982
ModelPatcher.apply_lora_unet(
972983
unet,
973984
loras=_lora_loader(),
974-
model_state_dict=model_state_dict,
985+
cached_weights=cached_weights,
975986
),
976987
):
977988
assert isinstance(unet, UNet2DConditionModel)

invokeai/app/invocations/fields.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Any, Callable, Optional, Tuple
33

4-
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
4+
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
55
from pydantic.fields import _Unset
66
from pydantic_core import PydanticUndefined
77

@@ -242,6 +242,31 @@ class ConditioningField(BaseModel):
242242
)
243243

244244

245+
class BoundingBoxField(BaseModel):
246+
"""A bounding box primitive value."""
247+
248+
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
249+
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
250+
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
251+
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
252+
253+
score: Optional[float] = Field(
254+
default=None,
255+
ge=0.0,
256+
le=1.0,
257+
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
258+
"when the bounding box was produced by a detector and has an associated confidence score.",
259+
)
260+
261+
@model_validator(mode="after")
262+
def check_coords(self):
263+
if self.x_min > self.x_max:
264+
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
265+
if self.y_min > self.y_max:
266+
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
267+
return self
268+
269+
245270
class MetadataField(RootModel[dict[str, Any]]):
246271
"""
247272
Pydantic model for metadata with custom root of type dict[str, Any].
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from pathlib import Path
2+
from typing import Literal
3+
4+
import torch
5+
from PIL import Image
6+
from transformers import pipeline
7+
from transformers.pipelines import ZeroShotObjectDetectionPipeline
8+
9+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
10+
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
11+
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
12+
from invokeai.app.services.shared.invocation_context import InvocationContext
13+
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
14+
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
15+
16+
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
17+
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
18+
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
19+
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
20+
}
21+
22+
23+
@invocation(
24+
"grounding_dino",
25+
title="Grounding DINO (Text Prompt Object Detection)",
26+
tags=["prompt", "object detection"],
27+
category="image",
28+
version="1.0.0",
29+
)
30+
class GroundingDinoInvocation(BaseInvocation):
31+
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
32+
33+
# Reference:
34+
# - https://arxiv.org/pdf/2303.05499
35+
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
36+
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
37+
38+
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
39+
prompt: str = InputField(description="The prompt describing the object to segment.")
40+
image: ImageField = InputField(description="The image to segment.")
41+
detection_threshold: float = InputField(
42+
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
43+
ge=0.0,
44+
le=1.0,
45+
default=0.3,
46+
)
47+
48+
@torch.no_grad()
49+
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
50+
# The model expects a 3-channel RGB image.
51+
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
52+
53+
detections = self._detect(
54+
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
55+
)
56+
57+
# Convert detections to BoundingBoxCollectionOutput.
58+
bounding_boxes: list[BoundingBoxField] = []
59+
for detection in detections:
60+
bounding_boxes.append(
61+
BoundingBoxField(
62+
x_min=detection.box.xmin,
63+
x_max=detection.box.xmax,
64+
y_min=detection.box.ymin,
65+
y_max=detection.box.ymax,
66+
score=detection.score,
67+
)
68+
)
69+
return BoundingBoxCollectionOutput(collection=bounding_boxes)
70+
71+
@staticmethod
72+
def _load_grounding_dino(model_path: Path):
73+
grounding_dino_pipeline = pipeline(
74+
model=str(model_path),
75+
task="zero-shot-object-detection",
76+
local_files_only=True,
77+
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
78+
# model, and figure out how to make it work in the pipeline.
79+
# torch_dtype=TorchDevice.choose_torch_dtype(),
80+
)
81+
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
82+
return GroundingDinoPipeline(grounding_dino_pipeline)
83+
84+
def _detect(
85+
self,
86+
context: InvocationContext,
87+
image: Image.Image,
88+
labels: list[str],
89+
threshold: float = 0.3,
90+
) -> list[DetectionResult]:
91+
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
92+
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
93+
# actually makes a difference.
94+
labels = [label if label.endswith(".") else label + "." for label in labels]
95+
96+
with context.models.load_remote_model(
97+
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
98+
) as detector:
99+
assert isinstance(detector, GroundingDinoPipeline)
100+
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

invokeai/app/invocations/mask.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import torch
3+
from PIL import Image
34

45
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
5-
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
6-
from invokeai.app.invocations.primitives import MaskOutput
6+
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
7+
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
78

89

910
@invocation(
@@ -118,3 +119,27 @@ def invoke(self, context: InvocationContext) -> MaskOutput:
118119
height=mask.shape[1],
119120
width=mask.shape[2],
120121
)
122+
123+
124+
@invocation(
125+
"tensor_mask_to_image",
126+
title="Tensor Mask to Image",
127+
tags=["mask"],
128+
category="mask",
129+
version="1.0.0",
130+
)
131+
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
132+
"""Convert a mask tensor to an image."""
133+
134+
mask: TensorField = InputField(description="The mask tensor to convert.")
135+
136+
def invoke(self, context: InvocationContext) -> ImageOutput:
137+
mask = context.tensors.load(self.mask.tensor_name)
138+
# Ensure that the mask is binary.
139+
if mask.dtype != torch.bool:
140+
mask = mask > 0.5
141+
mask_np = (mask.float() * 255).byte().cpu().numpy()
142+
143+
mask_pil = Image.fromarray(mask_np, mode="L")
144+
image_dto = context.images.save(image=mask_pil)
145+
return ImageOutput.build(image_dto)

invokeai/app/invocations/primitives.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
88
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
99
from invokeai.app.invocations.fields import (
10+
BoundingBoxField,
1011
ColorField,
1112
ConditioningField,
1213
DenoiseMaskField,
@@ -469,3 +470,42 @@ def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
469470

470471

471472
# endregion
473+
474+
# region BoundingBox
475+
476+
477+
@invocation_output("bounding_box_output")
478+
class BoundingBoxOutput(BaseInvocationOutput):
479+
"""Base class for nodes that output a single bounding box"""
480+
481+
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
482+
483+
484+
@invocation_output("bounding_box_collection_output")
485+
class BoundingBoxCollectionOutput(BaseInvocationOutput):
486+
"""Base class for nodes that output a collection of bounding boxes"""
487+
488+
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
489+
490+
491+
@invocation(
492+
"bounding_box",
493+
title="Bounding Box",
494+
tags=["primitives", "segmentation", "collection", "bounding box"],
495+
category="primitives",
496+
version="1.0.0",
497+
)
498+
class BoundingBoxInvocation(BaseInvocation):
499+
"""Create a bounding box manually by supplying box coordinates"""
500+
501+
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
502+
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
503+
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
504+
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
505+
506+
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
507+
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
508+
return BoundingBoxOutput(bounding_box=bounding_box)
509+
510+
511+
# endregion

0 commit comments

Comments
 (0)