Skip to content

Commit 74f0c31

Browse files
authored
Merge branch 'main' into lstein/feat/load-one-file
2 parents 3a622af + a43d602 commit 74f0c31

File tree

6 files changed

+179
-188
lines changed

6 files changed

+179
-188
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 148 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from torchvision.transforms.functional import resize as tv_resize
1717
from transformers import CLIPVisionModelWithProjection
1818

19+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
1920
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
21+
from invokeai.app.invocations.controlnet_image_processors import ControlField
2022
from invokeai.app.invocations.fields import (
2123
ConditioningField,
2224
DenoiseMaskField,
@@ -27,6 +29,7 @@
2729
UIType,
2830
)
2931
from invokeai.app.invocations.ip_adapter import IPAdapterField
32+
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
3033
from invokeai.app.invocations.primitives import LatentsOutput
3134
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
3235
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -36,6 +39,11 @@
3639
from invokeai.backend.model_manager import BaseModelType
3740
from invokeai.backend.model_patcher import ModelPatcher
3841
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
42+
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
43+
ControlNetData,
44+
StableDiffusionGeneratorPipeline,
45+
T2IAdapterData,
46+
)
3947
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
4048
BasicConditioningInfo,
4149
IPAdapterConditioningInfo,
@@ -45,20 +53,11 @@
4553
TextConditioningData,
4654
TextConditioningRegions,
4755
)
56+
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
57+
from invokeai.backend.util.devices import TorchDevice
4858
from invokeai.backend.util.mask import to_standard_float_mask
4959
from invokeai.backend.util.silence_warnings import SilenceWarnings
5060

51-
from ...backend.stable_diffusion.diffusers_pipeline import (
52-
ControlNetData,
53-
StableDiffusionGeneratorPipeline,
54-
T2IAdapterData,
55-
)
56-
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
57-
from ...backend.util.devices import TorchDevice
58-
from .baseinvocation import BaseInvocation, invocation
59-
from .controlnet_image_processors import ControlField
60-
from .model import ModelIdentifierField, UNetField
61-
6261

