Skip to content

Commit 7d2a666

Browse files
Ilan TchenakIlan Tchenak
authored andcommitted
wip support for bria's controlnet
1 parent 7140f2e commit 7d2a666

File tree

3 files changed

+340
-5
lines changed

3 files changed

+340
-5
lines changed

invokeai/backend/bria/controlnet.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from diffusers.models.modeling_utils import ModelMixin
2+
from enum import Enum
3+
from typing import Literal
4+
5+
BRIA_CONTROL_MODES = Literal["depth", "canny", "cologrid", "recolor", "tile", "pose"]
6+
class BriaControlModes(Enum):
7+
depth = 0
8+
canny = 1
9+
cologrid = 2
10+
recolor = 3
11+
tile = 4
12+
pose = 5
13+
14+
class BriaMultiControlNetModel(ModelMixin):
15+
r"""
16+
`BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
17+
This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
18+
compatible with `BriaControlNetModel`.
19+
Args:
20+
controlnets (`List[BriaControlNetModel]`):
21+
Provides additional conditioning to the unet during the denoising process. You must set multiple
22+
`BriaControlNetModel` as a list.
23+
"""
24+
25+
def __init__(self, controlnets):
26+
super().__init__()
27+
self.nets = nn.ModuleList(controlnets)
28+
29+
def forward(
30+
self,
31+
hidden_states: torch.FloatTensor,
32+
controlnet_cond: List[torch.tensor],
33+
controlnet_mode: List[torch.tensor],
34+
conditioning_scale: List[float],
35+
encoder_hidden_states: torch.Tensor = None,
36+
pooled_projections: torch.Tensor = None,
37+
timestep: torch.LongTensor = None,
38+
img_ids: torch.Tensor = None,
39+
txt_ids: torch.Tensor = None,
40+
guidance: torch.Tensor = None,
41+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
42+
return_dict: bool = True,
43+
) -> Union[BriaControlNetOutput, Tuple]:
44+
# ControlNet-Union with multiple conditions
45+
# only load one ControlNet for saving memories
46+
if len(self.nets) == 1 and self.nets[0].union:
47+
controlnet = self.nets[0]
48+
49+
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
50+
block_samples, single_block_samples = controlnet(
51+
hidden_states=hidden_states,
52+
controlnet_cond=image,
53+
controlnet_mode=mode[:, None],
54+
conditioning_scale=scale,
55+
timestep=timestep,
56+
guidance=guidance,
57+
pooled_projections=pooled_projections,
58+
encoder_hidden_states=encoder_hidden_states,
59+
txt_ids=txt_ids,
60+
img_ids=img_ids,
61+
joint_attention_kwargs=joint_attention_kwargs,
62+
return_dict=return_dict,
63+
)
64+
65+
# merge samples
66+
if i == 0:
67+
control_block_samples = block_samples
68+
control_single_block_samples = single_block_samples
69+
else:
70+
control_block_samples = [
71+
control_block_sample + block_sample
72+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
73+
]
74+
75+
control_single_block_samples = [
76+
control_single_block_sample + block_sample
77+
for control_single_block_sample, block_sample in zip(
78+
control_single_block_samples, single_block_samples
79+
)
80+
]
81+
82+
# Regular Multi-ControlNets
83+
# load all ControlNets into memories
84+
else:
85+
for i, (image, mode, scale, controlnet) in enumerate(
86+
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
87+
):
88+
block_samples, single_block_samples = controlnet(
89+
hidden_states=hidden_states,
90+
controlnet_cond=image,
91+
controlnet_mode=mode[:, None],
92+
conditioning_scale=scale,
93+
timestep=timestep,
94+
guidance=guidance,
95+
pooled_projections=pooled_projections,
96+
encoder_hidden_states=encoder_hidden_states,
97+
txt_ids=txt_ids,
98+
img_ids=img_ids,
99+
joint_attention_kwargs=joint_attention_kwargs,
100+
return_dict=return_dict,
101+
)
102+
103+
# merge samples
104+
if i == 0:
105+
control_block_samples = block_samples
106+
control_single_block_samples = single_block_samples
107+
else:
108+
if block_samples is not None and control_block_samples is not None:
109+
control_block_samples = [
110+
control_block_sample + block_sample
111+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
112+
]
113+
if single_block_samples is not None and control_single_block_samples is not None:
114+
control_single_block_samples = [
115+
control_single_block_sample + block_sample
116+
for control_single_block_sample, block_sample in zip(
117+
control_single_block_samples, single_block_samples
118+
)
119+
]
120+
121+
return control_block_samples, control_single_block_samples
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from invokeai.backend.bria.controlnet import BRIA_CONTROL_MODES
2+
from pydantic import BaseModel, Field, field_validator, model_validator
3+
4+
from invokeai.app.invocations.baseinvocation import (
5+
BaseInvocation,
6+
BaseInvocationOutput,
7+
invocation,
8+
invocation_output,
9+
)
10+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
11+
from invokeai.app.invocations.model import ModelIdentifierField
12+
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
13+
from invokeai.app.services.shared.invocation_context import InvocationContext
14+
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
15+
16+
17+
class BriaControlNetField(BaseModel):
18+
image: ImageField = Field(description="The control image")
19+
model: ModelIdentifierField = Field(description="The ControlNet model to use")
20+
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
21+
controlnet_conditioning_scale: float = Field(description="The weight given to the ControlNet")
22+
control_guidance_start: float = Field(description="When the ControlNet is first applied (% of total steps)")
23+
control_guidance_end: float = Field(description="When the ControlNet is last applied (% of total steps)")
24+
25+
26+
@invocation_output("flux_controlnet_output")
27+
class BriaControlNetOutput(BaseInvocationOutput):
28+
"""FLUX ControlNet info"""
29+
30+
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
31+
32+
33+
@invocation(
34+
"bria_controlnet",
35+
title="Bria ControlNet",
36+
tags=["controlnet", "bria"],
37+
category="controlnet",
38+
version="1.0.0",
39+
)
40+
class BriaControlNetInvocation(BaseInvocation):
41+
"""Collect Bria ControlNet info to pass to denoiser node."""
42+
43+
control_image: ImageField = InputField(description="The control image")
44+
control_model: ModelIdentifierField = InputField(
45+
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
46+
)
47+
control_mode: BRIA_CONTROL_MODES = InputField(
48+
default="depth", description="The mode of the ControlNet"
49+
)
50+
control_weight: float | list[float] = InputField(
51+
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
52+
)
53+
begin_step_percent: float = Field(
54+
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
55+
)
56+
end_step_percent: float = Field(
57+
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
58+
)
59+
60+
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
61+
return BriaControlNetOutput(
62+
control=BriaControlNetField(
63+
control_image=self.control_image,
64+
model=self.control_model,
65+
mode=self.control_mode,
66+
controlnet_conditioning_scale=self.control_weight,
67+
control_guidance_start=self.begin_step_percent,
68+
control_guidance_end=self.end_step_percent,
69+
),
70+
)

