1
+ #credit to Acly for this module
2
+ #from https://github.com/Acly/comfyui-inpaint-nodes
3
+ import torch
4
+ import torch .nn .functional as F
5
+ import comfy
6
+ from comfy .model_base import BaseModel
7
+ from comfy .model_patcher import ModelPatcher
8
+ from comfy .model_management import cast_to_device
9
+
10
+ from ..libs .log import log_node_warn , log_node_error , log_node_info
11
+
12
+ class InpaintHead (torch .nn .Module ):
13
+ def __init__ (self , * args , ** kwargs ):
14
+ super ().__init__ (* args , ** kwargs )
15
+ self .head = torch .nn .Parameter (torch .empty (size = (320 , 5 , 3 , 3 ), device = "cpu" ))
16
+
17
+ def __call__ (self , x ):
18
+ x = F .pad (x , (1 , 1 , 1 , 1 ), "replicate" )
19
+ return F .conv2d (x , weight = self .head )
20
+
21
+
22
+ class applyFooocusPatch :
23
+ def calculate_weight_patched (self , patches , weight , key , intermediate_type = torch .float32 ):
24
+ remaining = []
25
+
26
+ for p in patches :
27
+ alpha = p [0 ]
28
+ v = p [1 ]
29
+
30
+ is_fooocus_patch = isinstance (v , tuple ) and len (v ) == 2 and v [0 ] == "fooocus"
31
+ if not is_fooocus_patch :
32
+ remaining .append (p )
33
+ continue
34
+
35
+ if alpha != 0.0 :
36
+ v = v [1 ]
37
+ w1 = cast_to_device (v [0 ], weight .device , torch .float32 )
38
+ if w1 .shape == weight .shape :
39
+ w_min = cast_to_device (v [1 ], weight .device , torch .float32 )
40
+ w_max = cast_to_device (v [2 ], weight .device , torch .float32 )
41
+ w1 = (w1 / 255.0 ) * (w_max - w_min ) + w_min
42
+ weight += alpha * cast_to_device (w1 , weight .device , weight .dtype )
43
+ else :
44
+ pass
45
+ # log_node_warn(self.node_name,
46
+ # f"Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
47
+ # )
48
+
49
+ if len (remaining ) > 0 :
50
+ return self .original_calculate_weight (remaining , weight , key , intermediate_type )
51
+ return weight
52
+
53
+ def __enter__ (self ):
54
+ try :
55
+ print ("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight" )
56
+ self .original_calculate_weight = comfy .lora .calculate_weight
57
+ comfy .lora .calculate_weight = self .calculate_weight_patched
58
+ except AttributeError :
59
+ print ("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight" )
60
+ self .original_calculate_weight = ModelPatcher .calculate_weight
61
+ ModelPatcher .calculate_weight = self .calculate_weight_patched
62
+ def __exit__ (self , type , value , traceback ):
63
+ try :
64
+ comfy .lora .calculate_weight = self .original_calculate_weight
65
+ except AttributeError :
66
+ ModelPatcher .calculate_weight = self .original_calculate_weight
67
+
68
+
69
+ class InpaintWorker :
70
+ def __init__ (self , node_name ):
71
+ self .node_name = node_name if node_name is not None else ""
72
+
73
+ def load_fooocus_patch (self , lora : dict , to_load : dict ):
74
+ patch_dict = {}
75
+ loaded_keys = set ()
76
+ for key in to_load .values ():
77
+ if value := lora .get (key , None ):
78
+ patch_dict [key ] = ("fooocus" , value )
79
+ loaded_keys .add (key )
80
+
81
+ not_loaded = sum (1 for x in lora if x not in loaded_keys )
82
+ if not_loaded > 0 :
83
+ log_node_info (self .node_name ,
84
+ f"{ len (loaded_keys )} Lora keys loaded, { not_loaded } remaining keys not found in model."
85
+ )
86
+ return patch_dict
87
+
88
+
89
+ def patch (self , model , latent , patch ):
90
+ with applyFooocusPatch ():
91
+ base_model : BaseModel = model .model
92
+ latent_pixels = base_model .process_latent_in (latent ["samples" ])
93
+ noise_mask = latent ["noise_mask" ].round ()
94
+ latent_mask = F .max_pool2d (noise_mask , (8 , 8 )).round ().to (latent_pixels )
95
+
96
+ inpaint_head_model , inpaint_lora = patch
97
+ feed = torch .cat ([latent_mask , latent_pixels ], dim = 1 )
98
+ inpaint_head_model .to (device = feed .device , dtype = feed .dtype )
99
+ inpaint_head_feature = inpaint_head_model (feed )
100
+
101
+ def input_block_patch (h , transformer_options ):
102
+ if transformer_options ["block" ][1 ] == 0 :
103
+ h = h + inpaint_head_feature .to (h )
104
+ return h
105
+
106
+ lora_keys = comfy .lora .model_lora_keys_unet (model .model , {})
107
+ lora_keys .update ({x : x for x in base_model .state_dict ().keys ()})
108
+ loaded_lora = self .load_fooocus_patch (inpaint_lora , lora_keys )
109
+
110
+ m = model .clone ()
111
+ m .set_model_input_block_patch (input_block_patch )
112
+ patched = m .add_patches (loaded_lora , 1.0 )
113
+
114
+ not_patched_count = sum (1 for x in loaded_lora if x not in patched )
115
+ if not_patched_count > 0 :
116
+ log_node_error (self .node_name , f"Failed to patch { not_patched_count } keys" )
117
+ return (m ,)
0 commit comments