6362
def get_scheduler(
6463
context: InvocationContext,
@@ -658,155 +657,155 @@ def prep_inpaint_mask(
658657
return 1 - mask, masked_latents, self.denoise_mask.gradient
659658

660659
@torch.no_grad()
660+
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
661661
def invoke(self, context: InvocationContext) -> LatentsOutput:
662-
with SilenceWarnings(): # this quenches NSFW nag from diffusers
663-
seed = None
664-
noise = None
665-
if self.noise is not None:
666-
noise = context.tensors.load(self.noise.latents_name)
667-
seed = self.noise.seed
668-
669-
if self.latents is not None:
670-
latents = context.tensors.load(self.latents.latents_name)
671-
if seed is None:
672-
seed = self.latents.seed
673-
674-
if noise is not None and noise.shape[1:] != latents.shape[1:]:
675-
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
676-
677-
elif noise is not None:
678-
latents = torch.zeros_like(noise)
679-
else:
680-
raise Exception("'latents' or 'noise' must be provided!")
681-
662+
seed = None
663+
noise = None
664+
if self.noise is not None:
665+
noise = context.tensors.load(self.noise.latents_name)
666+
seed = self.noise.seed
667+
668+
if self.latents is not None:
669+
latents = context.tensors.load(self.latents.latents_name)
682670
if seed is None:
683-
seed = 0
671+
seed = self.latents.seed
684672

685-
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
673+
if noise is not None and noise.shape[1:] != latents.shape[1:]:
674+
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
686675

687-
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
688-
# below. Investigate whether this is appropriate.
689-
t2i_adapter_data = self.run_t2i_adapters(
690-
context,
691-
self.t2i_adapter,
692-
latents.shape,
693-
do_classifier_free_guidance=True,
694-
)
676+
elif noise is not None:
677+
latents = torch.zeros_like(noise)
678+
else:
679+
raise Exception("'latents' or 'noise' must be provided!")
695680

696-
ip_adapters: List[IPAdapterField] = []
697-
if self.ip_adapter is not None:
698-
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
699-
if isinstance(self.ip_adapter, list):
700-
ip_adapters = self.ip_adapter
701-
else:
702-
ip_adapters = [self.ip_adapter]
703-
704-
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
705-
# a series of image conditioning embeddings. This is being done here rather than in the
706-
# big model context below in order to use less VRAM on low-VRAM systems.
707-
# The image prompts are then passed to prep_ip_adapter_data().
708-
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
709-
710-
# get the unet's config so that we can pass the base to dispatch_progress()
711-
unet_config = context.models.get_config(self.unet.unet.key)
712-
713-
def step_callback(state: PipelineIntermediateState) -> None:
714-
context.util.sd_step_callback(state, unet_config.base)
715-
716-
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
717-
for lora in self.unet.loras:
718-
lora_info = context.models.load(lora.lora)
719-
assert isinstance(lora_info.model, LoRAModelRaw)
720-
yield (lora_info.model, lora.weight)
721-
del lora_info
722-
return
723-
724-
unet_info = context.models.load(self.unet.unet)
725-
assert isinstance(unet_info.model, UNet2DConditionModel)
726-
with (
727-
ExitStack() as exit_stack,
728-
unet_info.model_on_device() as (model_state_dict, unet),
729-
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
730-
set_seamless(unet, self.unet.seamless_axes), # FIXME
731-
# Apply the LoRA after unet has been moved to its target device for faster patching.
732-
ModelPatcher.apply_lora_unet(
733-
unet,
734-
loras=_lora_loader(),
735-
model_state_dict=model_state_dict,
736-
),
737-
):
738-
assert isinstance(unet, UNet2DConditionModel)
739-
latents = latents.to(device=unet.device, dtype=unet.dtype)
740-
if noise is not None:
741-
noise = noise.to(device=unet.device, dtype=unet.dtype)
742-
if mask is not None:
743-
mask = mask.to(device=unet.device, dtype=unet.dtype)
744-
if masked_latents is not None:
745-
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
746-
747-
scheduler = get_scheduler(
748-
context=context,
749-
scheduler_info=self.unet.scheduler,
750-
scheduler_name=self.scheduler,
751-
seed=seed,
752-
)
681+
if seed is None:
682+
seed = 0
753683

754-
pipeline = self.create_pipeline(unet, scheduler)
684+
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
755685

756-
_, _, latent_height, latent_width = latents.shape
757-
conditioning_data = self.get_conditioning_data(
758-
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
759-
)
686+
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
687+
# below. Investigate whether this is appropriate.
688+
t2i_adapter_data = self.run_t2i_adapters(
689+
context,
690+
self.t2i_adapter,
691+
latents.shape,
692+
do_classifier_free_guidance=True,
693+
)
760694

761-
controlnet_data = self.prep_control_data(
762-
context=context,
763-
control_input=self.control,
764-
latents_shape=latents.shape,
765-
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
766-
do_classifier_free_guidance=True,
767-
exit_stack=exit_stack,
768-
)
695+
ip_adapters: List[IPAdapterField] = []
696+
if self.ip_adapter is not None:
697+
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
698+
if isinstance(self.ip_adapter, list):
699+
ip_adapters = self.ip_adapter
700+
else:
701+
ip_adapters = [self.ip_adapter]
702+
703+
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
704+
# a series of image conditioning embeddings. This is being done here rather than in the
705+
# big model context below in order to use less VRAM on low-VRAM systems.
706+
# The image prompts are then passed to prep_ip_adapter_data().
707+
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
708+
709+
# get the unet's config so that we can pass the base to dispatch_progress()
710+
unet_config = context.models.get_config(self.unet.unet.key)
711+
712+
def step_callback(state: PipelineIntermediateState) -> None:
713+
context.util.sd_step_callback(state, unet_config.base)
714+
715+
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
716+
for lora in self.unet.loras:
717+
lora_info = context.models.load(lora.lora)
718+
assert isinstance(lora_info.model, LoRAModelRaw)
719+
yield (lora_info.model, lora.weight)
720+
del lora_info
721+
return
722+
723+
unet_info = context.models.load(self.unet.unet)
724+
assert isinstance(unet_info.model, UNet2DConditionModel)
725+
with (
726+
ExitStack() as exit_stack,
727+
unet_info.model_on_device() as (model_state_dict, unet),
728+
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
729+
set_seamless(unet, self.unet.seamless_axes), # FIXME
730+
# Apply the LoRA after unet has been moved to its target device for faster patching.
731+
ModelPatcher.apply_lora_unet(
732+
unet,
733+
loras=_lora_loader(),
734+
model_state_dict=model_state_dict,
735+
),
736+
):
737+
assert isinstance(unet, UNet2DConditionModel)
738+
latents = latents.to(device=unet.device, dtype=unet.dtype)
739+
if noise is not None:
740+
noise = noise.to(device=unet.device, dtype=unet.dtype)
741+
if mask is not None:
742+
mask = mask.to(device=unet.device, dtype=unet.dtype)
743+
if masked_latents is not None:
744+
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
745+
746+
scheduler = get_scheduler(
747+
context=context,
748+
scheduler_info=self.unet.scheduler,
749+
scheduler_name=self.scheduler,
750+
seed=seed,
751+
)
769752

770-
ip_adapter_data = self.prep_ip_adapter_data(
771-
context=context,
772-
ip_adapters=ip_adapters,
773-
image_prompts=image_prompts,
774-
exit_stack=exit_stack,
775-
latent_height=latent_height,
776-
latent_width=latent_width,
777-
dtype=unet.dtype,
778-
)
753+
pipeline = self.create_pipeline(unet, scheduler)
779754

780-
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
781-
scheduler,
782-
device=unet.device,
783-
steps=self.steps,
784-
denoising_start=self.denoising_start,
785-
denoising_end=self.denoising_end,
786-
seed=seed,
787-
)
755+
_, _, latent_height, latent_width = latents.shape
756+
conditioning_data = self.get_conditioning_data(
757+
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
758+
)
788759

789-
result_latents = pipeline.latents_from_embeddings(
790-
latents=latents,
791-
timesteps=timesteps,
792-
init_timestep=init_timestep,
793-
noise=noise,
794-
seed=seed,
795-
mask=mask,
796-
masked_latents=masked_latents,
797-
gradient_mask=gradient_mask,
798-
num_inference_steps=num_inference_steps,
799-
scheduler_step_kwargs=scheduler_step_kwargs,
800-
conditioning_data=conditioning_data,
801-
control_data=controlnet_data,
802-
ip_adapter_data=ip_adapter_data,
803-
t2i_adapter_data=t2i_adapter_data,
804-
callback=step_callback,
805-
)
760+
controlnet_data = self.prep_control_data(
761+
context=context,
762+
control_input=self.control,
763+
latents_shape=latents.shape,
764+
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
765+
do_classifier_free_guidance=True,
766+
exit_stack=exit_stack,
767+
)
768+
769+
ip_adapter_data = self.prep_ip_adapter_data(
770+
context=context,
771+
ip_adapters=ip_adapters,
772+
image_prompts=image_prompts,
773+
exit_stack=exit_stack,
774+
latent_height=latent_height,
775+
latent_width=latent_width,
776+
dtype=unet.dtype,
777+
)
778+
779+
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
780+
scheduler,
781+
device=unet.device,
782+
steps=self.steps,
783+
denoising_start=self.denoising_start,
784+
denoising_end=self.denoising_end,
785+
seed=seed,
786+
)
787+
788+
result_latents = pipeline.latents_from_embeddings(
789+
latents=latents,
790+
timesteps=timesteps,
791+
init_timestep=init_timestep,
792+
noise=noise,
793+
seed=seed,
794+
mask=mask,
795+
masked_latents=masked_latents,
796+
gradient_mask=gradient_mask,
797+
num_inference_steps=num_inference_steps,
798+
scheduler_step_kwargs=scheduler_step_kwargs,
799+
conditioning_data=conditioning_data,
800+
control_data=controlnet_data,
801+
ip_adapter_data=ip_adapter_data,
802+
t2i_adapter_data=t2i_adapter_data,
803+
callback=step_callback,
804+
)
806805