invokeai/nodes/bria_nodes/bria_denoiser.py

Lines changed: 149 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
from typing import List, Tuple
2+
from PIL import Image
3+
from diffusers.pipelines import AutoencoderKL
4+
from invokeai.backend.bria.controlnet import BriaControlModes, BriaMultiControlNetModel
5+
from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField
6+
from diffusers.image_processor import VaeImageProcessor
7+
18
import torch
29
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
310

@@ -68,6 +75,11 @@ class BriaDenoiseInvocation(BaseInvocation):
6875
input=Input.Connection,
6976
title="Text IDs",
7077
)
78+
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
79+
description="ControlNet",
80+
input=Input.Connection,
81+
title="ControlNet",
82+
)
7183

7284
@torch.no_grad()
7385
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
@@ -83,16 +95,29 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
8395
with (
8496
context.models.load(self.transformer.transformer) as transformer,
8597
context.models.load(scheduler_identifier) as scheduler,
98+
context.models.load(self.vae.vae) as vae,
8699
):
87100
assert isinstance(transformer, BriaTransformer2DModel)
88101
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
102+
assert isinstance(vae, AutoencoderKL)
89103
dtype = transformer.dtype
90104
device = transformer.device
91105
latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds))
92106
prompt_embeds = torch.cat([neg_embeds, pos_embeds]) if self.guidance_scale > 1 else pos_embeds
93107

94108
sigmas = get_original_sigmas(1000, self.num_steps)
95109
timesteps, _ = retrieve_timesteps(scheduler, self.num_steps, device, None, sigmas, mu=0.0)
110+
width, height = latents.shape[-2:]
111+
width, height = 1024, 1024
112+
if self.control is not None:
113+
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
114+
context=context,
115+
width=width,
116+
height=height,
117+
device=device,
118+
num_channels_latents=transformer.config.in_channels // 4
119+
120+
)
96121

97122
for t in timesteps:
98123
# Prepare model input efficiently
@@ -101,11 +126,21 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
101126
else:
102127
latent_model_input = latents
103128

