Skip to content

Commit 4eca729

Browse files
chore: ruff
1 parent 0bded00 commit 4eca729

File tree

5 files changed

+51
-69
lines changed

5 files changed

+51
-69
lines changed

invokeai/app/invocations/fields.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, Callable, Optional, Tuple, List
2+
from typing import Any, Callable, Optional, Tuple
33

44
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
55
from pydantic.fields import _Unset
@@ -283,12 +283,13 @@ class FluxReduxConditioningField(BaseModel):
283283
class FluxUnoReferenceField(BaseModel):
284284
"""A FLUX Uno image list primitive value"""
285285

286-
image_names: List[str] = Field(
286+
image_names: list[str] | None = Field(
287287
default=None,
288288
description="The name of the image associated with this conditioning tensor. This is used to store the image "
289289
"in the context.",
290290
)
291291

292+
292293
class FluxFillConditioningField(BaseModel):
293294
"""A FLUX Fill conditioning field."""
294295

invokeai/app/invocations/flux_denoise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import numpy.typing as npt
77
import torch
88
import torchvision.transforms as tv_transforms
9-
from PIL import Image
109
import torchvision.transforms.functional as TVF
10+
from PIL import Image
1111
from torchvision.transforms.functional import resize as tv_resize
1212
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
1313

@@ -27,9 +27,9 @@
2727
WithMetadata,
2828
)
2929
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
30+
from invokeai.app.invocations.flux_uno import preprocess_ref
3031
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
3132
from invokeai.app.invocations.ip_adapter import IPAdapterField
32-
from invokeai.app.invocations.flux_uno import preprocess_ref
3333
from invokeai.app.invocations.model import ControlLoRAField, LoRAField, TransformerField, VAEField
3434
from invokeai.app.invocations.primitives import LatentsOutput
3535
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -45,10 +45,10 @@
4545
from invokeai.backend.flux.sampling_utils import (
4646
clip_timestep_schedule_fractional,
4747
generate_img_ids,
48-
prepare_multi_ip,
4948
get_noise,
5049
get_schedule,
5150
pack,
51+
prepare_multi_ip,
5252
unpack,
5353
)
5454
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
@@ -677,7 +677,7 @@ def _prep_controlnet_extensions(
677677
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
678678

679679
return controlnet_extensions
680-
680+
681681
def _prep_uno_reference_imgs(self, context: InvocationContext) -> list[torch.Tensor]:
682682
# Load the conditioning image and resize it to the target image size.
683683
assert self.controlnet_vae is not None, 'Controlnet Vae must be set for UNO encoding'

invokeai/app/invocations/flux_uno.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
invocation,
1010
invocation_output,
1111
)
12-
from invokeai.app.invocations.fields import (
13-
InputField,
14-
OutputField,
15-
FluxUnoReferenceField
16-
)
12+
from invokeai.app.invocations.fields import FluxUnoReferenceField, InputField, OutputField
1713
from invokeai.app.invocations.primitives import ImageField
1814
from invokeai.app.services.shared.invocation_context import InvocationContext
1915

@@ -56,9 +52,7 @@ def preprocess_ref(raw_image: Image.Image, long_size: int = 512) -> Image.Image:
5652
class FluxUnoOutput(BaseInvocationOutput):
5753
"""The conditioning output of a FLUX Redux invocation."""
5854

59-
uno_refs: FluxUnoReferenceField = OutputField(
60-
description="Reference images container", title="Reference images"
61-
)
55+
uno_refs: FluxUnoReferenceField = OutputField(description="Reference images container", title="Reference images")
6256

6357