807-
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
808-
result_latents = result_latents.to("cpu")
809-
TorchDevice.empty_cache()
806+
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
807+
result_latents = result_latents.to("cpu")
808+
TorchDevice.empty_cache()
810809

811-
name = context.tensors.save(tensor=result_latents)
810+
name = context.tensors.save(tensor=result_latents)
812811
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

invokeai/app/services/config/config_default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class InvokeAIAppConfig(BaseSettings):
112112
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
113113
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
114114
max_queue_size: Maximum number of items in the session queue.
115+
clear_queue_on_startup: Empties session queue on startup.
115116
allow_nodes: List of nodes to allow. Omit to allow all.
116117
deny_nodes: List of nodes to deny. Omit to deny none.
117118
node_cache_size: How many cached nodes to keep in memory.
@@ -184,6 +185,7 @@ class InvokeAIAppConfig(BaseSettings):
184185
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
185186
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
186187
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
188+
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
187189

188190
# NODES
189191
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

invokeai/app/services/session_queue/session_queue_sqlite.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase):
3737
def start(self, invoker: Invoker) -> None:
3838
self.__invoker = invoker
3939
self._set_in_progress_to_canceled()
40-
prune_result = self.prune(DEFAULT_QUEUE_ID)
41-
42-
if prune_result.deleted > 0:
43-
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
40+
if self.__invoker.services.configuration.clear_queue_on_startup:
41+
clear_result = self.clear(DEFAULT_QUEUE_ID)
42+
if clear_result.deleted > 0:
43+
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
44+
else:
45+
prune_result = self.prune(DEFAULT_QUEUE_ID)
46+
if prune_result.deleted > 0:
47+
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
4448

4549
def __init__(self, db: SqliteDatabase) -> None:
4650
super().__init__()

invokeai/backend/model_manager/probe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import invokeai.backend.util.logging as logger
1111
from invokeai.app.util.misc import uuid_string
1212
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
13-
from invokeai.backend.util.util import SilenceWarnings
13+
from invokeai.backend.util.silence_warnings import SilenceWarnings
1414

1515
from .config import (
1616
AnyModelConfig,

0 commit comments

Comments
 (0)