Skip to content

Commit 79ceac2

Browse files
committed
(minor) Use SilenceWarnings as a decorator rather than a context manager to save an indentation level.
1 parent 8e47e00 commit 79ceac2

File tree

1 file changed

+138
-138
lines changed

1 file changed

+138
-138
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 138 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -657,155 +657,155 @@ def prep_inpaint_mask(
657657
return 1 - mask, masked_latents, self.denoise_mask.gradient
658658

659659
@torch.no_grad()
660+
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
660661
def invoke(self, context: InvocationContext) -> LatentsOutput:
661-
with SilenceWarnings(): # this quenches NSFW nag from diffusers
662-
seed = None
663-
noise = None
664-
if self.noise is not None:
665-
noise = context.tensors.load(self.noise.latents_name)
666-
seed = self.noise.seed
667-
668-
if self.latents is not None:
669-
latents = context.tensors.load(self.latents.latents_name)
670-
if seed is None:
671-
seed = self.latents.seed
672-
673-
if noise is not None and noise.shape[1:] != latents.shape[1:]:
674-
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
675-
676-
elif noise is not None:
677-
latents = torch.zeros_like(noise)
678-
else:
679-
raise Exception("'latents' or 'noise' must be provided!")
680-
662+
seed = None
663+
noise = None
664+
if self.noise is not None:
665+
noise = context.tensors.load(self.noise.latents_name)
666+
seed = self.noise.seed
667+
668+
if self.latents is not None:
669+
latents = context.tensors.load(self.latents.latents_name)
681670
if seed is None:
682-
seed = 0
671+
seed = self.latents.seed
683672

684-
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
673+
if noise is not None and noise.shape[1:] != latents.shape[1:]:
674+
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
685675

686-
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
687-
# below. Investigate whether this is appropriate.
688-
t2i_adapter_data = self.run_t2i_adapters(
689-
context,
690-
self.t2i_adapter,
691-
latents.shape,
692-
do_classifier_free_guidance=True,
693-
)
676+
elif noise is not None:
677+
latents = torch.zeros_like(noise)
678+
else:
679+
raise Exception("'latents' or 'noise' must be provided!")
694680

695-
ip_adapters: List[IPAdapterField] = []
696-
if self.ip_adapter is not None:
697-
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
698-
if isinstance(self.ip_adapter, list):
699-
ip_adapters = self.ip_adapter
700-
else:
701-
ip_adapters = [self.ip_adapter]
702-
703-
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
704-
# a series of image conditioning embeddings. This is being done here rather than in the
705-
# big model context below in order to use less VRAM on low-VRAM systems.
706-
# The image prompts are then passed to prep_ip_adapter_data().
707-
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
708-
709-
# get the unet's config so that we can pass the base to dispatch_progress()
710-
unet_config = context.models.get_config(self.unet.unet.key)
711-
712-
def step_callback(state: PipelineIntermediateState) -> None:
713-
context.util.sd_step_callback(state, unet_config.base)
714-
715-
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
716-
for lora in self.unet.loras:
717-
lora_info = context.models.load(lora.lora)
718-
assert isinstance(lora_info.model, LoRAModelRaw)
719-
yield (lora_info.model, lora.weight)
720-
del lora_info
721-
return
722-
723-
unet_info = context.models.load(self.unet.unet)
724-
assert isinstance(unet_info.model, UNet2DConditionModel)
725-
with (
726-
ExitStack() as exit_stack,
727-
unet_info.model_on_device() as (model_state_dict, unet),
728-
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
729-
set_seamless(unet, self.unet.seamless_axes), # FIXME
730-
# Apply the LoRA after unet has been moved to its target device for faster patching.
731-
ModelPatcher.apply_lora_unet(
732-
unet,
733-
loras=_lora_loader(),
734-
model_state_dict=model_state_dict,
735-
),
736-
):
737-
assert isinstance(unet, UNet2DConditionModel)
738-
latents = latents.to(device=unet.device, dtype=unet.dtype)
739-
if noise is not None:
740-
noise = noise.to(device=unet.device, dtype=unet.dtype)
741-
if mask is not None:
742-
mask = mask.to(device=unet.device, dtype=unet.dtype)
743-
if masked_latents is not None:
744-
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
745-
746-
scheduler = get_scheduler(
747-
context=context,
748-
scheduler_info=self.unet.scheduler,
749-
scheduler_name=self.scheduler,
750-
seed=seed,
751-
)
681+
if seed is None:
682+
seed = 0
752683

753-
pipeline = self.create_pipeline(unet, scheduler)
684+
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
754685

755-
_, _, latent_height, latent_width = latents.shape
756-
conditioning_data = self.get_conditioning_data(
757-
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
758-
)
686+
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
687+
# below. Investigate whether this is appropriate.
688+
t2i_adapter_data = self.run_t2i_adapters(
689+
context,
690+
self.t2i_adapter,
691+
latents.shape,
692+
do_classifier_free_guidance=True,
693+
)
759694

760-
controlnet_data = self.prep_control_data(
761-
context=context,
762-
control_input=self.control,
763-
latents_shape=latents.shape,
764-
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
765-
do_classifier_free_guidance=True,
766-
exit_stack=exit_stack,
767-
)
695+
ip_adapters: List[IPAdapterField] = []
696+
if self.ip_adapter is not None:
697+
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
698+
if isinstance(self.ip_adapter, list):
699+
ip_adapters = self.ip_adapter
700+
else:
701+
ip_adapters = [self.ip_adapter]
702+
703+
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
704+
# a series of image conditioning embeddings. This is being done here rather than in the
705+
# big model context below in order to use less VRAM on low-VRAM systems.
706+
# The image prompts are then passed to prep_ip_adapter_data().
707+
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
708+
709+
# get the unet's config so that we can pass the base to dispatch_progress()
710+
unet_config = context.models.get_config(self.unet.unet.key)
711+
712+
def step_callback(state: PipelineIntermediateState) -> None:
713+
context.util.sd_step_callback(state, unet_config.base)
714+
715+
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
716+
for lora in self.unet.loras:
717+
lora_info = context.models.load(lora.lora)
718+
assert isinstance(lora_info.model, LoRAModelRaw)
719+
yield (lora_info.model, lora.weight)
720+
del lora_info
721+
return
722+
723+
unet_info = context.models.load(self.unet.unet)
724+
assert isinstance(unet_info.model, UNet2DConditionModel)
725+
with (
726+
ExitStack() as exit_stack,
727+
unet_info.model_on_device() as (model_state_dict, unet),
728+
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
729+
set_seamless(unet, self.unet.seamless_axes), # FIXME
730+
# Apply the LoRA after unet has been moved to its target device for faster patching.
731+
ModelPatcher.apply_lora_unet(
732+
unet,
733+
loras=_lora_loader(),
734+
model_state_dict=model_state_dict,
735+
),
736+
):
737+
assert isinstance(unet, UNet2DConditionModel)
738+
latents = latents.to(device=unet.device, dtype=unet.dtype)
739+
if noise is not None:
740+
noise = noise.to(device=unet.device, dtype=unet.dtype)
741+
if mask is not None:
742+
mask = mask.to(device=unet.device, dtype=unet.dtype)
743+
if masked_latents is not None:
744+
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
745+
746+
scheduler = get_scheduler(
747+
context=context,
748+
scheduler_info=self.unet.scheduler,
749+
scheduler_name=self.scheduler,
750+
seed=seed,
751+
)
768752

769-
ip_adapter_data = self.prep_ip_adapter_data(
770-
context=context,
771-
ip_adapters=ip_adapters,
772-
image_prompts=image_prompts,
773-
exit_stack=exit_stack,
774-
latent_height=latent_height,
775-
latent_width=latent_width,
776-
dtype=unet.dtype,
777-
)
753+
pipeline = self.create_pipeline(unet, scheduler)
778754

779-
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
780-
scheduler,
781-
device=unet.device,
782-
steps=self.steps,
783-
denoising_start=self.denoising_start,
784-
denoising_end=self.denoising_end,
785-
seed=seed,
786-
)
755+
_, _, latent_height, latent_width = latents.shape
756+
conditioning_data = self.get_conditioning_data(
757+
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
758+
)
787759