104-
# Prepare timestep tensor efficiently
105-
if isinstance(t, torch.Tensor):
106-
timestep_tensor = t.expand(latent_model_input.shape[0])
107-
else:
108-
timestep_tensor = torch.tensor([t] * latent_model_input.shape[0], device=device, dtype=torch.float32)
129+
timestep_tensor = t.expand(latent_model_input.shape[0])
130+
131+
controlnet_block_samples, controlnet_single_block_samples = None, None
132+
if self.control is not None:
133+
controlnet_block_samples, controlnet_single_block_samples = control_model(
134+
hidden_states=latents,
135+
controlnet_cond=control_images, # type: ignore
136+
controlnet_mode=control_modes, # type: ignore
137+
conditioning_scale=control_scales, # type: ignore
138+
timestep=timestep_tensor,
139+
encoder_hidden_states=prompt_embeds,
140+
txt_ids=text_ids,
141+
img_ids=latent_image_ids,
142+
return_dict=False,
143+
)
109144

110145
noise_pred = transformer(
111146
latent_model_input,
@@ -115,6 +150,8 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
115150
txt_ids=text_ids,
116151
guidance=None,
117152
return_dict=False,
153+
controlnet_block_samples=controlnet_block_samples,
154+
controlnet_single_block_samples=controlnet_single_block_samples,
118155
)[0]
119156

120157
if self.guidance_scale > 1:
@@ -131,3 +168,110 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
131168
saved_input_latents_tensor = context.tensors.save(latents)
132169
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
133170
return BriaDenoiseInvocationOutput(latents=latents_output)
171+
172+
173+
174+
def _prepare_multi_control(
175+
self,
176+
context: InvocationContext,
177+
width: int,
178+
height: int,
179+
device: torch.device,
180+
num_channels_latents: int
181+
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
182+
183+
control = self.control if isinstance(self.control, list) else [self.control]
184+
control_images, control_models, control_modes, control_scales = [], [], [], []
185+
for controlnet in control:
186+
control_models.append(context.models.load(controlnet.model))
187+
control_images.append(context.images.get_pil(controlnet.image))
188+
control_modes.append(BriaControlModes[controlnet.mode].value)
189+
control_scales.append(controlnet.controlnet_conditioning_scale)
190+
191+
control_model = BriaMultiControlNetModel(control_models)
192+
tensored_control_images, tensored_control_modes = self._prepare_control_images(control_images, control_modes, device, dtype, num_channels_latents)
193+
return control_model, tensored_control_images, tensored_control_modes, control_scales
194+
195+
196+
def _prepare_control_images(
197+
self,
198+
control_images: list[Image.Image],
199+
control_modes: list[int],
200+
device: torch.device,
201+
dtype: torch.dtype,
202+
num_channels_latents: int
203+
) -> Tuple[torch.Tensor, List[int]]:
204+
205+
tensored_control_images = []
206+
tensored_control_modes = []
207+
for idx, control_image_ in enumerate(control_images):
208+
tensored_control_image = self.prepare_image(
209+
image=control_image_,
210+
width=width,
211+
height=height,
212+
device=device,
213+
dtype=vae.dtype,
214+
)
215+
height, width = tensored_control_image.shape[-2:]
216+
217+
# vae encode
218+
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
219+
tensored_control_image = (tensored_control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
220+
221+
# pack
222+
height_control_image, width_control_image = tensored_control_image.shape[2:]
223+
tensored_control_image = self._pack_latents(
224+
tensored_control_image,
225+
height_control_image,
226+
width_control_image,
227+
)
228+
tensored_control_images.append(tensored_control_image)
229+
tensored_control_modes.append(torch.tensor(control_modes[idx]).expand(control_images[0].shape[0]).to(device, dtype=torch.long))
230+
231+
return tensored_control_images, tensored_control_modes
232+
233+
def prepare_image(
234+
self,
235+
image: Image.Image,
236+
width: int,
237+
height: int,
238+
device: torch.device,
239+
dtype: torch.dtype,
240+
) -> torch.Tensor:
241+
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
242+
image = image.repeat_interleave(1, dim=0)
243+
image = image.to(device=device, dtype=dtype)
244+
return image
245+
246+
def _pack_latents(self, latents, height, width):
247+
latents = latents.view(1, 1, height // 2, 2, width // 2, 2)
248+
latents = latents.permute(0, 2, 4, 1, 3, 5)
249+
latents = latents.reshape(1, (height // 2) * (width // 2), 4)
250+
251+
return latents
252+
253+
254+
255+
def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
256+
controlnet_keep = []
257+
for i in range(len(timesteps)):
258+
keeps = [
259+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
260+
for s, e in zip(control_guidance_start, control_guidance_end)
261+
]
262+
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
263+
return controlnet_keep
264+
265+
def get_control_start_end(self, control_guidance_start, control_guidance_end):
266+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
267+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
268+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
269+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
270+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
271+
mult = 1 # TODO - why is this 1?
272+
control_guidance_start, control_guidance_end = (
273+
mult * [control_guidance_start],
274+
mult * [control_guidance_end],
275+
)
276+
277+
return control_guidance_start, control_guidance_end

0 commit comments

Comments
 (0)