Skip to content

Commit 1076d8e

Browse files
Ilan TchenakUbuntu
authored andcommitted
Added support for bria's controlnet
1 parent 7140f2e commit 1076d8e

File tree

8 files changed

+761
-26
lines changed

8 files changed

+761
-26
lines changed

invokeai/backend/bria/controlnet_bria.py

Lines changed: 543 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import List, Tuple
2+
from PIL import Image
3+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
4+
5+
from diffusers.image_processor import VaeImageProcessor
6+
7+
import torch
8+
9+
10+
11+
@torch.no_grad()
12+
def prepare_control_images(
13+
vae: AutoencoderKL,
14+
control_images: list[Image.Image],
15+
control_modes: list[int],
16+
width: int,
17+
height: int,
18+
device: torch.device,
19+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
20+
21+
tensored_control_images = []
22+
tensored_control_modes = []
23+
for idx, control_image_ in enumerate(control_images):
24+
tensored_control_image = _prepare_image(
25+
image=control_image_,
26+
width=width,
27+
height=height,
28+
device=device,
29+
dtype=vae.dtype,
30+
)
31+
height, width = tensored_control_image.shape[-2:]
32+
33+
# vae encode
34+
tensored_control_image = vae.encode(tensored_control_image).latent_dist.sample()
35+
tensored_control_image = (tensored_control_image) * 16
36+
37+
# pack
38+
height_control_image, width_control_image = tensored_control_image.shape[2:]
39+
tensored_control_image = _pack_latents(
40+
tensored_control_image,
41+
height_control_image,
42+
width_control_image,
43+
)
44+
tensored_control_images.append(tensored_control_image)
45+
tensored_control_modes.append(torch.tensor(control_modes[idx]).expand(
46+
tensored_control_image.shape[0]).to(device, dtype=torch.long))
47+
48+
return tensored_control_images, tensored_control_modes
49+
50+
def _prepare_image(
51+
image: Image.Image,
52+
width: int,
53+
height: int,
54+
device: torch.device,
55+
dtype: torch.dtype,
56+
) -> torch.Tensor:
57+
image = image.convert("RGB")
58+
image = VaeImageProcessor(vae_scale_factor=16).preprocess(image, height=height, width=width)
59+
image = image.repeat_interleave(1, dim=0)
60+
image = image.to(device=device, dtype=dtype)
61+
return image
62+
63+
def _pack_latents(latents, height, width):
64+
latents = latents.view(1, 4, height // 2, 2, width // 2, 2)
65+
latents = latents.permute(0, 2, 4, 1, 3, 5)
66+
latents = latents.reshape(1, (height // 2) * (width // 2), 16)
67+
68+
return latents
69+

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class ModelProbe(object):
126126

127127
CLASS2TYPE = {
128128
"BriaPipeline": ModelType.Main,
129-
"BriaControlNetModel": ModelType.ControlNet,
129+
"BriaTransformer2DModel": ModelType.ControlNet,
130130
"FluxPipeline": ModelType.Main,
131131
"StableDiffusionPipeline": ModelType.Main,
132132
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -1014,7 +1014,7 @@ def get_base_type(self) -> BaseModelType:
10141014
if config.get("_class_name", None) == "FluxControlNetModel":
10151015
return BaseModelType.Flux
10161016

1017-
if config.get("_class_name", None) == "BriaControlNetModel":
1017+
if config.get("_class_name", None) == "BriaTransformer2DModel":
10181018
return BaseModelType.Bria
10191019

10201020
# no obvious way to distinguish between sd2-base and sd2-768

invokeai/backend/model_manager/load/model_loaders/bria.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,11 @@ def _load_model(
3131
if isinstance(config, ControlNetCheckpointConfig):
3232
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
3333

34-
if submodel_type is None:
35-
raise Exception("A submodel type must be provided when loading control net pipelines.")
36-
3734
model_path = Path(config.path)
38-
load_class = self.get_hf_load_class(model_path, submodel_type)
35+
load_class = self.get_hf_load_class(model_path)
3936
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
4037
variant = repo_variant.value if repo_variant else None
41-
model_path = model_path / submodel_type.value
38+
model_path = model_path
4239

4340
dtype = self._torch_dtype
4441

invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: #
8484
]:
8585
if module == "transformer_bria":
8686
module = "invokeai.backend.bria.transformer_bria"
87+
elif class_name == "BriaTransformer2DModel":
88+
class_name = "BriaControlNetModel"
89+
module = "invokeai.backend.bria.controlnet_bria"
8790
res_type = sys.modules[module]
8891
else:
8992
res_type = sys.modules["diffusers"].pipelines

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BriaControlNetModelFieldInputComponent.tsx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import type {
66
BriaControlNetModelFieldInputTemplate,
77
} from 'features/nodes/types/field';
88
import { memo, useCallback } from 'react';
9-
import { useBriaModels } from 'services/api/hooks/modelsByType';
10-
import type { MainModelConfig } from 'services/api/types';
9+
import { useBriaControlNetModels } from 'services/api/hooks/modelsByType';
10+
import type { ControlNetModelConfig } from 'services/api/types';
1111

1212
import type { FieldComponentProps } from './types';
1313

@@ -16,9 +16,9 @@ type Props = FieldComponentProps<BriaControlNetModelFieldInputInstance, BriaCont
1616
const BriaControlNetModelFieldInputComponent = (props: Props) => {
1717
const { nodeId, field } = props;
1818
const dispatch = useAppDispatch();
19-
const [modelConfigs, { isLoading }] = useBriaModels();
19+
const [modelConfigs, { isLoading }] = useBriaControlNetModels();
2020
const onChange = useCallback(
21-
(value: MainModelConfig | null) => {
21+
(value: ControlNetModelConfig | null) => {
2222
if (!value) {
2323
return;
2424
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
2+
from pydantic import BaseModel, Field
3+
4+
from invokeai.app.invocations.baseinvocation import (
5+
BaseInvocation,
6+
BaseInvocationOutput,
7+
invocation,
8+
invocation_output,
9+
)
10+
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
11+
from invokeai.app.invocations.model import ModelIdentifierField
12+
from invokeai.app.services.shared.invocation_context import InvocationContext
13+
14+
15+
class BriaControlNetField(BaseModel):
16+
image: ImageField = Field(description="The control image")
17+
model: ModelIdentifierField = Field(description="The ControlNet model to use")
18+
mode: BRIA_CONTROL_MODES = Field(description="The mode of the ControlNet")
19+
conditioning_scale: float = Field(description="The weight given to the ControlNet")
20+
21+
@invocation_output("flux_controlnet_output")
22+
class BriaControlNetOutput(BaseInvocationOutput):
23+
"""FLUX ControlNet info"""
24+
25+
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
26+
27+
28+
@invocation(
29+
"bria_controlnet",
30+
title="Bria ControlNet",
31+
tags=["controlnet", "bria"],
32+
category="controlnet",
33+
version="1.0.0",
34+
)
35+
class BriaControlNetInvocation(BaseInvocation):
36+
"""Collect Bria ControlNet info to pass to denoiser node."""
37+
38+
control_image: ImageField = InputField(description="The control image")
39+
control_model: ModelIdentifierField = InputField(
40+
description=FieldDescriptions.controlnet_model, ui_type=UIType.BriaControlNetModel
41+
)
42+
control_mode: BRIA_CONTROL_MODES = InputField(
43+
default="depth", description="The mode of the ControlNet"
44+
)
45+
control_weight: float = InputField(
46+
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
47+
)
48+
49+
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
50+
return BriaControlNetOutput(
51+
control=BriaControlNetField(
52+
image=self.control_image,
53+
model=self.control_model,
54+
mode=self.control_mode,
55+
conditioning_scale=self.control_weight,
56+
),
57+
)

invokeai/nodes/bria_nodes/bria_denoiser.py

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
from typing import List, Tuple
2+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
3+
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
4+
from invokeai.backend.bria.controlnet_utils import prepare_control_images
5+
from invokeai.nodes.bria_nodes.bria_controlnet import BriaControlNetField
6+
17
import torch
28
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
39

4-
from invokeai.app.invocations.fields import Input, InputField
5-
from invokeai.app.invocations.model import SubModelType, TransformerField
6-
from invokeai.app.invocations.primitives import (
7-
BaseInvocationOutput,
8-
FieldDescriptions,
9-
Input,
10-
InputField,
11-
LatentsField,
12-
OutputField,
13-
)
10+
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
11+
from invokeai.app.invocations.model import SubModelType, TransformerField, VAEField
12+
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
1413
from invokeai.app.services.shared.invocation_context import InvocationContext
1514
from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output
1615

@@ -43,6 +42,11 @@ class BriaDenoiseInvocation(BaseInvocation):
4342
input=Input.Connection,
4443
title="Transformer",
4544
)
45+
vae: VAEField = InputField(
46+
description=FieldDescriptions.vae,
47+
input=Input.Connection,
48+
title="VAE",
49+
)
4650
latents: LatentsField = InputField(
4751
description="Latents to denoise",
4852
input=Input.Connection,
@@ -68,6 +72,12 @@ class BriaDenoiseInvocation(BaseInvocation):
6872
input=Input.Connection,
6973
title="Text IDs",
7074
)
75+
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
76+
description="ControlNet",
77+
input=Input.Connection,
78+
title="ControlNet",
79+
default = None,
80+
)
7181

7282
@torch.no_grad()
7383
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
@@ -83,16 +93,28 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
8393
with (
8494
context.models.load(self.transformer.transformer) as transformer,
8595
context.models.load(scheduler_identifier) as scheduler,
96+
context.models.load(self.vae.vae) as vae,
8697
):
8798
assert isinstance(transformer, BriaTransformer2DModel)
8899
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
100+
assert isinstance(vae, AutoencoderKL)
89101
dtype = transformer.dtype
90102
device = transformer.device
91103
latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds))
92104
prompt_embeds = torch.cat([neg_embeds, pos_embeds]) if self.guidance_scale > 1 else pos_embeds
93105

94106
sigmas = get_original_sigmas(1000, self.num_steps)
95107
timesteps, _ = retrieve_timesteps(scheduler, self.num_steps, device, None, sigmas, mu=0.0)
108+
width, height = 1024, 1024
109+
if self.control is not None:
110+
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
111+
context=context,
112+
vae=vae,
113+
width=width,
114+
height=height,
115+
device=device,
116+
117+
)
96118

97119
for t in timesteps:
98120
# Prepare model input efficiently
@@ -101,11 +123,21 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
101123
else:
102124
latent_model_input = latents
103125

104-
# Prepare timestep tensor efficiently
105-
if isinstance(t, torch.Tensor):
106-
timestep_tensor = t.expand(latent_model_input.shape[0])
107-
else:
108-
timestep_tensor = torch.tensor([t] * latent_model_input.shape[0], device=device, dtype=torch.float32)
126+
timestep_tensor = t.expand(latent_model_input.shape[0])
127+
128+
controlnet_block_samples, controlnet_single_block_samples = None, None
129+
if self.control is not None:
130+
controlnet_block_samples, controlnet_single_block_samples = control_model(
131+
hidden_states=latents,
132+
controlnet_cond=control_images, # type: ignore
133+
controlnet_mode=control_modes, # type: ignore
134+
conditioning_scale=control_scales, # type: ignore
135+
timestep=timestep_tensor,
136+
encoder_hidden_states=prompt_embeds,
137+
txt_ids=text_ids,
138+
img_ids=latent_image_ids,
139+
return_dict=False,
140+
)
109141

110142
noise_pred = transformer(
111143
latent_model_input,
@@ -115,6 +147,8 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
115147
txt_ids=text_ids,
116148
guidance=None,
117149
return_dict=False,
150+
controlnet_block_samples=controlnet_block_samples,
151+
controlnet_single_block_samples=controlnet_single_block_samples,
118152
)[0]
119153

120154
if self.guidance_scale > 1:
@@ -131,3 +165,35 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
131165
saved_input_latents_tensor = context.tensors.save(latents)
132166
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
133167
return BriaDenoiseInvocationOutput(latents=latents_output)
168+
169+
170+
171+
def _prepare_multi_control(
172+
self,
173+
context: InvocationContext,
174+
vae: AutoencoderKL,
175+
width: int,
176+
height: int,
177+
device: torch.device
178+
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
179+
180+
control = self.control if isinstance(self.control, list) else [self.control]
181+
control_images, control_models, control_modes, control_scales = [], [], [], []
182+
for controlnet in control:
183+
if controlnet is not None:
184+
control_models.append(context.models.load(controlnet.model).model)
185+
control_images.append(context.images.get_pil(controlnet.image.image_name))
186+
control_modes.append(BriaControlModes[controlnet.mode].value)
187+
control_scales.append(controlnet.conditioning_scale)
188+
189+
control_model = BriaMultiControlNetModel(control_models).to(device)
190+
tensored_control_images, tensored_control_modes = prepare_control_images(
191+
vae=vae,
192+
control_images=control_images,
193+
control_modes=control_modes,
194+
width=width,
195+
height=height,
196+
device=device,
197+
)
198+
return control_model, tensored_control_images, tensored_control_modes, control_scales
199+

0 commit comments

Comments
 (0)