5
5
import numpy as np
6
6
from .attention_processor import IPAFluxAttnProcessor2_0
7
7
from .utils import is_model_pathched , FluxUpdateModules
8
+ from .sd3 .resampler import TimeResampler
9
+ from .sd3 .joinblock import JointBlockIPWrapper , IPAttnProcessor
8
10
9
11
image_proj_model = None
10
12
class MLPProjModel (torch .nn .Module ):
@@ -95,7 +97,7 @@ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
95
97
image_prompt_embeds = image_proj_model (clip_image_embeds )
96
98
return image_prompt_embeds
97
99
98
- def apply_ipadapter_flux (self , model , ipadapter , image , weight , start_at , end_at , provider = None , use_tiled = False ):
100
+ def apply_ipadapter (self , model , ipadapter , image , weight , start_at , end_at , provider = None , use_tiled = False ):
99
101
self .device = provider .lower ()
100
102
if "clipvision" in ipadapter :
101
103
# self.clip_vision = ipadapter["clipvision"]['model']
@@ -127,3 +129,140 @@ def apply_ipadapter_flux(self, model, ipadapter, image, weight, start_at, end_at
127
129
128
130
return (bi , image )
129
131
132
+
133
+ def patch_sd3 (
134
+ patcher ,
135
+ ip_procs ,
136
+ resampler : TimeResampler ,
137
+ clip_embeds ,
138
+ weight = 1.0 ,
139
+ start = 0.0 ,
140
+ end = 1.0 ,
141
+ ):
142
+ """
143
+ Patches a model_sampler to add the ipadapter
144
+ """
145
+ mmdit = patcher .model .diffusion_model
146
+ timestep_schedule_max = patcher .model .model_config .sampling_settings .get (
147
+ "timesteps" , 1000
148
+ )
149
+ # hook the model's forward function
150
+ # so that when it gets called, we can grab the timestep and send it to the resampler
151
+ ip_options = {
152
+ "hidden_states" : None ,
153
+ "t_emb" : None ,
154
+ "weight" : weight ,
155
+ }
156
+
157
+ def ddit_wrapper (forward , args ):
158
+ # this is between 0 and 1, so the adapters can calculate start_point and end_point
159
+ # actually, do we need to get the sigma value instead?
160
+ t_percent = 1 - args ["timestep" ].flatten ()[0 ].cpu ().item ()
161
+ if start <= t_percent <= end :
162
+ batch_size = args ["input" ].shape [0 ] // len (args ["cond_or_uncond" ])
163
+ # if we're only doing cond or only doing uncond, only pass one of them through the resampler
164
+ embeds = clip_embeds [args ["cond_or_uncond" ]]
165
+ # slight efficiency optimization todo: pass the embeds through and then afterwards
166
+ # repeat to the batch size
167
+ embeds = torch .repeat_interleave (embeds , batch_size , dim = 0 )
168
+ # the resampler wants between 0 and MAX_STEPS
169
+ timestep = args ["timestep" ] * timestep_schedule_max
170
+ image_emb , t_emb = resampler (embeds , timestep , need_temb = True )
171
+ # these will need to be accessible to the IPAdapters
172
+ ip_options ["hidden_states" ] = image_emb
173
+ ip_options ["t_emb" ] = t_emb
174
+ else :
175
+ ip_options ["hidden_states" ] = None
176
+ ip_options ["t_emb" ] = None
177
+
178
+ return forward (args ["input" ], args ["timestep" ], ** args ["c" ])
179
+
180
+ patcher .set_model_unet_function_wrapper (ddit_wrapper )
181
+ # patch each dit block
182
+ for i , block in enumerate (mmdit .joint_blocks ):
183
+ wrapper = JointBlockIPWrapper (block , ip_procs [i ], ip_options )
184
+ patcher .set_model_patch_replace (wrapper , "dit" , "double_block" , i )
185
+
186
+ class InstantXSD3IpadapterApply :
187
+ def __init__ (self ):
188
+ self .device = None
189
+ self .dtype = torch .float16
190
+ self .clip_image_processor = None
191
+ self .image_encoder = None
192
+ self .resampler = None
193
+ self .procs = None
194
+
195
+ @torch .inference_mode ()
196
+ def encode (self , image ):
197
+ clip_image = self .clip_image_processor .image_processor (image , return_tensors = "pt" , do_rescale = False ).pixel_values
198
+ clip_image_embeds = self .image_encoder (
199
+ clip_image .to (self .device , dtype = self .image_encoder .dtype ),
200
+ output_hidden_states = True ,
201
+ ).hidden_states [- 2 ]
202
+ clip_image_embeds = torch .cat (
203
+ [clip_image_embeds , torch .zeros_like (clip_image_embeds )], dim = 0
204
+ )
205
+ clip_image_embeds = clip_image_embeds .to (dtype = torch .float16 )
206
+ return clip_image_embeds
207
+
208
+ def apply_ipadapter (self , model , ipadapter , image , weight , start_at , end_at , provider = None , use_tiled = False ):
209
+ self .device = provider .lower ()
210
+ if "clipvision" in ipadapter :
211
+ self .image_encoder = ipadapter ["clipvision" ]['model' ]['image_encoder' ].to (self .device , dtype = self .dtype )
212
+ self .clip_image_processor = ipadapter ["clipvision" ]['model' ]['clip_image_processor' ]
213
+ if "ipadapter" in ipadapter :
214
+ self .ip_ckpt = ipadapter ["ipadapter" ]['file' ]
215
+ self .state_dict = ipadapter ["ipadapter" ]['model' ]
216
+
217
+ self .resampler = TimeResampler (
218
+ dim = 1280 ,
219
+ depth = 4 ,
220
+ dim_head = 64 ,
221
+ heads = 20 ,
222
+ num_queries = 64 ,
223
+ embedding_dim = 1152 ,
224
+ output_dim = 2432 ,
225
+ ff_mult = 4 ,
226
+ timestep_in_dim = 320 ,
227
+ timestep_flip_sin_to_cos = True ,
228
+ timestep_freq_shift = 0 ,
229
+ )
230
+ self .resampler .eval ()
231
+ self .resampler .to (self .device , dtype = self .dtype )
232
+ self .resampler .load_state_dict (self .state_dict ["image_proj" ])
233
+
234
+ # now we'll create the attention processors
235
+ # ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
236
+ n_procs = len (
237
+ set (x .split ("." )[0 ] for x in self .state_dict ["ip_adapter" ].keys ())
238
+ )
239
+ self .procs = torch .nn .ModuleList (
240
+ [
241
+ # this is hardcoded for SD3.5L
242
+ IPAttnProcessor (
243
+ hidden_size = 2432 ,
244
+ cross_attention_dim = 2432 ,
245
+ ip_hidden_states_dim = 2432 ,
246
+ ip_encoder_hidden_states_dim = 2432 ,
247
+ head_dim = 64 ,
248
+ timesteps_emb_dim = 1280 ,
249
+ ).to (self .device , dtype = torch .float16 )
250
+ for _ in range (n_procs )
251
+ ]
252
+ )
253
+ self .procs .load_state_dict (self .state_dict ["ip_adapter" ])
254
+
255
+ work_model = model .clone ()
256
+ embeds = self .encode (image )
257
+
258
+ patch_sd3 (
259
+ work_model ,
260
+ self .procs ,
261
+ self .resampler ,
262
+ embeds ,
263
+ weight ,
264
+ start_at ,
265
+ end_at ,
266
+ )
267
+
268
+ return (work_model , image )
0 commit comments