1
+ from typing import List , Tuple
2
+ from diffusers .models .autoencoders .autoencoder_kl import AutoencoderKL
3
+ from invokeai .backend .bria .controlnet_bria import BriaControlModes , BriaMultiControlNetModel
4
+ from invokeai .backend .bria .controlnet_utils import prepare_control_images
5
+ from invokeai .nodes .bria_nodes .bria_controlnet import BriaControlNetField
6
+
1
7
import torch
2
8
from diffusers .schedulers .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
3
9
4
- from invokeai .app .invocations .fields import Input , InputField
5
- from invokeai .app .invocations .model import SubModelType , TransformerField
6
- from invokeai .app .invocations .primitives import (
7
- BaseInvocationOutput ,
8
- FieldDescriptions ,
9
- Input ,
10
- InputField ,
11
- LatentsField ,
12
- OutputField ,
13
- )
10
+ from invokeai .app .invocations .fields import Input , InputField , LatentsField , OutputField
11
+ from invokeai .app .invocations .model import SubModelType , TransformerField , VAEField
12
+ from invokeai .app .invocations .primitives import BaseInvocationOutput , FieldDescriptions
14
13
from invokeai .app .services .shared .invocation_context import InvocationContext
15
14
from invokeai .invocation_api import BaseInvocation , Classification , InputField , invocation , invocation_output
16
15
@@ -43,6 +42,11 @@ class BriaDenoiseInvocation(BaseInvocation):
43
42
input = Input .Connection ,
44
43
title = "Transformer" ,
45
44
)
45
+ vae : VAEField = InputField (
46
+ description = FieldDescriptions .vae ,
47
+ input = Input .Connection ,
48
+ title = "VAE" ,
49
+ )
46
50
latents : LatentsField = InputField (
47
51
description = "Latents to denoise" ,
48
52
input = Input .Connection ,
@@ -68,6 +72,12 @@ class BriaDenoiseInvocation(BaseInvocation):
68
72
input = Input .Connection ,
69
73
title = "Text IDs" ,
70
74
)
75
+ control : BriaControlNetField | list [BriaControlNetField ] | None = InputField (
76
+ description = "ControlNet" ,
77
+ input = Input .Connection ,
78
+ title = "ControlNet" ,
79
+ default = None ,
80
+ )
71
81
72
82
@torch .no_grad ()
73
83
def invoke (self , context : InvocationContext ) -> BriaDenoiseInvocationOutput :
@@ -83,16 +93,28 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
83
93
with (
84
94
context .models .load (self .transformer .transformer ) as transformer ,
85
95
context .models .load (scheduler_identifier ) as scheduler ,
96
+ context .models .load (self .vae .vae ) as vae ,
86
97
):
87
98
assert isinstance (transformer , BriaTransformer2DModel )
88
99
assert isinstance (scheduler , FlowMatchEulerDiscreteScheduler )
100
+ assert isinstance (vae , AutoencoderKL )
89
101
dtype = transformer .dtype
90
102
device = transformer .device
91
103
latents , pos_embeds , neg_embeds = map (lambda x : x .to (device , dtype ), (latents , pos_embeds , neg_embeds ))
92
104
prompt_embeds = torch .cat ([neg_embeds , pos_embeds ]) if self .guidance_scale > 1 else pos_embeds
93
105
94
106
sigmas = get_original_sigmas (1000 , self .num_steps )
95
107
timesteps , _ = retrieve_timesteps (scheduler , self .num_steps , device , None , sigmas , mu = 0.0 )
108
+ width , height = 1024 , 1024
109
+ if self .control is not None :
110
+ control_model , control_images , control_modes , control_scales = self ._prepare_multi_control (
111
+ context = context ,
112
+ vae = vae ,
113
+ width = width ,
114
+ height = height ,
115
+ device = device ,
116
+
117
+ )
96
118
97
119
for t in timesteps :
98
120
# Prepare model input efficiently
@@ -101,11 +123,21 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
101
123
else :
102
124
latent_model_input = latents
103
125
104
- # Prepare timestep tensor efficiently
105
- if isinstance (t , torch .Tensor ):
106
- timestep_tensor = t .expand (latent_model_input .shape [0 ])
107
- else :
108
- timestep_tensor = torch .tensor ([t ] * latent_model_input .shape [0 ], device = device , dtype = torch .float32 )
126
+ timestep_tensor = t .expand (latent_model_input .shape [0 ])
127
+
128
+ controlnet_block_samples , controlnet_single_block_samples = None , None
129
+ if self .control is not None :
130
+ controlnet_block_samples , controlnet_single_block_samples = control_model (
131
+ hidden_states = latents ,
132
+ controlnet_cond = control_images , # type: ignore
133
+ controlnet_mode = control_modes , # type: ignore
134
+ conditioning_scale = control_scales , # type: ignore
135
+ timestep = timestep_tensor ,
136
+ encoder_hidden_states = prompt_embeds ,
137
+ txt_ids = text_ids ,
138
+ img_ids = latent_image_ids ,
139
+ return_dict = False ,
140
+ )
109
141
110
142
noise_pred = transformer (
111
143
latent_model_input ,
@@ -115,6 +147,8 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
115
147
txt_ids = text_ids ,
116
148
guidance = None ,
117
149
return_dict = False ,
150
+ controlnet_block_samples = controlnet_block_samples ,
151
+ controlnet_single_block_samples = controlnet_single_block_samples ,
118
152
)[0 ]
119
153
120
154
if self .guidance_scale > 1 :
@@ -131,3 +165,35 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
131
165
saved_input_latents_tensor = context .tensors .save (latents )
132
166
latents_output = LatentsField (latents_name = saved_input_latents_tensor )
133
167
return BriaDenoiseInvocationOutput (latents = latents_output )
168
+
169
+
170
+
171
+ def _prepare_multi_control (
172
+ self ,
173
+ context : InvocationContext ,
174
+ vae : AutoencoderKL ,
175
+ width : int ,
176
+ height : int ,
177
+ device : torch .device
178
+ ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [torch .Tensor ], List [float ]]:
179
+
180
+ control = self .control if isinstance (self .control , list ) else [self .control ]
181
+ control_images , control_models , control_modes , control_scales = [], [], [], []
182
+ for controlnet in control :
183
+ if controlnet is not None :
184
+ control_models .append (context .models .load (controlnet .model ).model )
185
+ control_images .append (context .images .get_pil (controlnet .image .image_name ))
186
+ control_modes .append (BriaControlModes [controlnet .mode ].value )
187
+ control_scales .append (controlnet .conditioning_scale )
188
+
189
+ control_model = BriaMultiControlNetModel (control_models ).to (device )
190
+ tensored_control_images , tensored_control_modes = prepare_control_images (
191
+ vae = vae ,
192
+ control_images = control_images ,
193
+ control_modes = control_modes ,
194
+ width = width ,
195
+ height = height ,
196
+ device = device ,
197
+ )
198
+ return control_model , tensored_control_images , tensored_control_modes , control_scales
199
+
0 commit comments