Skip to content

Commit a46d80b

Browse files
committed
Renamed FLUX.1-dev to REGULAR - FLUX and SD3.5 only (high strength) preset on easy ipadapterApply
1 parent da57b55 commit a46d80b

File tree

10 files changed

+775
-38
lines changed

10 files changed

+775
-38
lines changed

README.ZH_CN.md

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,13 @@ git clone https://github.com/yolain/ComfyUI-Easy-Use
5050
双击install.bat安装依赖
5151
```
5252

53-
## 👨🏻‍🚀 计划
54-
55-
- [x] 更新便于维护的新前端代码
56-
- [x] 使用sass维护css样式
57-
- [x] 对原有扩展进行优化
58-
- [x] 增加新的组件(如节点时间统计等)
59-
- [ ][ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows)中上传更多的工作流(如kolors,sd3等),并更新english版本的readme
60-
- [ ] 更详细功能介绍的 gitbook
61-
6253
## 📜 更新日志
6354

6455
**v1.2.5**
6556

6657
-`easy preSamplingCustom``easy preSamplingAdvanced` 上增加 `enable (GPU=A1111)` 噪波生成模式选择项
6758
- 增加 `easy makeImageForICLora`
68-
-`easy ipadapterApply` 添加 `FLUX.1-dev` 预置项以支持 InstantX Flux ipadapter
59+
-`easy ipadapterApply` 添加 `REGULAR - FLUX and SD3.5 only (high strength)` 预置项以支持 InstantX Flux ipadapter
6960
- 修复brushnet 无法在 `--fast` 模式下使用
7061
- 支持briaai RMBG-2.0
7162
- 支持mochi模型

README.md

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,13 @@ git clone https://github.com/yolain/ComfyUI-Easy-Use
4545
Double-click install.bat to install the required dependencies
4646
```
4747

48-
## 👨🏻‍🚀 Plan
49-
50-
- [x] Updated new front-end code for easier maintenance
51-
- [x] Maintain css styles using sass
52-
- [x] Optimize existing extensions
53-
- [x] Add new components
54-
- [ ] Upload new workflows to [ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows) and translate readme to english version.
55-
- [ ] Write gitbook with more detailed function introdution
56-
5748
## 📜 Changelog
5849

5950
**v1.2.5**
6051

6152
- Added `enable (GPU=A1111)` noise mode on `easy preSamplingCustom` and `easy preSamplingAdvanced`
6253
- Added `easy makeImageForICLora`
63-
- Added `FLUX.1-dev` preset for InstantX Flux ipadapter on `easy ipadapterApply`
54+
- Added `REGULAR - FLUX and SD3.5 only (high strength)` preset for InstantX Flux ipadapter on `easy ipadapterApply`
6455
- Fix brushnet can not be used with startup arg `--fast` mode
6556
- Support briaai RMBG-2.0
6657
- Support mochi

