Skip to content

Commit 9b5b239

Browse files
committed
Fix after using fooocus inpaint,all models become unusable #354
1 parent 4167733 commit 9b5b239

File tree

3 files changed

+118
-122
lines changed

3 files changed

+118
-122
lines changed

py/easyNodes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,8 +2431,6 @@ def load_lllite(self, model, model_name, cond_image, strength, steps, start_perc
24312431
#---------------------------------------------------------------Inpaint 开始----------------------------------------------------------------------#
24322432

24332433
# FooocusInpaint
2434-
from .libs.fooocus import InpaintHead, InpaintWorker
2435-
24362434
class applyFooocusInpaint:
24372435
@classmethod
24382436
def INPUT_TYPES(s):
@@ -2451,7 +2449,7 @@ def INPUT_TYPES(s):
24512449
FUNCTION = "apply"
24522450

24532451
def apply(self, model, latent, head, patch):
2454-
2452+
from .fooocus import InpaintHead, InpaintWorker
24552453
head_file = get_local_filepath(FOOOCUS_INPAINT_HEAD[head]["model_url"], INPAINT_DIR)
24562454
inpaint_head_model = InpaintHead()
24572455
sd = torch.load(head_file, map_location='cpu')

py/fooocus/__init__.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#credit to Acly for this module
2+
#from https://github.com/Acly/comfyui-inpaint-nodes
3+
import torch
4+
import torch.nn.functional as F
5+
import comfy
6+
from comfy.model_base import BaseModel
7+
from comfy.model_patcher import ModelPatcher
8+
from comfy.model_management import cast_to_device
9+
10+
from ..libs.log import log_node_warn, log_node_error, log_node_info
11+
12+
class InpaintHead(torch.nn.Module):
13+
def __init__(self, *args, **kwargs):
14+
super().__init__(*args, **kwargs)
15+
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
16+
17+
def __call__(self, x):
18+
x = F.pad(x, (1, 1, 1, 1), "replicate")
19+
return F.conv2d(x, weight=self.head)
20+
21+
22+
class applyFooocusPatch:
23+
def calculate_weight_patched(self, patches, weight, key, intermediate_type=torch.float32):
24+
remaining = []
25+
26+
for p in patches:
27+
alpha = p[0]
28+
v = p[1]
29+
30+
is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
31+
if not is_fooocus_patch:
32+
remaining.append(p)
33+
continue
34+
35+
if alpha != 0.0:
36+
v = v[1]
37+
w1 = cast_to_device(v[0], weight.device, torch.float32)
38+
if w1.shape == weight.shape:
39+
w_min = cast_to_device(v[1], weight.device, torch.float32)
40+
w_max = cast_to_device(v[2], weight.device, torch.float32)
41+
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
42+
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
43+
else:
44+
pass
45+
# log_node_warn(self.node_name,
46+
# f"Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
47+
# )
48+
49+
if len(remaining) > 0:
50+
return self.original_calculate_weight(remaining, weight, key, intermediate_type)
51+
return weight
52+
53+
def __enter__(self):
54+
try:
55+
print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
56+
self.original_calculate_weight = comfy.lora.calculate_weight
57+
comfy.lora.calculate_weight = self.calculate_weight_patched
58+
except AttributeError:
59+
print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
60+
self.original_calculate_weight = ModelPatcher.calculate_weight
61+
ModelPatcher.calculate_weight = self.calculate_weight_patched
62+
def __exit__(self, type, value, traceback):
63+
try:
64+
comfy.lora.calculate_weight = self.original_calculate_weight
65+
except AttributeError:
66+
ModelPatcher.calculate_weight = self.original_calculate_weight
67+
68+
69+
class InpaintWorker:
70+
def __init__(self, node_name):
71+
self.node_name = node_name if node_name is not None else ""
72+
73+
def load_fooocus_patch(self, lora: dict, to_load: dict):
74+
patch_dict = {}
75+
loaded_keys = set()
76+
for key in to_load.values():
77+
if value := lora.get(key, None):
78+
patch_dict[key] = ("fooocus", value)
79+
loaded_keys.add(key)
80+
81+
not_loaded = sum(1 for x in lora if x not in loaded_keys)
82+
if not_loaded > 0:
83+
log_node_info(self.node_name,
84+
f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
85+
)
86+
return patch_dict
87+
88+
89+
def patch(self, model, latent, patch):
90+
with applyFooocusPatch():
91+
base_model: BaseModel = model.model
92+
latent_pixels = base_model.process_latent_in(latent["samples"])
93+
noise_mask = latent["noise_mask"].round()
94+
latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)
95+
96+
inpaint_head_model, inpaint_lora = patch
97+
feed = torch.cat([latent_mask, latent_pixels], dim=1)
98+
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
99+
inpaint_head_feature = inpaint_head_model(feed)
100+
101+
def input_block_patch(h, transformer_options):
102+
if transformer_options["block"][1] == 0:
103+
h = h + inpaint_head_feature.to(h)
104+
return h
105+
106+
lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
107+
lora_keys.update({x: x for x in base_model.state_dict().keys()})
108+
loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)
109+
110+
m = model.clone()
111+
m.set_model_input_block_patch(input_block_patch)
112+
patched = m.add_patches(loaded_lora, 1.0)
113+
114+
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
115+
if not_patched_count > 0:
116+
log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")
117+
return (m,)

py/libs/fooocus.py

Lines changed: 0 additions & 119 deletions
This file was deleted.

0 commit comments

Comments
 (0)