Skip to content

Commit 045cadd

Browse files
committed
Move LatentsToImageInvocation to its own file. No functional changes.
1 parent 5869714 commit 045cadd

File tree

2 files changed

+108
-92
lines changed

2 files changed

+108
-92
lines changed

invokeai/app/invocations/latent.py

Lines changed: 1 addition & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,7 @@
88
import torchvision
99
import torchvision.transforms as T
1010
from diffusers.configuration_utils import ConfigMixin
11-
from diffusers.image_processor import VaeImageProcessor
1211
from diffusers.models.adapter import T2IAdapter
13-
from diffusers.models.attention_processor import (
14-
AttnProcessor2_0,
15-
LoRAAttnProcessor2_0,
16-
LoRAXFormersAttnProcessor,
17-
XFormersAttnProcessor,
18-
)
19-
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
20-
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
2112
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
2213
from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler
2314
from diffusers.schedulers.scheduling_tcd import TCDScheduler
@@ -38,11 +29,9 @@
3829
LatentsField,
3930
OutputField,
4031
UIType,
41-
WithBoard,
42-
WithMetadata,
4332
)
4433
from invokeai.app.invocations.ip_adapter import IPAdapterField
45-
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
34+
from invokeai.app.invocations.primitives import DenoiseMaskOutput, LatentsOutput
4635
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
4736
from invokeai.app.services.shared.invocation_context import InvocationContext
4837
from invokeai.app.util.controlnet_utils import prepare_control_image
@@ -1033,83 +1022,3 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
10331022

