Skip to content

Commit 854bca6

Browse files
committed
Move CreateDenoiseMaskInvocation to its own file. No functional changes.
1 parent fea9013 commit 854bca6

File tree

2 files changed

+82
-70
lines changed

2 files changed

+82
-70
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torchvision.transforms as T
5+
from PIL import Image
6+
from torchvision.transforms.functional import resize as tv_resize
7+
8+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
10+
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
11+
from invokeai.app.invocations.latent import DEFAULT_PRECISION
12+
from invokeai.app.invocations.model import VAEField
13+
from invokeai.app.invocations.primitives import DenoiseMaskOutput
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
16+
17+
18+
@invocation(
19+
"create_denoise_mask",
20+
title="Create Denoise Mask",
21+
tags=["mask", "denoise"],
22+
category="latents",
23+
version="1.0.2",
24+
)
25+
class CreateDenoiseMaskInvocation(BaseInvocation):
26+
"""Creates mask for denoising model run."""
27+
28+
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
29+
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
30+
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
31+
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
32+
fp32: bool = InputField(
33+
default=DEFAULT_PRECISION == "float32",
34+
description=FieldDescriptions.fp32,
35+
ui_order=4,
36+
)
37+
38+
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
39+
if mask_image.mode != "L":
40+
mask_image = mask_image.convert("L")
41+
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
42+
if mask_tensor.dim() == 3:
43+
mask_tensor = mask_tensor.unsqueeze(0)
44+
# if shape is not None:
45+
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
46+
return mask_tensor
47+
48+
@torch.no_grad()
49+
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
50+
if self.image is not None:
51+
image = context.images.get_pil(self.image.image_name)
52+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
53+
if image_tensor.dim() == 3:
54+
image_tensor = image_tensor.unsqueeze(0)
55+
else:
56+
image_tensor = None
57+
58+
mask = self.prep_mask_tensor(
59+
context.images.get_pil(self.mask.image_name),
60+
)
61+
62+
if image_tensor is not None:
63+
vae_info = context.models.load(self.vae.vae)
64+
65+
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
66+
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
67+
# TODO:
68+
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
69+
70+
masked_latents_name = context.tensors.save(tensor=masked_latents)
71+
else:
72+
masked_latents_name = None
73+
74+
mask_name = context.tensors.save(tensor=mask)
75+
76+
return DenoiseMaskOutput.build(
77+
mask_name=mask_name,
78+
masked_latents_name=masked_latents_name,
79+
gradient=False,
80+
)

invokeai/app/invocations/latent.py

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
1313
from diffusers.schedulers.scheduling_tcd import TCDScheduler
1414
from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
15-
from PIL import Image
1615
from pydantic import field_validator
1716
from torchvision.transforms.functional import resize as tv_resize
1817
from transformers import CLIPVisionModelWithProjection
@@ -22,15 +21,14 @@
2221
ConditioningField,
2322
DenoiseMaskField,
2423
FieldDescriptions,
25-
ImageField,
2624
Input,
2725
InputField,
2826
LatentsField,
2927
OutputField,
3028
UIType,
3129
)
3230
from invokeai.app.invocations.ip_adapter import IPAdapterField
33-
from invokeai.app.invocations.primitives import DenoiseMaskOutput, LatentsOutput
31+
from invokeai.app.invocations.primitives import LatentsOutput
3432
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
3533
from invokeai.app.services.shared.invocation_context import InvocationContext
3634
from invokeai.app.util.controlnet_utils import prepare_control_image
@@ -55,13 +53,12 @@
5553
ControlNetData,
5654
StableDiffusionGeneratorPipeline,
5755
T2IAdapterData,
58-
image_resized_to_grid_as_tensor,
5956
)
6057
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
6158
from ...backend.util.devices import TorchDevice
6259
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
6360
from .controlnet_image_processors import ControlField
64-
from .model import ModelIdentifierField, UNetField, VAEField
61+
from .model import ModelIdentifierField, UNetField
6562

6663
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
6764

@@ -91,71 +88,6 @@ def invoke(self, context: InvocationContext) -> SchedulerOutput:
9188
return SchedulerOutput(scheduler=self.scheduler)
9289

9390

94-
@invocation(
95-
"create_denoise_mask",
96-
title="Create Denoise Mask",
97-
tags=["mask", "denoise"],
98-
category="latents",
99-
version="1.0.2",
100-
)
101-
class CreateDenoiseMaskInvocation(BaseInvocation):
102-
"""Creates mask for denoising model run."""
103-
104-
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
105-
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
106-
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
107-
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
108-
fp32: bool = InputField(
109-
default=DEFAULT_PRECISION == "float32",
110-
description=FieldDescriptions.fp32,
111-
ui_order=4,
112-
)
113-
114-
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
115-
if mask_image.mode != "L":
116-
mask_image = mask_image.convert("L")
117-
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
118-
if mask_tensor.dim() == 3:
119-
mask_tensor = mask_tensor.unsqueeze(0)
120-
# if shape is not None:
121-
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
122-
return mask_tensor
123-
124-
@torch.no_grad()
125-
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
126-
if self.image is not None:
127-
image = context.images.get_pil(self.image.image_name)
128-
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
129-
if image_tensor.dim() == 3:
130-
image_tensor = image_tensor.unsqueeze(0)
131-
else:
132-
image_tensor = None
133-
134-
mask = self.prep_mask_tensor(
135-
context.images.get_pil(self.mask.image_name),
136-
)
137-
138-
if image_tensor is not None:
139-
vae_info = context.models.load(self.vae.vae)
140-
141-
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
142-
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
143-
# TODO:
144-
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
145-
146-
masked_latents_name = context.tensors.save(tensor=masked_latents)
147-
else:
148-
masked_latents_name = None
149-
150-
mask_name = context.tensors.save(tensor=mask)
151-
152-
return DenoiseMaskOutput.build(
153-
mask_name=mask_name,
154-
masked_latents_name=masked_latents_name,
155-
gradient=False,
156-
)
157-
158-
15991
def get_scheduler(
16092
context: InvocationContext,
16193
scheduler_info: ModelIdentifierField,

0 commit comments

Comments
 (0)