11
11
import torch
12
12
import k_diffusion as K
13
13
14
+ # Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
15
+ # For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model
16
+
17
+
14
18
def find_noise_for_image (p , cond , uncond , cfg_scale , steps ):
15
19
x = p .init_latent
16
20
@@ -30,15 +34,25 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
30
34
31
35
x_in = torch .cat ([x ] * 2 )
32
36
sigma_in = torch .cat ([sigmas [i ] * s_in ] * 2 )
33
- cond_in = torch .cat ([uncond , cond ])
37
+
38
+ if shared .sd_model .is_sdxl :
39
+ cond_tensor = cond ['crossattn' ]
40
+ uncond_tensor = uncond ['crossattn' ]
41
+ cond_in = torch .cat ([uncond_tensor , cond_tensor ])
42
+ else :
43
+ cond_in = torch .cat ([uncond , cond ])
34
44
35
45
image_conditioning = torch .cat ([p .image_conditioning ] * 2 )
36
46
cond_in = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond_in ]}
37
47
38
48
c_out , c_in = [K .utils .append_dims (k , x_in .ndim ) for k in dnw .get_scalings (sigma_in )[skip :]]
39
49
t = dnw .sigma_to_t (sigma_in )
40
50
41
- eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
51
+ if shared .sd_model .is_sdxl :
52
+ eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
53
+ else :
54
+ eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
55
+
42
56
denoised_uncond , denoised_cond = (x_in + eps * c_out ).chunk (2 )
43
57
44
58
denoised = denoised_uncond + (denoised_cond - denoised_uncond ) * cfg_scale
@@ -64,6 +78,13 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
64
78
65
79
# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
66
80
def find_noise_for_image_sigma_adjustment (p , cond , uncond , cfg_scale , steps ):
81
+ if shared .sd_model .is_sdxl :
82
+ cond_tensor = cond ['crossattn' ]
83
+ uncond_tensor = uncond ['crossattn' ]
84
+ cond_in = torch .cat ([uncond_tensor , cond_tensor ])
85
+ else :
86
+ cond_in = torch .cat ([uncond , cond ])
87
+
67
88
x = p .init_latent
68
89
69
90
s_in = x .new_ones ([x .shape [0 ]])
@@ -82,7 +103,14 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
82
103
83
104
x_in = torch .cat ([x ] * 2 )
84
105
sigma_in = torch .cat ([sigmas [i - 1 ] * s_in ] * 2 )
85
- cond_in = torch .cat ([uncond , cond ])
106
+
107
+
108
+ if shared .sd_model .is_sdxl :
109
+ cond_tensor = cond ['crossattn' ]
110
+ uncond_tensor = uncond ['crossattn' ]
111
+ cond_in = torch .cat ([uncond_tensor , cond_tensor ])
112
+ else :
113
+ cond_in = torch .cat ([uncond , cond ])
86
114
87
115
image_conditioning = torch .cat ([p .image_conditioning ] * 2 )
88
116
cond_in = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond_in ]}
@@ -94,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
94
122
else :
95
123
t = dnw .sigma_to_t (sigma_in )
96
124
97
- eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
125
+
126
+ if shared .sd_model .is_sdxl :
127
+ eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
128
+ else :
129
+ eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
130
+
98
131
denoised_uncond , denoised_cond = (x_in + eps * c_out ).chunk (2 )
99
132
100
133
denoised = denoised_uncond + (denoised_cond - denoised_uncond ) * cfg_scale
0 commit comments