19
19
- Gaussian Blur Mask
20
20
- Image Concatenate
21
21
"""
22
+
22
23
import math
23
- from PIL import Image
24
- from typing import Optional , List , Dict , Any , Tuple , Literal
24
+ from typing import List , Literal , Optional
25
25
26
- import cv2
27
26
import numpy as np
28
27
import torch
29
28
import torchvision .transforms as T
29
+ from PIL import Image
30
30
31
31
from invokeai .app .invocations .baseinvocation import BaseInvocation , BaseInvocationOutput , invocation , invocation_output
32
32
from invokeai .app .invocations .fields import (
33
- Field ,
34
- Input ,
35
33
ImageField ,
34
+ Input ,
36
35
InputField ,
37
36
OutputField ,
38
37
TensorField ,
39
38
WithBoard ,
40
- WithMetadata
39
+ WithMetadata ,
41
40
)
42
- from invokeai .app .invocations .primitives import ImageOutput , MaskOutput
41
+ from invokeai .app .invocations .primitives import ImageOutput
43
42
from invokeai .app .services .shared .invocation_context import InvocationContext
44
- from invokeai .backend .image_util .util import cv2_to_pil , pil_to_cv2
45
-
46
43
47
44
DIRECTION_OPTIONS = Literal ["right" , "left" , "down" , "up" ]
48
45
49
46
50
- def concat_images (image1 : Image . Image , image2 : Image . Image ,
51
- direction : str = "right" ,
52
- match_image_size = True ) -> Image .Image :
47
+ def concat_images (
48
+ image1 : Image . Image , image2 : Image . Image , direction : str = "right" , match_image_size = True
49
+ ) -> Image .Image :
53
50
"""Concatenate two images either horizontally or vertically."""
54
51
# Ensure that image modes are same
55
52
if image1 .mode != image2 .mode :
56
53
image2 = image2 .convert (image1 .mode )
57
-
54
+
58
55
if direction == "right" or direction == "left" :
59
56
if direction == "left" :
60
57
image1 , image2 = image2 , image1
@@ -73,7 +70,7 @@ def concat_images(image1: Image.Image, image2: Image.Image,
73
70
new_image .paste (image2 , (0 , image1 .height ))
74
71
else :
75
72
raise ValueError ("Mode must be either 'horizontal' or 'vertical'." )
76
-
73
+
77
74
return new_image
78
75
79
76
@@ -89,7 +86,9 @@ class ConcatImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
89
86
90
87
image1 : ImageField = InputField (description = "The first image to process" )
91
88
image2 : ImageField = InputField (description = "The second image to process" )
92
- mode : DIRECTION_OPTIONS = InputField (default = "horizontal" , description = "Mode of concatenation: 'horizontal' or 'vertical'" )
89
+ mode : DIRECTION_OPTIONS = InputField (
90
+ default = "horizontal" , description = "Mode of concatenation: 'horizontal' or 'vertical'"
91
+ )
93
92
94
93
def invoke (self , context : InvocationContext ) -> ImageOutput :
95
94
image1 = context .images .get_pil (self .image1 .image_name )
@@ -103,9 +102,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
103
102
class InpaintCropOutput (BaseInvocationOutput ):
104
103
"""The output of Inpain Crop Invocation."""
105
104
106
- image_crop : ImageField = OutputField (
107
- description = "Cropped part of image" , title = "Conditioning"
108
- )
105
+ image_crop : ImageField = OutputField (description = "Cropped part of image" , title = "Conditioning" )
109
106
stitcher : List [int ] = OutputField (description = "Parameter for stitching image after inpainting" )
110
107
111
108
@@ -117,17 +114,17 @@ class InpaintCropOutput(BaseInvocationOutput):
117
114
)
118
115
class InpaintCropInvocation (BaseInvocation , WithMetadata , WithBoard ):
119
116
"Crop from image masked area with resize and expand options"
120
-
117
+
121
118
image : ImageField = InputField (description = "The source image" )
122
119
mask : TensorField = InputField (description = "Inpaint mask" )
123
-
120
+
124
121
def invoke (self , context : InvocationContext ) -> ImageOutput :
125
122
image = context .images .get_pil (self .image .image_name , "RGB" )
126
123
mask = context .tensors .load (self .mask .tensor_name )
127
-
124
+
128
125
# TODO: Finish InpaintCrop implementation
129
126
image_crop = Image .new ("RGB" , (256 , 256 ))
130
-
127
+
131
128
image_dto = context .images .save (image = image_crop )
132
129
return ImageOutput .build (image_dto )
133
130
@@ -136,9 +133,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
136
133
class ACEppProcessorOutput (BaseInvocationOutput ):
137
134
"""The conditioning output of a FLUX Fill invocation."""
138
135
139
- image : ImageField = OutputField (
140
- description = "Concatted image" , title = "Image"
141
- )
136
+ image : ImageField = OutputField (description = "Concatted image" , title = "Image" )
142
137
mask : TensorField = OutputField (description = "Inpaint mask" )
143
138
crop_pad : int = OutputField (description = "Padding to crop result" )
144
139
crop_width : int = OutputField (description = "Width of output area" )
@@ -155,15 +150,15 @@ class ACEppProcessor(BaseInvocation):
155
150
reference_image : ImageField = InputField (description = "Reference Image" )
156
151
edit_image : Optional [ImageField ] = InputField (description = "Edit Image" , default = None , input = Input .Connection )
157
152
edit_mask : Optional [TensorField ] = InputField (description = "Edit Mask" , default = None , input = Input .Connection )
158
-
153
+
159
154
width : int = InputField (default = 512 , gt = 0 , description = "The width of the crop rectangle" )
160
155
height : int = InputField (default = 512 , gt = 0 , description = "The height of the crop rectangle" )
161
-
156
+
162
157
max_seq_len : int = InputField (default = 4096 , gt = 2048 , le = 5120 , description = "The height of the crop rectangle" )
163
-
158
+
164
159
def image_check (self , image_pil : Image .Image ) -> torch .Tensor :
165
160
max_aspect_ratio = 4
166
-
161
+
167
162
image = self .transform_pil_tensor (image_pil )
168
163
image = image .unsqueeze (0 )
169
164
# preprocess
@@ -173,20 +168,18 @@ def image_check(self, image_pil: Image.Image) -> torch.Tensor:
173
168
elif W / H > max_aspect_ratio :
174
169
image [0 ] = T .CenterCrop ([H , int (max_aspect_ratio * H )])(image [0 ])
175
170
return image [0 ]
176
-
171
+
177
172
def transform_pil_tensor (self , pil_image : Image .Image ) -> torch .Tensor :
178
- transform = T .Compose ([
179
- T .ToTensor ()
180
- ])
173
+ transform = T .Compose ([T .ToTensor ()])
181
174
tensor_image : torch .Tensor = transform (pil_image )
182
175
return tensor_image
183
-
176
+
184
177
def invoke (self , context : InvocationContext ) -> ACEppProcessorOutput :
185
178
d = 16 # Flux pixels per patch rate
186
-
179
+
187
180
image_pil = context .images .get_pil (self .reference_image .image_name , "RGB" )
188
181
image = self .image_check (image_pil ) - 0.5
189
-
182
+
190
183
if self .edit_image is None :
191
184
edit_image = torch .zeros ((3 , self .height , self .width ))
192
185
edit_mask = torch .ones ((1 , self .height , self .width ))
@@ -199,23 +192,21 @@ def invoke(self, context: InvocationContext) -> ACEppProcessorOutput:
199
192
edit_mask = torch .ones ((eH , eW ))
200
193
else :
201
194
edit_mask = context .tensors .load (self .edit_mask .tensor_name )
202
-
195
+
203
196
out_H , out_W = edit_image .shape [- 2 :]
204
-
197
+
205
198
_ , H , W = image .shape
206
199
_ , eH , eW = edit_image .shape
207
-
200
+
208
201
# align height with edit_image
209
202
scale = eH / H
210
203
tH , tW = eH , int (W * scale )
211
-
212
- reference_image = T .Resize ((tH , tW ), interpolation = T .InterpolationMode .BILINEAR , antialias = True )(
213
- image )
204
+
205
+ reference_image = T .Resize ((tH , tW ), interpolation = T .InterpolationMode .BILINEAR , antialias = True )(image )
214
206
edit_image = torch .cat ([reference_image , edit_image ], dim = - 1 )
215
- edit_mask = torch .cat ([torch .zeros ((1 , reference_image .shape [1 ], reference_image .shape [2 ])), edit_mask ],
216
- dim = - 1 )
207
+ edit_mask = torch .cat ([torch .zeros ((1 , reference_image .shape [1 ], reference_image .shape [2 ])), edit_mask ], dim = - 1 )
217
208
slice_w = reference_image .shape [- 1 ]
218
-
209
+
219
210
H , W = edit_image .shape [- 2 :]
220
211
scale = min (1.0 , math .sqrt (self .max_seq_len * 2 / ((H / d ) * (W / d ))))
221
212
rH = int (H * scale ) // d * d
@@ -235,7 +226,7 @@ def invoke(self, context: InvocationContext) -> ACEppProcessorOutput:
235
226
# Convert to torch.bool
236
227
edit_mask = edit_mask > 0.5
237
228
image_out = Image .fromarray ((edit_image [0 ].numpy () * 255 ).astype (np .uint8 ))
238
-
229
+
239
230
image_dto = context .images .save (image = image_out )
240
231
mask_name = context .tensors .save (edit_mask )
241
232
return ACEppProcessorOutput (
0 commit comments