|
12 | 12 | from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
13 | 13 | from diffusers.schedulers.scheduling_tcd import TCDScheduler
|
14 | 14 | from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler
|
15 |
| -from PIL import Image |
16 | 15 | from pydantic import field_validator
|
17 | 16 | from torchvision.transforms.functional import resize as tv_resize
|
18 | 17 | from transformers import CLIPVisionModelWithProjection
|
|
22 | 21 | ConditioningField,
|
23 | 22 | DenoiseMaskField,
|
24 | 23 | FieldDescriptions,
|
25 |
| - ImageField, |
26 | 24 | Input,
|
27 | 25 | InputField,
|
28 | 26 | LatentsField,
|
29 | 27 | OutputField,
|
30 | 28 | UIType,
|
31 | 29 | )
|
32 | 30 | from invokeai.app.invocations.ip_adapter import IPAdapterField
|
33 |
| -from invokeai.app.invocations.primitives import DenoiseMaskOutput, LatentsOutput |
| 31 | +from invokeai.app.invocations.primitives import LatentsOutput |
34 | 32 | from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
35 | 33 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
36 | 34 | from invokeai.app.util.controlnet_utils import prepare_control_image
|
|
55 | 53 | ControlNetData,
|
56 | 54 | StableDiffusionGeneratorPipeline,
|
57 | 55 | T2IAdapterData,
|
58 |
| - image_resized_to_grid_as_tensor, |
59 | 56 | )
|
60 | 57 | from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
61 | 58 | from ...backend.util.devices import TorchDevice
|
62 | 59 | from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
63 | 60 | from .controlnet_image_processors import ControlField
|
64 |
| -from .model import ModelIdentifierField, UNetField, VAEField |
| 61 | +from .model import ModelIdentifierField, UNetField |
65 | 62 |
|
66 | 63 | DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
67 | 64 |
|
@@ -91,71 +88,6 @@ def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
91 | 88 | return SchedulerOutput(scheduler=self.scheduler)
|
92 | 89 |
|
93 | 90 |
|
94 |
| -@invocation( |
95 |
| - "create_denoise_mask", |
96 |
| - title="Create Denoise Mask", |
97 |
| - tags=["mask", "denoise"], |
98 |
| - category="latents", |
99 |
| - version="1.0.2", |
100 |
| -) |
101 |
| -class CreateDenoiseMaskInvocation(BaseInvocation): |
102 |
| - """Creates mask for denoising model run.""" |
103 |
| - |
104 |
| - vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0) |
105 |
| - image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1) |
106 |
| - mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) |
107 |
| - tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3) |
108 |
| - fp32: bool = InputField( |
109 |
| - default=DEFAULT_PRECISION == "float32", |
110 |
| - description=FieldDescriptions.fp32, |
111 |
| - ui_order=4, |
112 |
| - ) |
113 |
| - |
114 |
| - def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor: |
115 |
| - if mask_image.mode != "L": |
116 |
| - mask_image = mask_image.convert("L") |
117 |
| - mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) |
118 |
| - if mask_tensor.dim() == 3: |
119 |
| - mask_tensor = mask_tensor.unsqueeze(0) |
120 |
| - # if shape is not None: |
121 |
| - # mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR) |
122 |
| - return mask_tensor |
123 |
| - |
124 |
| - @torch.no_grad() |
125 |
| - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: |
126 |
| - if self.image is not None: |
127 |
| - image = context.images.get_pil(self.image.image_name) |
128 |
| - image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) |
129 |
| - if image_tensor.dim() == 3: |
130 |
| - image_tensor = image_tensor.unsqueeze(0) |
131 |
| - else: |
132 |
| - image_tensor = None |
133 |
| - |
134 |
| - mask = self.prep_mask_tensor( |
135 |
| - context.images.get_pil(self.mask.image_name), |
136 |
| - ) |
137 |
| - |
138 |
| - if image_tensor is not None: |
139 |
| - vae_info = context.models.load(self.vae.vae) |
140 |
| - |
141 |
| - img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) |
142 |
| - masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) |
143 |
| - # TODO: |
144 |
| - masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) |
145 |
| - |
146 |
| - masked_latents_name = context.tensors.save(tensor=masked_latents) |
147 |
| - else: |
148 |
| - masked_latents_name = None |
149 |
| - |
150 |
| - mask_name = context.tensors.save(tensor=mask) |
151 |
| - |
152 |
| - return DenoiseMaskOutput.build( |
153 |
| - mask_name=mask_name, |
154 |
| - masked_latents_name=masked_latents_name, |
155 |
| - gradient=False, |
156 |
| - ) |
157 |
| - |
158 |
| - |
159 | 91 | def get_scheduler(
|
160 | 92 | context: InvocationContext,
|
161 | 93 | scheduler_info: ModelIdentifierField,
|
|
0 commit comments