py/config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,23 @@
231231
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors"
232232
}
233233
},
234-
"PLUS (kolors genernal)":{
235-
"sd1":{
236-
"model_url":""
234+
"PLUS (kolors genernal)": {
235+
"sd1": {
236+
"model_url": ""
237237
},
238-
"sdxl":{
238+
"sdxl": {
239239
"model_url":"https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus/resolve/main/ip_adapter_plus_general.bin"
240240
}
241241
},
242-
"FLUX.1-dev": {
243-
"flux":{
242+
"REGULAR - FLUX and SD3.5 only (high strength)": {
243+
"flux": {
244244
"model_url": "https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/resolve/main/ip-adapter.bin",
245245
"model_file_name": "ip-adapter_flux_1_dev.bin",
246246
},
247+
"sd3": {
248+
"model_url": "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin",
249+
"model_file_name": "ip-adapter_sd35.bin",
250+
},
247251
},
248252
"PLUS FACE (portraits)": {
249253
"sd1": {

py/easyNodes.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,7 +3000,7 @@ def __init__(self):
30003000
'VIT-G (medium strength)',
30013001
'PLUS (high strength)',
30023002
'PLUS (kolors genernal)',
3003-
'FLUX.1-dev',
3003+
'REGULAR - FLUX and SD3.5 only (high strength)',
30043004
'PLUS FACE (portraits)',
30053005
'FULL FACE - SD1.5 only (portraits stronger)',
30063006
'COMPOSITION'
@@ -3023,7 +3023,7 @@ def error(self):
30233023
def get_clipvision_file(self, preset, node_name):
30243024
preset = preset.lower()
30253025
clipvision_list = folder_paths.get_filename_list("clip_vision")
3026-
if preset.startswith("flux"):
3026+
if preset.startswith("regular"):
30273027
# pattern = 'sigclip.vision.patch14.384'
30283028
pattern = 'siglip.so400m.patch14.384'
30293029
elif preset.startswith("plus (kolors") or preset.startswith("faceid plus kolors"):
@@ -3045,6 +3045,7 @@ def get_ipadapter_file(self, preset, model_type, node_name):
30453045
is_insightface = False
30463046
lora_pattern = None
30473047
is_sdxl = model_type == 'sdxl'
3048+
is_flux = model_type == 'flux'
30483049

30493050
if preset.startswith("light"):
30503051
if is_sdxl:
@@ -3063,8 +3064,11 @@ def get_ipadapter_file(self, preset, model_type, node_name):
30633064
pattern = 'ip.adapter.sdxl.(safetensors|bin)$'
30643065
else:
30653066
pattern = 'sd15.vit.g.(safetensors|bin)$'
3066-
elif preset.startswith("flux"):
3067-
pattern = 'ip.adapter.flux.1.dev.(safetensors|bin)$'
3067+
elif preset.startswith("regular"):
3068+
if is_flux:
3069+
pattern = 'ip.adapter.flux.1.dev.(safetensors|bin)$'
3070+
else:
3071+
pattern = 'ip.adapter.sd35.(safetensors|bin)$'
30683072
elif preset.startswith("plus (high"):
30693073
if is_sdxl:
30703074
pattern = 'plus.sdxl.vit.h.(safetensors|bin)$'
@@ -3218,7 +3222,7 @@ def load_model(self, model, preset, lora_model_strength, provider="CPU", clip_vi
32183222
if not clip_vision:
32193223
clipvision_file, clipvision_name = self.get_clipvision_file(preset, node_name)
32203224
if clipvision_file is None:
3221-
if preset.lower().startswith("flux"):
3225+
if preset.lower().startswith("regular"):
32223226
# model_url = IPADAPTER_CLIPVISION_MODELS["sigclip_vision_patch14_384"]["model_url"]
32233227
# clipvision_file = get_local_filepath(model_url, IPADAPTER_DIR, "sigclip_vision_patch14_384.bin")
32243228
from huggingface_hub import snapshot_download
@@ -3253,7 +3257,7 @@ def load_model(self, model, preset, lora_model_strength, provider="CPU", clip_vi
32533257
log_node_info("easy ipadapterApply", f"Using ClipVisonModel {clipvision_name} Cached")
32543258
_, clip_vision = backend_cache.cache[clipvision_name][1]
32553259
else:
3256-
if preset.lower().startswith("flux"):
3260+
if preset.lower().startswith("regular"):
32573261
from transformers import SiglipVisionModel, AutoProcessor
32583262
image_encoder_path = os.path.dirname(clipvision_file)
32593263
image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path)
@@ -3352,9 +3356,13 @@ def INPUT_TYPES(cls):
33523356
def apply(self, model, image, preset, lora_strength, provider, weight, weight_faceidv2, start_at, end_at, cache_mode, use_tiled, attn_mask=None, optional_ipadapter=None, weight_kolors=None):
33533357
images, masks = image, [None]
33543358
model, ipadapter = self.load_model(model, preset, lora_strength, provider, clip_vision=None, optional_ipadapter=optional_ipadapter, cache_mode=cache_mode)
3355-
if preset in ['FLUX.1-dev']:
3356-
from .ipadapter import InstantXFluxIpadapterApply
3357-
model, images = InstantXFluxIpadapterApply().apply_ipadapter_flux(model, ipadapter, image, weight, start_at, end_at, provider)
3359+
if preset == 'REGULAR - FLUX and SD3.5 only (high strength)':
3360+
from .ipadapter import InstantXFluxIpadapterApply, InstantXSD3IpadapterApply
3361+
model_type = get_sd_version(model)
3362+
if model_type == 'flux':
3363+
model, images = InstantXFluxIpadapterApply().apply_ipadapter(model, ipadapter, image, weight, start_at, end_at, provider)
3364+
elif model_type == 'sd3':
3365+
model, images = InstantXSD3IpadapterApply().apply_ipadapter(model, ipadapter, image, weight, start_at, end_at, provider)
33583366
elif use_tiled and preset not in self.faceid_presets:
33593367
if "IPAdapterTiled" not in ALL_NODE_CLASS_MAPPINGS:
33603368
self.error()

py/ipadapter/__init__.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
from .attention_processor import IPAFluxAttnProcessor2_0
77
from .utils import is_model_pathched, FluxUpdateModules
8+
from .sd3.resampler import TimeResampler
9+
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
810

911
image_proj_model = None
1012
class MLPProjModel(torch.nn.Module):
@@ -95,7 +97,7 @@ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
9597
image_prompt_embeds = image_proj_model(clip_image_embeds)
9698
return image_prompt_embeds
9799

98-
def apply_ipadapter_flux(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
100+
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
99101
self.device = provider.lower()
100102
if "clipvision" in ipadapter:
101103
# self.clip_vision = ipadapter["clipvision"]['model']
@@ -127,3 +129,140 @@ def apply_ipadapter_flux(self, model, ipadapter, image, weight, start_at, end_at
127129

128130
return (bi, image)
129131

132+
133+
def patch_sd3(
134+
patcher,
135+
ip_procs,
136+
resampler: TimeResampler,
137+
clip_embeds,
138+
weight=1.0,
139+
start=0.0,
140+
end=1.0,
141+
):
142+
"""
143+
Patches a model_sampler to add the ipadapter
144+
"""
145+
mmdit = patcher.model.diffusion_model
146+
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
147+
"timesteps", 1000
148+
)
149+
# hook the model's forward function
150+
# so that when it gets called, we can grab the timestep and send it to the resampler
151+
ip_options = {
152+
"hidden_states": None,
153+
"t_emb": None,
154+
"weight": weight,
155+
}
156+
157+
def ddit_wrapper(forward, args):
158+
# this is between 0 and 1, so the adapters can calculate start_point and end_point
159+
# actually, do we need to get the sigma value instead?
160+
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
161+
if start <= t_percent <= end:
162+
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
163+
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
164+
embeds = clip_embeds[args["cond_or_uncond"]]
165+
# slight efficiency optimization todo: pass the embeds through and then afterwards
166+
# repeat to the batch size
167+
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
168+
# the resampler wants between 0 and MAX_STEPS
169+
timestep = args["timestep"] * timestep_schedule_max
170+
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
171+
# these will need to be accessible to the IPAdapters
172+
ip_options["hidden_states"] = image_emb
173+
ip_options["t_emb"] = t_emb
174+
else:
175+
ip_options["hidden_states"] = None
176+
ip_options["t_emb"] = None
177+
178+
return forward(args["input"], args["timestep"], **args["c"])
179+
180+
patcher.set_model_unet_function_wrapper(ddit_wrapper)
181+
# patch each dit block
182+
for i, block in enumerate(mmdit.joint_blocks):
183+
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
184+
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
185+
186+
class InstantXSD3IpadapterApply:
187+
def __init__(self):
188+
self.device = None
189+
self.dtype = torch.float16
190+
self.clip_image_processor = None
191+
self.image_encoder = None
192+
self.resampler = None
193+
self.procs = None
194+
195+
@torch.inference_mode()
196+
def encode(self, image):
197+
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
198+
clip_image_embeds = self.image_encoder(
199+
clip_image.to(self.device, dtype=self.image_encoder.dtype),
200+
output_hidden_states=True,
201+
).hidden_states[-2]
202+
clip_image_embeds = torch.cat(
203+
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
204+
)
205+
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
206+
return clip_image_embeds
207+
208+
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
209+
self.device = provider.lower()
210+
if "clipvision" in ipadapter:
211+
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
212+
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
213+
if "ipadapter" in ipadapter:
214+
self.ip_ckpt = ipadapter["ipadapter"]['file']
215+
self.state_dict = ipadapter["ipadapter"]['model']
216+
217+
self.resampler = TimeResampler(
218+
dim=1280,
219+
depth=4,
220+
dim_head=64,
221+
heads=20,
222+
num_queries=64,
223+
embedding_dim=1152,
224+
output_dim=2432,
225+
ff_mult=4,
226+
timestep_in_dim=320,
227+
timestep_flip_sin_to_cos=True,
228+
timestep_freq_shift=0,
229+
)
230+
self.resampler.eval()
231+
self.resampler.to(self.device, dtype=self.dtype)
232+
self.resampler.load_state_dict(self.state_dict["image_proj"])
233+
234+
# now we'll create the attention processors
235+
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
236+
n_procs = len(
237+
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
238+
)
239+
self.procs = torch.nn.ModuleList(
240+
[
241+
# this is hardcoded for SD3.5L
242+
IPAttnProcessor(
243+
hidden_size=2432,
244+
cross_attention_dim=2432,
245+
ip_hidden_states_dim=2432,
246+
ip_encoder_hidden_states_dim=2432,
247+
head_dim=64,
248+
timesteps_emb_dim=1280,
249+
).to(self.device, dtype=torch.float16)
250+
for _ in range(n_procs)
251+
]
252+
)
253+
self.procs.load_state_dict(self.state_dict["ip_adapter"])
254+
255+
work_model = model.clone()
256+
embeds = self.encode(image)
257+
258+
patch_sd3(
259+
work_model,
260+
self.procs,
261+
self.resampler,
262+
embeds,
263+
weight,
264+
start_at,
265+
end_at,
266+
)
267+
268+
return (work_model, image)

0 commit comments

Comments
 (0)