1
+ from typing import List , Tuple
2
+ from PIL import Image
3
+ from diffusers .pipelines import AutoencoderKL
4
+ from invokeai .backend .bria .controlnet import BriaControlModes , BriaMultiControlNetModel
5
+ from invokeai .nodes .bria_nodes .bria_controlnet import BriaControlNetField
6
+ from diffusers .image_processor import VaeImageProcessor
7
+
1
8
import torch
2
9
from diffusers .schedulers .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
3
10
@@ -68,6 +75,11 @@ class BriaDenoiseInvocation(BaseInvocation):
68
75
input = Input .Connection ,
69
76
title = "Text IDs" ,
70
77
)
78
+ control : BriaControlNetField | list [BriaControlNetField ] | None = InputField (
79
+ description = "ControlNet" ,
80
+ input = Input .Connection ,
81
+ title = "ControlNet" ,
82
+ )
71
83
72
84
@torch .no_grad ()
73
85
def invoke (self , context : InvocationContext ) -> BriaDenoiseInvocationOutput :
@@ -83,16 +95,29 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
83
95
with (
84
96
context .models .load (self .transformer .transformer ) as transformer ,
85
97
context .models .load (scheduler_identifier ) as scheduler ,
98
+ context .models .load (self .vae .vae ) as vae ,
86
99
):
87
100
assert isinstance (transformer , BriaTransformer2DModel )
88
101
assert isinstance (scheduler , FlowMatchEulerDiscreteScheduler )
102
+ assert isinstance (vae , AutoencoderKL )
89
103
dtype = transformer .dtype
90
104
device = transformer .device
91
105
latents , pos_embeds , neg_embeds = map (lambda x : x .to (device , dtype ), (latents , pos_embeds , neg_embeds ))
92
106
prompt_embeds = torch .cat ([neg_embeds , pos_embeds ]) if self .guidance_scale > 1 else pos_embeds
93
107
94
108
sigmas = get_original_sigmas (1000 , self .num_steps )
95
109
timesteps , _ = retrieve_timesteps (scheduler , self .num_steps , device , None , sigmas , mu = 0.0 )
110
+ width , height = latents .shape [- 2 :]
111
+ width , height = 1024 , 1024
112
+ if self .control is not None :
113
+ control_model , control_images , control_modes , control_scales = self ._prepare_multi_control (
114
+ context = context ,
115
+ width = width ,
116
+ height = height ,
117
+ device = device ,
118
+ num_channels_latents = transformer .config .in_channels // 4
119
+
120
+ )
96
121
97
122
for t in timesteps :
98
123
# Prepare model input efficiently
@@ -101,11 +126,21 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
101
126
else :
102
127
latent_model_input = latents
103
128
104
- # Prepare timestep tensor efficiently
105
- if isinstance (t , torch .Tensor ):
106
- timestep_tensor = t .expand (latent_model_input .shape [0 ])
107
- else :
108
- timestep_tensor = torch .tensor ([t ] * latent_model_input .shape [0 ], device = device , dtype = torch .float32 )
129
+ timestep_tensor = t .expand (latent_model_input .shape [0 ])
130
+
131
+ controlnet_block_samples , controlnet_single_block_samples = None , None
132
+ if self .control is not None :
133
+ controlnet_block_samples , controlnet_single_block_samples = control_model (
134
+ hidden_states = latents ,
135
+ controlnet_cond = control_images , # type: ignore
136
+ controlnet_mode = control_modes , # type: ignore
137
+ conditioning_scale = control_scales , # type: ignore
138
+ timestep = timestep_tensor ,
139
+ encoder_hidden_states = prompt_embeds ,
140
+ txt_ids = text_ids ,
141
+ img_ids = latent_image_ids ,
142
+ return_dict = False ,
143
+ )
109
144
110
145
noise_pred = transformer (
111
146
latent_model_input ,
@@ -115,6 +150,8 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
115
150
txt_ids = text_ids ,
116
151
guidance = None ,
117
152
return_dict = False ,
153
+ controlnet_block_samples = controlnet_block_samples ,
154
+ controlnet_single_block_samples = controlnet_single_block_samples ,
118
155
)[0 ]
119
156
120
157
if self .guidance_scale > 1 :
@@ -131,3 +168,110 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
131
168
saved_input_latents_tensor = context .tensors .save (latents )
132
169
latents_output = LatentsField (latents_name = saved_input_latents_tensor )
133
170
return BriaDenoiseInvocationOutput (latents = latents_output )
171
+
172
+
173
+
174
+ def _prepare_multi_control (
175
+ self ,
176
+ context : InvocationContext ,
177
+ width : int ,
178
+ height : int ,
179
+ device : torch .device ,
180
+ num_channels_latents : int
181
+ ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [torch .Tensor ], List [float ]]:
182
+
183
+ control = self .control if isinstance (self .control , list ) else [self .control ]
184
+ control_images , control_models , control_modes , control_scales = [], [], [], []
185
+ for controlnet in control :
186
+ control_models .append (context .models .load (controlnet .model ))
187
+ control_images .append (context .images .get_pil (controlnet .image ))
188
+ control_modes .append (BriaControlModes [controlnet .mode ].value )
189
+ control_scales .append (controlnet .controlnet_conditioning_scale )
190
+
191
+ control_model = BriaMultiControlNetModel (control_models )
192
+ tensored_control_images , tensored_control_modes = self ._prepare_control_images (control_images , control_modes , device , dtype , num_channels_latents )
193
+ return control_model , tensored_control_images , tensored_control_modes , control_scales
194
+
195
+
196
+ def _prepare_control_images (
197
+ self ,
198
+ control_images : list [Image .Image ],
199
+ control_modes : list [int ],
200
+ device : torch .device ,
201
+ dtype : torch .dtype ,
202
+ num_channels_latents : int
203
+ ) -> Tuple [torch .Tensor , List [int ]]:
204
+
205
+ tensored_control_images = []
206
+ tensored_control_modes = []
207
+ for idx , control_image_ in enumerate (control_images ):
208
+ tensored_control_image = self .prepare_image (
209
+ image = control_image_ ,
210
+ width = width ,
211
+ height = height ,
212
+ device = device ,
213
+ dtype = vae .dtype ,
214
+ )
215
+ height , width = tensored_control_image .shape [- 2 :]
216
+
217
+ # vae encode
218
+ tensored_control_image = vae .encode (tensored_control_image ).latent_dist .sample ()
219
+ tensored_control_image = (tensored_control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
220
+
221
+ # pack
222
+ height_control_image , width_control_image = tensored_control_image .shape [2 :]
223
+ tensored_control_image = self ._pack_latents (
224
+ tensored_control_image ,
225
+ height_control_image ,
226
+ width_control_image ,
227
+ )
228
+ tensored_control_images .append (tensored_control_image )
229
+ tensored_control_modes .append (torch .tensor (control_modes [idx ]).expand (control_images [0 ].shape [0 ]).to (device , dtype = torch .long ))
230
+
231
+ return tensored_control_images , tensored_control_modes
232
+
233
+ def prepare_image (
234
+ self ,
235
+ image : Image .Image ,
236
+ width : int ,
237
+ height : int ,
238
+ device : torch .device ,
239
+ dtype : torch .dtype ,
240
+ ) -> torch .Tensor :
241
+ image = VaeImageProcessor (vae_scale_factor = 16 ).preprocess (image , height = height , width = width )
242
+ image = image .repeat_interleave (1 , dim = 0 )
243
+ image = image .to (device = device , dtype = dtype )
244
+ return image
245
+
246
+ def _pack_latents (self , latents , height , width ):
247
+ latents = latents .view (1 , 1 , height // 2 , 2 , width // 2 , 2 )
248
+ latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
249
+ latents = latents .reshape (1 , (height // 2 ) * (width // 2 ), 4 )
250
+
251
+ return latents
252
+
253
+
254
+
255
+ def get_controlnet_keep (self , timesteps , control_guidance_start , control_guidance_end ):
256
+ controlnet_keep = []
257
+ for i in range (len (timesteps )):
258
+ keeps = [
259
+ 1.0 - float (i / len (timesteps ) < s or (i + 1 ) / len (timesteps ) > e )
260
+ for s , e in zip (control_guidance_start , control_guidance_end )
261
+ ]
262
+ controlnet_keep .append (keeps [0 ] if isinstance (self .controlnet , BriaControlNetModel ) else keeps )
263
+ return controlnet_keep
264
+
265
+ def get_control_start_end (self , control_guidance_start , control_guidance_end ):
266
+ if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
267
+ control_guidance_start = len (control_guidance_end ) * [control_guidance_start ]
268
+ elif not isinstance (control_guidance_end , list ) and isinstance (control_guidance_start , list ):
269
+ control_guidance_end = len (control_guidance_start ) * [control_guidance_end ]
270
+ elif not isinstance (control_guidance_start , list ) and not isinstance (control_guidance_end , list ):
271
+ mult = 1 # TODO - why is this 1?
272
+ control_guidance_start , control_guidance_end = (
273
+ mult * [control_guidance_start ],
274
+ mult * [control_guidance_end ],
275
+ )
276
+
277
+ return control_guidance_start , control_guidance_end
0 commit comments