|
16 | 16 | from torchvision.transforms.functional import resize as tv_resize
|
17 | 17 | from transformers import CLIPVisionModelWithProjection
|
18 | 18 |
|
| 19 | +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation |
19 | 20 | from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
| 21 | +from invokeai.app.invocations.controlnet_image_processors import ControlField |
20 | 22 | from invokeai.app.invocations.fields import (
|
21 | 23 | ConditioningField,
|
22 | 24 | DenoiseMaskField,
|
|
27 | 29 | UIType,
|
28 | 30 | )
|
29 | 31 | from invokeai.app.invocations.ip_adapter import IPAdapterField
|
| 32 | +from invokeai.app.invocations.model import ModelIdentifierField, UNetField |
30 | 33 | from invokeai.app.invocations.primitives import LatentsOutput
|
31 | 34 | from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
32 | 35 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
36 | 39 | from invokeai.backend.model_manager import BaseModelType
|
37 | 40 | from invokeai.backend.model_patcher import ModelPatcher
|
38 | 41 | from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
| 42 | +from invokeai.backend.stable_diffusion.diffusers_pipeline import ( |
| 43 | + ControlNetData, |
| 44 | + StableDiffusionGeneratorPipeline, |
| 45 | + T2IAdapterData, |
| 46 | +) |
39 | 47 | from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
40 | 48 | BasicConditioningInfo,
|
41 | 49 | IPAdapterConditioningInfo,
|
|
45 | 53 | TextConditioningData,
|
46 | 54 | TextConditioningRegions,
|
47 | 55 | )
|
| 56 | +from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP |
| 57 | +from invokeai.backend.util.devices import TorchDevice |
48 | 58 | from invokeai.backend.util.mask import to_standard_float_mask
|
49 | 59 | from invokeai.backend.util.silence_warnings import SilenceWarnings
|
50 | 60 |
|
51 |
| -from ...backend.stable_diffusion.diffusers_pipeline import ( |
52 |
| - ControlNetData, |
53 |
| - StableDiffusionGeneratorPipeline, |
54 |
| - T2IAdapterData, |
55 |
| -) |
56 |
| -from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP |
57 |
| -from ...backend.util.devices import TorchDevice |
58 |
| -from .baseinvocation import BaseInvocation, invocation |
59 |
| -from .controlnet_image_processors import ControlField |
60 |
| -from .model import ModelIdentifierField, UNetField |
61 |
| - |
62 | 61 |
|
63 | 62 | def get_scheduler(
|
64 | 63 | context: InvocationContext,
|
@@ -658,155 +657,155 @@ def prep_inpaint_mask(
|
658 | 657 | return 1 - mask, masked_latents, self.denoise_mask.gradient
|
659 | 658 |
|
660 | 659 | @torch.no_grad()
|
| 660 | + @SilenceWarnings() # This quenches the NSFW nag from diffusers. |
661 | 661 | def invoke(self, context: InvocationContext) -> LatentsOutput:
|
662 |
| - with SilenceWarnings(): # this quenches NSFW nag from diffusers |
663 |
| - seed = None |
664 |
| - noise = None |
665 |
| - if self.noise is not None: |
666 |
| - noise = context.tensors.load(self.noise.latents_name) |
667 |
| - seed = self.noise.seed |
668 |
| - |
669 |
| - if self.latents is not None: |
670 |
| - latents = context.tensors.load(self.latents.latents_name) |
671 |
| - if seed is None: |
672 |
| - seed = self.latents.seed |
673 |
| - |
674 |
| - if noise is not None and noise.shape[1:] != latents.shape[1:]: |
675 |
| - raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") |
676 |
| - |
677 |
| - elif noise is not None: |
678 |
| - latents = torch.zeros_like(noise) |
679 |
| - else: |
680 |
| - raise Exception("'latents' or 'noise' must be provided!") |
681 |
| - |
| 662 | + seed = None |
| 663 | + noise = None |
| 664 | + if self.noise is not None: |
| 665 | + noise = context.tensors.load(self.noise.latents_name) |
| 666 | + seed = self.noise.seed |
| 667 | + |
| 668 | + if self.latents is not None: |
| 669 | + latents = context.tensors.load(self.latents.latents_name) |
682 | 670 | if seed is None:
|
683 |
| - seed = 0 |
| 671 | + seed = self.latents.seed |
684 | 672 |
|
685 |
| - mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) |
| 673 | + if noise is not None and noise.shape[1:] != latents.shape[1:]: |
| 674 | + raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") |
686 | 675 |
|
687 |
| - # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, |
688 |
| - # below. Investigate whether this is appropriate. |
689 |
| - t2i_adapter_data = self.run_t2i_adapters( |
690 |
| - context, |
691 |
| - self.t2i_adapter, |
692 |
| - latents.shape, |
693 |
| - do_classifier_free_guidance=True, |
694 |
| - ) |
| 676 | + elif noise is not None: |
| 677 | + latents = torch.zeros_like(noise) |
| 678 | + else: |
| 679 | + raise Exception("'latents' or 'noise' must be provided!") |
695 | 680 |
|
696 |
| - ip_adapters: List[IPAdapterField] = [] |
697 |
| - if self.ip_adapter is not None: |
698 |
| - # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. |
699 |
| - if isinstance(self.ip_adapter, list): |
700 |
| - ip_adapters = self.ip_adapter |
701 |
| - else: |
702 |
| - ip_adapters = [self.ip_adapter] |
703 |
| - |
704 |
| - # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return |
705 |
| - # a series of image conditioning embeddings. This is being done here rather than in the |
706 |
| - # big model context below in order to use less VRAM on low-VRAM systems. |
707 |
| - # The image prompts are then passed to prep_ip_adapter_data(). |
708 |
| - image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) |
709 |
| - |
710 |
| - # get the unet's config so that we can pass the base to dispatch_progress() |
711 |
| - unet_config = context.models.get_config(self.unet.unet.key) |
712 |
| - |
713 |
| - def step_callback(state: PipelineIntermediateState) -> None: |
714 |
| - context.util.sd_step_callback(state, unet_config.base) |
715 |
| - |
716 |
| - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: |
717 |
| - for lora in self.unet.loras: |
718 |
| - lora_info = context.models.load(lora.lora) |
719 |
| - assert isinstance(lora_info.model, LoRAModelRaw) |
720 |
| - yield (lora_info.model, lora.weight) |
721 |
| - del lora_info |
722 |
| - return |
723 |
| - |
724 |
| - unet_info = context.models.load(self.unet.unet) |
725 |
| - assert isinstance(unet_info.model, UNet2DConditionModel) |
726 |
| - with ( |
727 |
| - ExitStack() as exit_stack, |
728 |
| - unet_info.model_on_device() as (model_state_dict, unet), |
729 |
| - ModelPatcher.apply_freeu(unet, self.unet.freeu_config), |
730 |
| - set_seamless(unet, self.unet.seamless_axes), # FIXME |
731 |
| - # Apply the LoRA after unet has been moved to its target device for faster patching. |
732 |
| - ModelPatcher.apply_lora_unet( |
733 |
| - unet, |
734 |
| - loras=_lora_loader(), |
735 |
| - model_state_dict=model_state_dict, |
736 |
| - ), |
737 |
| - ): |
738 |
| - assert isinstance(unet, UNet2DConditionModel) |
739 |
| - latents = latents.to(device=unet.device, dtype=unet.dtype) |
740 |
| - if noise is not None: |
741 |
| - noise = noise.to(device=unet.device, dtype=unet.dtype) |
742 |
| - if mask is not None: |
743 |
| - mask = mask.to(device=unet.device, dtype=unet.dtype) |
744 |
| - if masked_latents is not None: |
745 |
| - masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) |
746 |
| - |
747 |
| - scheduler = get_scheduler( |
748 |
| - context=context, |
749 |
| - scheduler_info=self.unet.scheduler, |
750 |
| - scheduler_name=self.scheduler, |
751 |
| - seed=seed, |
752 |
| - ) |
| 681 | + if seed is None: |
| 682 | + seed = 0 |
753 | 683 |
|
754 |
| - pipeline = self.create_pipeline(unet, scheduler) |
| 684 | + mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) |
755 | 685 |
|
756 |
| - _, _, latent_height, latent_width = latents.shape |
757 |
| - conditioning_data = self.get_conditioning_data( |
758 |
| - context=context, unet=unet, latent_height=latent_height, latent_width=latent_width |
759 |
| - ) |
| 686 | + # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, |
| 687 | + # below. Investigate whether this is appropriate. |
| 688 | + t2i_adapter_data = self.run_t2i_adapters( |
| 689 | + context, |
| 690 | + self.t2i_adapter, |
| 691 | + latents.shape, |
| 692 | + do_classifier_free_guidance=True, |
| 693 | + ) |
760 | 694 |
|
761 |
| - controlnet_data = self.prep_control_data( |
762 |
| - context=context, |
763 |
| - control_input=self.control, |
764 |
| - latents_shape=latents.shape, |
765 |
| - # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) |
766 |
| - do_classifier_free_guidance=True, |
767 |
| - exit_stack=exit_stack, |
768 |
| - ) |
| 695 | + ip_adapters: List[IPAdapterField] = [] |
| 696 | + if self.ip_adapter is not None: |
| 697 | + # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. |
| 698 | + if isinstance(self.ip_adapter, list): |
| 699 | + ip_adapters = self.ip_adapter |
| 700 | + else: |
| 701 | + ip_adapters = [self.ip_adapter] |
| 702 | + |
| 703 | + # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return |
| 704 | + # a series of image conditioning embeddings. This is being done here rather than in the |
| 705 | + # big model context below in order to use less VRAM on low-VRAM systems. |
| 706 | + # The image prompts are then passed to prep_ip_adapter_data(). |
| 707 | + image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) |
| 708 | + |
| 709 | + # get the unet's config so that we can pass the base to dispatch_progress() |
| 710 | + unet_config = context.models.get_config(self.unet.unet.key) |
| 711 | + |
| 712 | + def step_callback(state: PipelineIntermediateState) -> None: |
| 713 | + context.util.sd_step_callback(state, unet_config.base) |
| 714 | + |
| 715 | + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: |
| 716 | + for lora in self.unet.loras: |
| 717 | + lora_info = context.models.load(lora.lora) |
| 718 | + assert isinstance(lora_info.model, LoRAModelRaw) |
| 719 | + yield (lora_info.model, lora.weight) |
| 720 | + del lora_info |
| 721 | + return |
| 722 | + |
| 723 | + unet_info = context.models.load(self.unet.unet) |
| 724 | + assert isinstance(unet_info.model, UNet2DConditionModel) |
| 725 | + with ( |
| 726 | + ExitStack() as exit_stack, |
| 727 | + unet_info.model_on_device() as (model_state_dict, unet), |
| 728 | + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), |
| 729 | + set_seamless(unet, self.unet.seamless_axes), # FIXME |
| 730 | + # Apply the LoRA after unet has been moved to its target device for faster patching. |
| 731 | + ModelPatcher.apply_lora_unet( |
| 732 | + unet, |
| 733 | + loras=_lora_loader(), |
| 734 | + model_state_dict=model_state_dict, |
| 735 | + ), |
| 736 | + ): |
| 737 | + assert isinstance(unet, UNet2DConditionModel) |
| 738 | + latents = latents.to(device=unet.device, dtype=unet.dtype) |
| 739 | + if noise is not None: |
| 740 | + noise = noise.to(device=unet.device, dtype=unet.dtype) |
| 741 | + if mask is not None: |
| 742 | + mask = mask.to(device=unet.device, dtype=unet.dtype) |
| 743 | + if masked_latents is not None: |
| 744 | + masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) |
| 745 | + |
| 746 | + scheduler = get_scheduler( |
| 747 | + context=context, |
| 748 | + scheduler_info=self.unet.scheduler, |
| 749 | + scheduler_name=self.scheduler, |
| 750 | + seed=seed, |
| 751 | + ) |
769 | 752 |
|
770 |
| - ip_adapter_data = self.prep_ip_adapter_data( |
771 |
| - context=context, |
772 |
| - ip_adapters=ip_adapters, |
773 |
| - image_prompts=image_prompts, |
774 |
| - exit_stack=exit_stack, |
775 |
| - latent_height=latent_height, |
776 |
| - latent_width=latent_width, |
777 |
| - dtype=unet.dtype, |
778 |
| - ) |
| 753 | + pipeline = self.create_pipeline(unet, scheduler) |
779 | 754 |
|
780 |
| - num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( |
781 |
| - scheduler, |
782 |
| - device=unet.device, |
783 |
| - steps=self.steps, |
784 |
| - denoising_start=self.denoising_start, |
785 |
| - denoising_end=self.denoising_end, |
786 |
| - seed=seed, |
787 |
| - ) |
| 755 | + _, _, latent_height, latent_width = latents.shape |
| 756 | + conditioning_data = self.get_conditioning_data( |
| 757 | + context=context, unet=unet, latent_height=latent_height, latent_width=latent_width |
| 758 | + ) |
788 | 759 |
|
789 |
| - result_latents = pipeline.latents_from_embeddings( |
790 |
| - latents=latents, |
791 |
| - timesteps=timesteps, |
792 |
| - init_timestep=init_timestep, |
793 |
| - noise=noise, |
794 |
| - seed=seed, |
795 |
| - mask=mask, |
796 |
| - masked_latents=masked_latents, |
797 |
| - gradient_mask=gradient_mask, |
798 |
| - num_inference_steps=num_inference_steps, |
799 |
| - scheduler_step_kwargs=scheduler_step_kwargs, |
800 |
| - conditioning_data=conditioning_data, |
801 |
| - control_data=controlnet_data, |
802 |
| - ip_adapter_data=ip_adapter_data, |
803 |
| - t2i_adapter_data=t2i_adapter_data, |
804 |
| - callback=step_callback, |
805 |
| - ) |
| 760 | + controlnet_data = self.prep_control_data( |
| 761 | + context=context, |
| 762 | + control_input=self.control, |
| 763 | + latents_shape=latents.shape, |
| 764 | + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) |
| 765 | + do_classifier_free_guidance=True, |
| 766 | + exit_stack=exit_stack, |
| 767 | + ) |
| 768 | + |
| 769 | + ip_adapter_data = self.prep_ip_adapter_data( |
| 770 | + context=context, |
| 771 | + ip_adapters=ip_adapters, |
| 772 | + image_prompts=image_prompts, |
| 773 | + exit_stack=exit_stack, |
| 774 | + latent_height=latent_height, |
| 775 | + latent_width=latent_width, |
| 776 | + dtype=unet.dtype, |
| 777 | + ) |
| 778 | + |
| 779 | + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( |
| 780 | + scheduler, |
| 781 | + device=unet.device, |
| 782 | + steps=self.steps, |
| 783 | + denoising_start=self.denoising_start, |
| 784 | + denoising_end=self.denoising_end, |
| 785 | + seed=seed, |
| 786 | + ) |
| 787 | + |
| 788 | + result_latents = pipeline.latents_from_embeddings( |
| 789 | + latents=latents, |
| 790 | + timesteps=timesteps, |
| 791 | + init_timestep=init_timestep, |
| 792 | + noise=noise, |
| 793 | + seed=seed, |
| 794 | + mask=mask, |
| 795 | + masked_latents=masked_latents, |
| 796 | + gradient_mask=gradient_mask, |
| 797 | + num_inference_steps=num_inference_steps, |
| 798 | + scheduler_step_kwargs=scheduler_step_kwargs, |
| 799 | + conditioning_data=conditioning_data, |
| 800 | + control_data=controlnet_data, |
| 801 | + ip_adapter_data=ip_adapter_data, |
| 802 | + t2i_adapter_data=t2i_adapter_data, |
| 803 | + callback=step_callback, |
| 804 | + ) |
806 | 805 |
|
807 |
| - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 |
808 |
| - result_latents = result_latents.to("cpu") |
809 |
| - TorchDevice.empty_cache() |
| 806 | + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 |
| 807 | + result_latents = result_latents.to("cpu") |
| 808 | + TorchDevice.empty_cache() |
810 | 809 |
|
811 |
| - name = context.tensors.save(tensor=result_latents) |
| 810 | + name = context.tensors.save(tensor=result_latents) |
812 | 811 | return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
0 commit comments