6458
@invocation(
@@ -78,14 +72,10 @@ class FluxUnoInvocation(BaseInvocation):
7872
image4: Optional[ImageField] = InputField(default=None, description="4th reference")
7973

8074
def invoke(self, context: InvocationContext) -> FluxUnoOutput:
81-
8275
images: list[str] = []
8376
for image in [self.image, self.image2, self.image3, self.image4]:
8477
if image is not None:
8578
image_pil = context.images.get_pil(image.image_name)
8679
images.append(context.images.save(image=image_pil).image_name)
87-
88-
return FluxUnoOutput(
89-
uno_refs=FluxUnoReferenceField(
90-
image_names=images)
91-
)
80+
81+
return FluxUnoOutput(uno_refs=FluxUnoReferenceField(image_names=images))

invokeai/app/invocations/image_context_utils.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,39 @@
1919
- Gaussian Blur Mask
2020
- Image Concatenate
2121
"""
22+
2223
import math
23-
from PIL import Image
24-
from typing import Optional, List, Dict, Any, Tuple, Literal
24+
from typing import List, Literal, Optional
2525

26-
import cv2
2726
import numpy as np
2827
import torch
2928
import torchvision.transforms as T
29+
from PIL import Image
3030

3131
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
3232
from invokeai.app.invocations.fields import (
33-
Field,
34-
Input,
3533
ImageField,
34+
Input,
3635
InputField,
3736
OutputField,
3837
TensorField,
3938
WithBoard,
40-
WithMetadata
39+
WithMetadata,
4140
)
42-
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
41+
from invokeai.app.invocations.primitives import ImageOutput
4342
from invokeai.app.services.shared.invocation_context import InvocationContext
44-
from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
45-
4643

4744
DIRECTION_OPTIONS = Literal["right", "left", "down", "up"]
4845

4946

50-
def concat_images(image1: Image.Image, image2: Image.Image,
51-
direction: str = "right",
52-
match_image_size=True) -> Image.Image:
47+
def concat_images(
48+
image1: Image.Image, image2: Image.Image, direction: str = "right", match_image_size=True
49+
) -> Image.Image:
5350
"""Concatenate two images either horizontally or vertically."""
5451
# Ensure that image modes are same
5552
if image1.mode != image2.mode:
5653
image2 = image2.convert(image1.mode)
57-
54+
5855
if direction == "right" or direction == "left":
5956
if direction == "left":
6057
image1, image2 = image2, image1
@@ -73,7 +70,7 @@ def concat_images(image1: Image.Image, image2: Image.Image,
7370
new_image.paste(image2, (0, image1.height))
7471
else:
7572
raise ValueError("Mode must be either 'horizontal' or 'vertical'.")
76-
73+
7774
return new_image
7875

7976

@@ -89,7 +86,9 @@ class ConcatImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
8986

9087
image1: ImageField = InputField(description="The first image to process")
9188
image2: ImageField = InputField(description="The second image to process")
92-
mode: DIRECTION_OPTIONS = InputField(default="horizontal", description="Mode of concatenation: 'horizontal' or 'vertical'")
89+
mode: DIRECTION_OPTIONS = InputField(
90+
default="horizontal", description="Mode of concatenation: 'horizontal' or 'vertical'"
91+
)
9392

9493
def invoke(self, context: InvocationContext) -> ImageOutput:
9594
image1 = context.images.get_pil(self.image1.image_name)
@@ -103,9 +102,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
103102
class InpaintCropOutput(BaseInvocationOutput):
104103
"""The output of Inpain Crop Invocation."""
105104

106-
image_crop: ImageField = OutputField(
107-
description="Cropped part of image", title="Conditioning"
108-
)
105+
image_crop: ImageField = OutputField(description="Cropped part of image", title="Conditioning")
109106
stitcher: List[int] = OutputField(description="Parameter for stitching image after inpainting")
110107

111108

@@ -117,17 +114,17 @@ class InpaintCropOutput(BaseInvocationOutput):
117114
)
118115
class InpaintCropInvocation(BaseInvocation, WithMetadata, WithBoard):
119116
"Crop from image masked area with resize and expand options"
120-
117+
121118
image: ImageField = InputField(description="The source image")
122119
mask: TensorField = InputField(description="Inpaint mask")
123-
120+
124121
def invoke(self, context: InvocationContext) -> ImageOutput:
125122
image = context.images.get_pil(self.image.image_name, "RGB")
126123
mask = context.tensors.load(self.mask.tensor_name)
127-
124+
128125
# TODO: Finish InpaintCrop implementation
129126
image_crop = Image.new("RGB", (256, 256))
130-
127+
131128
image_dto = context.images.save(image=image_crop)
132129
return ImageOutput.build(image_dto)
133130

@@ -136,9 +133,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
136133
class ACEppProcessorOutput(BaseInvocationOutput):
137134
"""The conditioning output of a FLUX Fill invocation."""
138135

139-
image: ImageField = OutputField(
140-
description="Concatted image", title="Image"
141-
)
136+
image: ImageField = OutputField(description="Concatted image", title="Image")
142137
mask: TensorField = OutputField(description="Inpaint mask")
143138
crop_pad: int = OutputField(description="Padding to crop result")
144139
crop_width: int = OutputField(description="Width of output area")
@@ -155,15 +150,15 @@ class ACEppProcessor(BaseInvocation):
155150
reference_image: ImageField = InputField(description="Reference Image")
156151
edit_image: Optional[ImageField] = InputField(description="Edit Image", default=None, input=Input.Connection)
157152
edit_mask: Optional[TensorField] = InputField(description="Edit Mask", default=None, input=Input.Connection)
158-
153+
159154
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
160155
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
161-
156+
162157
max_seq_len: int = InputField(default=4096, gt=2048, le=5120, description="The height of the crop rectangle")
163-
158+
164159
def image_check(self, image_pil: Image.Image) -> torch.Tensor:
165160
max_aspect_ratio = 4
166-
161+
167162
image = self.transform_pil_tensor(image_pil)
168163
image = image.unsqueeze(0)
169164
# preprocess
@@ -173,20 +168,18 @@ def image_check(self, image_pil: Image.Image) -> torch.Tensor:
173168
elif W / H > max_aspect_ratio:
174169
image[0] = T.CenterCrop([H, int(max_aspect_ratio * H)])(image[0])
175170
return image[0]
176-
171+
177172
def transform_pil_tensor(self, pil_image: Image.Image) -> torch.Tensor:
178-
transform = T.Compose([
179-
T.ToTensor()
180-
])
173+
transform = T.Compose([T.ToTensor()])
181174
tensor_image: torch.Tensor = transform(pil_image)
182175
return tensor_image
183-
176+
184177
def invoke(self, context: InvocationContext) -> ACEppProcessorOutput:
185178
d = 16 # Flux pixels per patch rate
186-
179+
187180
image_pil = context.images.get_pil(self.reference_image.image_name, "RGB")
188181
image = self.image_check(image_pil) - 0.5
189-
182+
190183
if self.edit_image is None:
191184
edit_image = torch.zeros((3, self.height, self.width))
192185
edit_mask = torch.ones((1, self.height, self.width))
@@ -199,23 +192,21 @@ def invoke(self, context: InvocationContext) -> ACEppProcessorOutput:
199192
edit_mask = torch.ones((eH, eW))
200193
else:
201194
edit_mask = context.tensors.load(self.edit_mask.tensor_name)
202-
195+
203196
out_H, out_W = edit_image.shape[-2:]
204-
197+
205198
_, H, W = image.shape
206199
_, eH, eW = edit_image.shape
207-
200+
208201
# align height with edit_image
209202
scale = eH / H
210203
tH, tW = eH, int(W * scale)
211-
212-
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(
213-
image)
204+
205+
reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(image)
214206
edit_image = torch.cat([reference_image, edit_image], dim=-1)
215-
edit_mask = torch.cat([torch.zeros((1, reference_image.shape[1], reference_image.shape[2])), edit_mask],
216-
dim=-1)
207+
edit_mask = torch.cat([torch.zeros((1, reference_image.shape[1], reference_image.shape[2])), edit_mask], dim=-1)
217208
slice_w = reference_image.shape[-1]
218-
209+
219210
H, W = edit_image.shape[-2:]
220211
scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / d) * (W / d))))
221212
rH = int(H * scale) // d * d
@@ -235,7 +226,7 @@ def invoke(self, context: InvocationContext) -> ACEppProcessorOutput:
235226
# Convert to torch.bool
236227
edit_mask = edit_mask > 0.5
237228
image_out = Image.fromarray((edit_image[0].numpy() * 255).astype(np.uint8))
238-
229+
239230
image_dto = context.images.save(image=image_out)
240231
mask_name = context.tensors.save(edit_mask)
241232
return ACEppProcessorOutput(

invokeai/backend/flux/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ def forward(
123123
img_end = img.shape[1] # length of original image vector
124124
if uno_ref_imgs is not None and uno_ref_ids is not None:
125125
img_in = [img] + [self.img_in(ref) for ref in uno_ref_imgs]
126-
img_ids = [ids] + [ref_ids for ref_ids in uno_ref_ids]
127-
img = torch.cat(img_in, dim=1)
126+
img_ids = [ids] + uno_ref_ids
127+
img = torch.cat(img_in, dim=1)
128128
ids = torch.cat(img_ids, dim=1)
129-
129+
130130
pe = self.pe_embedder(ids)
131131

132132
# Validate double_block_residuals shape.

0 commit comments

Comments
 (0)