Skip to content

Commit 64a8f9d

Browse files
committed
Update img2imgalt.py
Fix with documentation
1 parent 82a973c commit 64a8f9d

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

scripts/img2imgalt.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import torch
1212
import k_diffusion as K
1313

14+
# Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
15+
# For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model
16+
17+
1418
def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
1519
x = p.init_latent
1620

@@ -30,15 +34,25 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
3034

3135
x_in = torch.cat([x] * 2)
3236
sigma_in = torch.cat([sigmas[i] * s_in] * 2)
33-
cond_in = torch.cat([uncond, cond])
37+
38+
if shared.sd_model.is_sdxl:
39+
cond_tensor = cond['crossattn']
40+
uncond_tensor = uncond['crossattn']
41+
cond_in = torch.cat([uncond_tensor, cond_tensor])
42+
else:
43+
cond_in = torch.cat([uncond, cond])
3444

3545
image_conditioning = torch.cat([p.image_conditioning] * 2)
3646
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
3747

3848
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
3949
t = dnw.sigma_to_t(sigma_in)
4050

41-
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
51+
if shared.sd_model.is_sdxl:
52+
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
53+
else:
54+
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
55+
4256
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
4357

4458
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
@@ -64,6 +78,13 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
6478

6579
# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
6680
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
81+
if shared.sd_model.is_sdxl:
82+
cond_tensor = cond['crossattn']
83+
uncond_tensor = uncond['crossattn']
84+
cond_in = torch.cat([uncond_tensor, cond_tensor])
85+
else:
86+
cond_in = torch.cat([uncond, cond])
87+
6788
x = p.init_latent
6889

6990
s_in = x.new_ones([x.shape[0]])
@@ -82,7 +103,14 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
82103

83104
x_in = torch.cat([x] * 2)
84105
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
85-
cond_in = torch.cat([uncond, cond])
106+
107+
108+
if shared.sd_model.is_sdxl:
109+
cond_tensor = cond['crossattn']
110+
uncond_tensor = uncond['crossattn']
111+
cond_in = torch.cat([uncond_tensor, cond_tensor])
112+
else:
113+
cond_in = torch.cat([uncond, cond])
86114

87115
image_conditioning = torch.cat([p.image_conditioning] * 2)
88116
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
@@ -94,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
94122
else:
95123
t = dnw.sigma_to_t(sigma_in)
96124

97-
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
125+
126+
if shared.sd_model.is_sdxl:
127+
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
128+
else:
129+
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
130+
98131
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
99132

100133
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale

0 commit comments

Comments
 (0)