Skip to content

Commit fea9013

Browse files
committed
Move CreateGradientMaskInvocation to its own file. No functional changes.
1 parent 045cadd commit fea9013

File tree

2 files changed

+141
-119
lines changed

2 files changed

+141
-119
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Literal, Optional
2+
3+
import numpy as np
4+
import torch
5+
import torchvision.transforms as T
6+
from PIL import Image, ImageFilter
7+
from torchvision.transforms.functional import resize as tv_resize
8+
9+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
10+
from invokeai.app.invocations.fields import (
11+
DenoiseMaskField,
12+
FieldDescriptions,
13+
ImageField,
14+
Input,
15+
InputField,
16+
OutputField,
17+
)
18+
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
19+
from invokeai.app.invocations.latent import DEFAULT_PRECISION
20+
from invokeai.app.invocations.model import UNetField, VAEField
21+
from invokeai.app.services.shared.invocation_context import InvocationContext
22+
from invokeai.backend.model_manager import LoadedModel
23+
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
24+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
25+
26+
27+
@invocation_output("gradient_mask_output")
28+
class GradientMaskOutput(BaseInvocationOutput):
29+
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
30+
31+
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
32+
expanded_mask_area: ImageField = OutputField(
33+
description="Image representing the total gradient area of the mask. For paste-back purposes."
34+
)
35+
36+
37+
@invocation(
38+
"create_gradient_mask",
39+
title="Create Gradient Mask",
40+
tags=["mask", "denoise"],
41+
category="latents",
42+
version="1.1.0",
43+
)
44+
class CreateGradientMaskInvocation(BaseInvocation):
45+
"""Creates mask for denoising model run."""
46+
47+
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
48+
edge_radius: int = InputField(
49+
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
50+
)
51+
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
52+
minimum_denoise: float = InputField(
53+
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
54+
)
55+
image: Optional[ImageField] = InputField(
56+
default=None,
57+
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
58+
title="[OPTIONAL] Image",
59+
ui_order=6,
60+
)
61+
unet: Optional[UNetField] = InputField(
62+
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
63+
default=None,
64+
input=Input.Connection,
65+
title="[OPTIONAL] UNet",
66+
ui_order=5,
67+
)
68+
vae: Optional[VAEField] = InputField(
69+
default=None,
70+
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
71+
title="[OPTIONAL] VAE",
72+
input=Input.Connection,
73+
ui_order=7,
74+
)
75+
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
76+
fp32: bool = InputField(
77+
default=DEFAULT_PRECISION == "float32",
78+
description=FieldDescriptions.fp32,
79+
ui_order=9,
80+
)
81+
82+
@torch.no_grad()
83+
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
84+
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
85+
if self.edge_radius > 0:
86+
if self.coherence_mode == "Box Blur":
87+
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
88+
else: # Gaussian Blur OR Staged
89+
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
90+
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
91+
92+
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
93+
94+
# redistribute blur so that the original edges are 0 and blur outwards to 1
95+
blur_tensor = (blur_tensor - 0.5) * 2
96+
97+
threshold = 1 - self.minimum_denoise
98+
99+
if self.coherence_mode == "Staged":
100+
# wherever the blur_tensor is less than fully masked, convert it to threshold
101+
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
102+
else:
103+
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
104+
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
105+
106+
else:
107+
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
108+
109+
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
110+
111+
# compute a [0, 1] mask from the blur_tensor
112+
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
113+
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
114+
expanded_image_dto = context.images.save(expanded_mask_image)
115+
116+
masked_latents_name = None
117+
if self.unet is not None and self.vae is not None and self.image is not None:
118+
# all three fields must be present at the same time
119+
main_model_config = context.models.get_config(self.unet.unet.key)
120+
assert isinstance(main_model_config, MainConfigBase)
121+
if main_model_config.variant is ModelVariantType.Inpaint:
122+
mask = blur_tensor
123+
vae_info: LoadedModel = context.models.load(self.vae.vae)
124+
image = context.images.get_pil(self.image.image_name)
125+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
126+
if image_tensor.dim() == 3:
127+
image_tensor = image_tensor.unsqueeze(0)
128+
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
129+
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
130+
masked_latents = ImageToLatentsInvocation.vae_encode(
131+
vae_info, self.fp32, self.tiled, masked_image.clone()
132+
)
133+
masked_latents_name = context.tensors.save(tensor=masked_latents)
134+
135+
return GradientMaskOutput(
136+
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
137+
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
138+
)

invokeai/app/invocations/latent.py

Lines changed: 3 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
22
import inspect
33
from contextlib import ExitStack
4-
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
55