788-
result_latents = pipeline.latents_from_embeddings(
789-
latents=latents,
790-
timesteps=timesteps,
791-
init_timestep=init_timestep,
792-
noise=noise,
793-
seed=seed,
794-
mask=mask,
795-
masked_latents=masked_latents,
796-
gradient_mask=gradient_mask,
797-
num_inference_steps=num_inference_steps,
798-
scheduler_step_kwargs=scheduler_step_kwargs,
799-
conditioning_data=conditioning_data,
800-
control_data=controlnet_data,
801-
ip_adapter_data=ip_adapter_data,
802-
t2i_adapter_data=t2i_adapter_data,
803-
callback=step_callback,
804-
)
760+
controlnet_data = self.prep_control_data(
761+
context=context,
762+
control_input=self.control,
763+
latents_shape=latents.shape,
764+
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
765+
do_classifier_free_guidance=True,
766+
exit_stack=exit_stack,
767+
)
768+
769+
ip_adapter_data = self.prep_ip_adapter_data(
770+
context=context,
771+
ip_adapters=ip_adapters,
772+
image_prompts=image_prompts,
773+
exit_stack=exit_stack,
774+
latent_height=latent_height,
775+
latent_width=latent_width,
776+
dtype=unet.dtype,
777+
)
778+
779+
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
780+
scheduler,
781+
device=unet.device,
782+
steps=self.steps,
783+
denoising_start=self.denoising_start,
784+
denoising_end=self.denoising_end,
785+
seed=seed,
786+
)
787+
788+
result_latents = pipeline.latents_from_embeddings(
789+
latents=latents,
790+
timesteps=timesteps,
791+
init_timestep=init_timestep,
792+
noise=noise,
793+
seed=seed,
794+
mask=mask,
795+
masked_latents=masked_latents,
796+
gradient_mask=gradient_mask,
797+
num_inference_steps=num_inference_steps,
798+
scheduler_step_kwargs=scheduler_step_kwargs,
799+
conditioning_data=conditioning_data,
800+
control_data=controlnet_data,
801+
ip_adapter_data=ip_adapter_data,
802+
t2i_adapter_data=t2i_adapter_data,
803+
callback=step_callback,
804+
)
805805

806-
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
807-
result_latents = result_latents.to("cpu")
808-
TorchDevice.empty_cache()
806+
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
807+
result_latents = result_latents.to("cpu")
808+
TorchDevice.empty_cache()
809809

810-
name = context.tensors.save(tensor=result_latents)
810+
name = context.tensors.save(tensor=result_latents)
811811
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

0 commit comments

Comments
 (0)