1111import torch
1212import k_diffusion as K
1313
14+ # Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
15+ # For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model
16+
17+
1418def find_noise_for_image (p , cond , uncond , cfg_scale , steps ):
1519 x = p .init_latent
1620
@@ -30,15 +34,25 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
3034
3135 x_in = torch .cat ([x ] * 2 )
3236 sigma_in = torch .cat ([sigmas [i ] * s_in ] * 2 )
33- cond_in = torch .cat ([uncond , cond ])
37+
38+ if shared .sd_model .is_sdxl :
39+ cond_tensor = cond ['crossattn' ]
40+ uncond_tensor = uncond ['crossattn' ]
41+ cond_in = torch .cat ([uncond_tensor , cond_tensor ])
42+ else :
43+ cond_in = torch .cat ([uncond , cond ])
3444
3545 image_conditioning = torch .cat ([p .image_conditioning ] * 2 )
3646 cond_in = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond_in ]}
3747
3848 c_out , c_in = [K .utils .append_dims (k , x_in .ndim ) for k in dnw .get_scalings (sigma_in )[skip :]]
3949 t = dnw .sigma_to_t (sigma_in )
4050
41- eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
51+ if shared .sd_model .is_sdxl :
52+ eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
53+ else :
54+ eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
55+
4256 denoised_uncond , denoised_cond = (x_in + eps * c_out ).chunk (2 )
4357
4458 denoised = denoised_uncond + (denoised_cond - denoised_uncond ) * cfg_scale
@@ -64,6 +78,13 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
6478
6579# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
6680def find_noise_for_image_sigma_adjustment (p , cond , uncond , cfg_scale , steps ):
81+ if shared .sd_model .is_sdxl :
82+ cond_tensor = cond ['crossattn' ]
83+ uncond_tensor = uncond ['crossattn' ]
84+ cond_in = torch .cat ([uncond_tensor , cond_tensor ])
85+ else :
86+ cond_in = torch .cat ([uncond , cond ])
87+
6788 x = p .init_latent
6889
6990 s_in = x .new_ones ([x .shape [0 ]])
@@ -82,7 +103,14 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
82103
83104 x_in = torch .cat ([x ] * 2 )
84105 sigma_in = torch .cat ([sigmas [i - 1 ] * s_in ] * 2 )
85- cond_in = torch .cat ([uncond , cond ])
106+
107+
108+ if shared .sd_model .is_sdxl :
109+ cond_tensor = cond ['crossattn' ]
110+ uncond_tensor = uncond ['crossattn' ]
111+ cond_in = torch .cat ([uncond_tensor , cond_tensor ])
112+ else :
113+ cond_in = torch .cat ([uncond , cond ])
86114
87115 image_conditioning = torch .cat ([p .image_conditioning ] * 2 )
88116 cond_in = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond_in ]}
@@ -94,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
94122 else :
95123 t = dnw .sigma_to_t (sigma_in )
96124
97- eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
125+
126+ if shared .sd_model .is_sdxl :
127+ eps = shared .sd_model .model (x_in * c_in , t , {"crossattn" : cond_in ["c_crossattn" ][0 ]} )
128+ else :
129+ eps = shared .sd_model .apply_model (x_in * c_in , t , cond = cond_in )
130+
98131 denoised_uncond , denoised_cond = (x_in + eps * c_out ).chunk (2 )
99132
100133 denoised = denoised_uncond + (denoised_cond - denoised_uncond ) * cfg_scale
0 commit comments