6-
import numpy as np
76
import torch
87
import torchvision
98
import torchvision.transforms as T
@@ -13,7 +12,7 @@
1312
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
1413
from diffusers.schedulers.scheduling_tcd import TCDScheduler
1514
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
16-
from PIL import Image, ImageFilter
15+
from PIL import Image
1716
from pydantic import field_validator
1817
from torchvision.transforms.functional import resize as tv_resize
1918
from transformers import CLIPVisionModelWithProjection
@@ -37,8 +36,7 @@
3736
from invokeai.app.util.controlnet_utils import prepare_control_image
3837
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
3938
from invokeai.backend.lora import LoRAModelRaw
40-
from invokeai.backend.model_manager import BaseModelType, LoadedModel
41-
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
39+
from invokeai.backend.model_manager import BaseModelType
4240
from invokeai.backend.model_patcher import ModelPatcher
4341
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
4442
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
@@ -158,120 +156,6 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
158156
)
159157

160158

161-
@invocation_output("gradient_mask_output")
162-
class GradientMaskOutput(BaseInvocationOutput):
163-
"""Outputs a denoise mask and an image representing the total gradient of the mask."""
164-
165-
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
166-
expanded_mask_area: ImageField = OutputField(
167-
description="Image representing the total gradient area of the mask. For paste-back purposes."
168-
)
169-
170-
171-
@invocation(
172-
"create_gradient_mask",
173-
title="Create Gradient Mask",
174-
tags=["mask", "denoise"],
175-
category="latents",
176-
version="1.1.0",
177-
)
178-
class CreateGradientMaskInvocation(BaseInvocation):
179-
"""Creates mask for denoising model run."""
180-
181-
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
182-
edge_radius: int = InputField(
183-
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
184-
)
185-
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
186-
minimum_denoise: float = InputField(
187-
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
188-
)
189-
image: Optional[ImageField] = InputField(
190-
default=None,
191-
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
192-
title="[OPTIONAL] Image",
193-
ui_order=6,
194-
)
195-
unet: Optional[UNetField] = InputField(
196-
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
197-
default=None,
198-
input=Input.Connection,
199-
title="[OPTIONAL] UNet",
200-
ui_order=5,
201-
)
202-
vae: Optional[VAEField] = InputField(
203-
default=None,
204-
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
205-
title="[OPTIONAL] VAE",
206-
input=Input.Connection,
207-
ui_order=7,
208-
)
209-
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
210-
fp32: bool = InputField(
211-
default=DEFAULT_PRECISION == "float32",
212-
description=FieldDescriptions.fp32,
213-
ui_order=9,
214-
)
215-
216-
@torch.no_grad()
217-
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
218-
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
219-
if self.edge_radius > 0:
220-
if self.coherence_mode == "Box Blur":
221-
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
222-
else: # Gaussian Blur OR Staged
223-
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
224-
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
225-
226-
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
227-
228-
# redistribute blur so that the original edges are 0 and blur outwards to 1
229-
blur_tensor = (blur_tensor - 0.5) * 2
230-
231-
threshold = 1 - self.minimum_denoise
232-
233-
if self.coherence_mode == "Staged":
234-
# wherever the blur_tensor is less than fully masked, convert it to threshold
235-
blur_tensor = torch.where((blur_tensor < 1) & (blur_tensor > 0), threshold, blur_tensor)
236-
else:
237-
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
238-
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
239-
240-
else:
241-
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
242-
243-
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
244-
245-
# compute a [0, 1] mask from the blur_tensor
246-
expanded_mask = torch.where((blur_tensor < 1), 0, 1)
247-
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
248-
expanded_image_dto = context.images.save(expanded_mask_image)
249-
250-
masked_latents_name = None
251-
if self.unet is not None and self.vae is not None and self.image is not None:
252-
# all three fields must be present at the same time
253-
main_model_config = context.models.get_config(self.unet.unet.key)
254-
assert isinstance(main_model_config, MainConfigBase)
255-
if main_model_config.variant is ModelVariantType.Inpaint:
256-
mask = blur_tensor
257-
vae_info: LoadedModel = context.models.load(self.vae.vae)
258-
image = context.images.get_pil(self.image.image_name)
259-
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
260-
if image_tensor.dim() == 3:
261-
image_tensor = image_tensor.unsqueeze(0)
262-
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
263-
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
264-
masked_latents = ImageToLatentsInvocation.vae_encode(
265-
vae_info, self.fp32, self.tiled, masked_image.clone()
266-
)
267-
masked_latents_name = context.tensors.save(tensor=masked_latents)
268-
269-
return GradientMaskOutput(
270-
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
271-
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
272-
)
273-
274-
275159
def get_scheduler(
276160
context: InvocationContext,
277161
scheduler_info: ModelIdentifierField,

0 commit comments

Comments
 (0)