10341023
name = context.tensors.save(tensor=result_latents)
10351024
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
1036-
1037-
1038-
@invocation(
1039-
"l2i",
1040-
title="Latents to Image",
1041-
tags=["latents", "image", "vae", "l2i"],
1042-
category="latents",
1043-
version="1.2.2",
1044-
)
1045-
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
1046-
"""Generates an image from latents."""
1047-
1048-
latents: LatentsField = InputField(
1049-
description=FieldDescriptions.latents,
1050-
input=Input.Connection,
1051-
)
1052-
vae: VAEField = InputField(
1053-
description=FieldDescriptions.vae,
1054-
input=Input.Connection,
1055-
)
1056-
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
1057-
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
1058-
1059-
@torch.no_grad()
1060-
def invoke(self, context: InvocationContext) -> ImageOutput:
1061-
latents = context.tensors.load(self.latents.latents_name)
1062-
1063-
vae_info = context.models.load(self.vae.vae)
1064-
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
1065-
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
1066-
assert isinstance(vae, torch.nn.Module)
1067-
latents = latents.to(vae.device)
1068-
if self.fp32:
1069-
vae.to(dtype=torch.float32)
1070-
1071-
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
1072-
vae.decoder.mid_block.attentions[0].processor,
1073-
(
1074-
AttnProcessor2_0,
1075-
XFormersAttnProcessor,
1076-
LoRAXFormersAttnProcessor,
1077-
LoRAAttnProcessor2_0,
1078-
),
1079-
)
1080-
# if xformers or torch_2_0 is used attention block does not need
1081-
# to be in float32 which can save lots of memory
1082-
if use_torch_2_0_or_xformers:
1083-
vae.post_quant_conv.to(latents.dtype)
1084-
vae.decoder.conv_in.to(latents.dtype)
1085-
vae.decoder.mid_block.to(latents.dtype)
1086-
else:
1087-
latents = latents.float()
1088-
1089-
else:
1090-
vae.to(dtype=torch.float16)
1091-
latents = latents.half()
1092-
1093-
if self.tiled or context.config.get().force_tiled_decode:
1094-
vae.enable_tiling()
1095-
else:
1096-
vae.disable_tiling()
1097-
1098-
# clear memory as vae decode can request a lot
1099-
TorchDevice.empty_cache()
1100-
1101-
with torch.inference_mode():
1102-
# copied from diffusers pipeline
1103-
latents = latents / vae.config.scaling_factor
1104-
image = vae.decode(latents, return_dict=False)[0]
1105-
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
1106-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1107-
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1108-
1109-
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
1110-
1111-
TorchDevice.empty_cache()
1112-
1113-
image_dto = context.images.save(image=image)
1114-
1115-
return ImageOutput.build(image_dto)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
from diffusers.image_processor import VaeImageProcessor
3+
from diffusers.models.attention_processor import (
4+
AttnProcessor2_0,
5+
LoRAAttnProcessor2_0,
6+
LoRAXFormersAttnProcessor,
7+
XFormersAttnProcessor,
8+
)
9+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
10+
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
11+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
12+
13+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
14+
from invokeai.app.invocations.fields import (
15+
FieldDescriptions,
16+
Input,
17+
InputField,
18+
LatentsField,
19+
WithBoard,
20+
WithMetadata,
21+
)
22+
from invokeai.app.invocations.latent import DEFAULT_PRECISION
23+
from invokeai.app.invocations.model import VAEField
24+
from invokeai.app.invocations.primitives import ImageOutput
25+
from invokeai.app.services.shared.invocation_context import InvocationContext
26+
from invokeai.backend.stable_diffusion import set_seamless
27+
from invokeai.backend.util.devices import TorchDevice
28+
29+
30+
@invocation(
31+
"l2i",
32+
title="Latents to Image",
33+
tags=["latents", "image", "vae", "l2i"],
34+
category="latents",
35+
version="1.2.2",
36+
)
37+
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
38+
"""Generates an image from latents."""
39+
40+
latents: LatentsField = InputField(
41+
description=FieldDescriptions.latents,
42+
input=Input.Connection,
43+
)
44+
vae: VAEField = InputField(
45+
description=FieldDescriptions.vae,
46+
input=Input.Connection,
47+
)
48+
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
49+
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
50+
51+
@torch.no_grad()
52+
def invoke(self, context: InvocationContext) -> ImageOutput:
53+
latents = context.tensors.load(self.latents.latents_name)
54+
55+
vae_info = context.models.load(self.vae.vae)
56+
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
57+
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
58+
assert isinstance(vae, torch.nn.Module)
59+
latents = latents.to(vae.device)
60+
if self.fp32:
61+
vae.to(dtype=torch.float32)
62+
63+
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
64+
vae.decoder.mid_block.attentions[0].processor,
65+
(
66+
AttnProcessor2_0,
67+
XFormersAttnProcessor,
68+
LoRAXFormersAttnProcessor,
69+
LoRAAttnProcessor2_0,
70+
),
71+
)
72+
# if xformers or torch_2_0 is used attention block does not need
73+
# to be in float32 which can save lots of memory
74+
if use_torch_2_0_or_xformers:
75+
vae.post_quant_conv.to(latents.dtype)
76+
vae.decoder.conv_in.to(latents.dtype)
77+
vae.decoder.mid_block.to(latents.dtype)
78+
else:
79+
latents = latents.float()
80+
81+
else:
82+
vae.to(dtype=torch.float16)
83+
latents = latents.half()
84+
85+
if self.tiled or context.config.get().force_tiled_decode:
86+
vae.enable_tiling()
87+
else:
88+
vae.disable_tiling()
89+
90+
# clear memory as vae decode can request a lot
91+
TorchDevice.empty_cache()
92+
93+
with torch.inference_mode():
94+
# copied from diffusers pipeline
95+
latents = latents / vae.config.scaling_factor
96+
image = vae.decode(latents, return_dict=False)[0]
97+
image = (image / 2 + 0.5).clamp(0, 1) # denormalize
98+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
99+
np_image = image.cpu().permute(0, 2, 3, 1).float().numpy()
100+
101+
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
102+
103+
TorchDevice.empty_cache()
104+
105+
image_dto = context.images.save(image=image)
106+
107+
return ImageOutput.build(image_dto)

0 commit comments

Comments
 (0)