@@ -657,155 +657,155 @@ def prep_inpaint_mask(
657
657
return 1 - mask , masked_latents , self .denoise_mask .gradient
658
658
659
659
@torch .no_grad ()
660
+ @SilenceWarnings () # This quenches the NSFW nag from diffusers.
660
661
def invoke (self , context : InvocationContext ) -> LatentsOutput :
661
- with SilenceWarnings (): # this quenches NSFW nag from diffusers
662
- seed = None
663
- noise = None
664
- if self .noise is not None :
665
- noise = context .tensors .load (self .noise .latents_name )
666
- seed = self .noise .seed
667
-
668
- if self .latents is not None :
669
- latents = context .tensors .load (self .latents .latents_name )
670
- if seed is None :
671
- seed = self .latents .seed
672
-
673
- if noise is not None and noise .shape [1 :] != latents .shape [1 :]:
674
- raise Exception (f"Incompatable 'noise' and 'latents' shapes: { latents .shape = } { noise .shape = } " )
675
-
676
- elif noise is not None :
677
- latents = torch .zeros_like (noise )
678
- else :
679
- raise Exception ("'latents' or 'noise' must be provided!" )
680
-
662
+ seed = None
663
+ noise = None
664
+ if self .noise is not None :
665
+ noise = context .tensors .load (self .noise .latents_name )
666
+ seed = self .noise .seed
667
+
668
+ if self .latents is not None :
669
+ latents = context .tensors .load (self .latents .latents_name )
681
670
if seed is None :
682
- seed = 0
671
+ seed = self . latents . seed
683
672
684
- mask , masked_latents , gradient_mask = self .prep_inpaint_mask (context , latents )
673
+ if noise is not None and noise .shape [1 :] != latents .shape [1 :]:
674
+ raise Exception (f"Incompatable 'noise' and 'latents' shapes: { latents .shape = } { noise .shape = } " )
685
675
686
- # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
687
- # below. Investigate whether this is appropriate.
688
- t2i_adapter_data = self .run_t2i_adapters (
689
- context ,
690
- self .t2i_adapter ,
691
- latents .shape ,
692
- do_classifier_free_guidance = True ,
693
- )
676
+ elif noise is not None :
677
+ latents = torch .zeros_like (noise )
678
+ else :
679
+ raise Exception ("'latents' or 'noise' must be provided!" )
694
680
695
- ip_adapters : List [IPAdapterField ] = []
696
- if self .ip_adapter is not None :
697
- # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
698
- if isinstance (self .ip_adapter , list ):
699
- ip_adapters = self .ip_adapter
700
- else :
701
- ip_adapters = [self .ip_adapter ]
702
-
703
- # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
704
- # a series of image conditioning embeddings. This is being done here rather than in the
705
- # big model context below in order to use less VRAM on low-VRAM systems.
706
- # The image prompts are then passed to prep_ip_adapter_data().
707
- image_prompts = self .prep_ip_adapter_image_prompts (context = context , ip_adapters = ip_adapters )
708
-
709
- # get the unet's config so that we can pass the base to dispatch_progress()
710
- unet_config = context .models .get_config (self .unet .unet .key )
711
-
712
- def step_callback (state : PipelineIntermediateState ) -> None :
713
- context .util .sd_step_callback (state , unet_config .base )
714
-
715
- def _lora_loader () -> Iterator [Tuple [LoRAModelRaw , float ]]:
716
- for lora in self .unet .loras :
717
- lora_info = context .models .load (lora .lora )
718
- assert isinstance (lora_info .model , LoRAModelRaw )
719
- yield (lora_info .model , lora .weight )
720
- del lora_info
721
- return
722
-
723
- unet_info = context .models .load (self .unet .unet )
724
- assert isinstance (unet_info .model , UNet2DConditionModel )
725
- with (
726
- ExitStack () as exit_stack ,
727
- unet_info .model_on_device () as (model_state_dict , unet ),
728
- ModelPatcher .apply_freeu (unet , self .unet .freeu_config ),
729
- set_seamless (unet , self .unet .seamless_axes ), # FIXME
730
- # Apply the LoRA after unet has been moved to its target device for faster patching.
731
- ModelPatcher .apply_lora_unet (
732
- unet ,
733
- loras = _lora_loader (),
734
- model_state_dict = model_state_dict ,
735
- ),
736
- ):
737
- assert isinstance (unet , UNet2DConditionModel )
738
- latents = latents .to (device = unet .device , dtype = unet .dtype )
739
- if noise is not None :
740
- noise = noise .to (device = unet .device , dtype = unet .dtype )
741
- if mask is not None :
742
- mask = mask .to (device = unet .device , dtype = unet .dtype )
743
- if masked_latents is not None :
744
- masked_latents = masked_latents .to (device = unet .device , dtype = unet .dtype )
745
-
746
- scheduler = get_scheduler (
747
- context = context ,
748
- scheduler_info = self .unet .scheduler ,
749
- scheduler_name = self .scheduler ,
750
- seed = seed ,
751
- )
681
+ if seed is None :
682
+ seed = 0
752
683
753
- pipeline = self .create_pipeline ( unet , scheduler )
684
+ mask , masked_latents , gradient_mask = self .prep_inpaint_mask ( context , latents )
754
685
755
- _ , _ , latent_height , latent_width = latents .shape
756
- conditioning_data = self .get_conditioning_data (
757
- context = context , unet = unet , latent_height = latent_height , latent_width = latent_width
758
- )
686
+ # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
687
+ # below. Investigate whether this is appropriate.
688
+ t2i_adapter_data = self .run_t2i_adapters (
689
+ context ,
690
+ self .t2i_adapter ,
691
+ latents .shape ,
692
+ do_classifier_free_guidance = True ,
693
+ )
759
694
760
- controlnet_data = self .prep_control_data (
761
- context = context ,
762
- control_input = self .control ,
763
- latents_shape = latents .shape ,
764
- # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
765
- do_classifier_free_guidance = True ,
766
- exit_stack = exit_stack ,
767
- )
695
+ ip_adapters : List [IPAdapterField ] = []
696
+ if self .ip_adapter is not None :
697
+ # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
698
+ if isinstance (self .ip_adapter , list ):
699
+ ip_adapters = self .ip_adapter
700
+ else :
701
+ ip_adapters = [self .ip_adapter ]
702
+
703
+ # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
704
+ # a series of image conditioning embeddings. This is being done here rather than in the
705
+ # big model context below in order to use less VRAM on low-VRAM systems.
706
+ # The image prompts are then passed to prep_ip_adapter_data().
707
+ image_prompts = self .prep_ip_adapter_image_prompts (context = context , ip_adapters = ip_adapters )
708
+
709
+ # get the unet's config so that we can pass the base to dispatch_progress()
710
+ unet_config = context .models .get_config (self .unet .unet .key )
711
+
712
+ def step_callback (state : PipelineIntermediateState ) -> None :
713
+ context .util .sd_step_callback (state , unet_config .base )
714
+
715
+ def _lora_loader () -> Iterator [Tuple [LoRAModelRaw , float ]]:
716
+ for lora in self .unet .loras :
717
+ lora_info = context .models .load (lora .lora )
718
+ assert isinstance (lora_info .model , LoRAModelRaw )
719
+ yield (lora_info .model , lora .weight )
720
+ del lora_info
721
+ return
722
+
723
+ unet_info = context .models .load (self .unet .unet )
724
+ assert isinstance (unet_info .model , UNet2DConditionModel )
725
+ with (
726
+ ExitStack () as exit_stack ,
727
+ unet_info .model_on_device () as (model_state_dict , unet ),
728
+ ModelPatcher .apply_freeu (unet , self .unet .freeu_config ),
729
+ set_seamless (unet , self .unet .seamless_axes ), # FIXME
730
+ # Apply the LoRA after unet has been moved to its target device for faster patching.
731
+ ModelPatcher .apply_lora_unet (
732
+ unet ,
733
+ loras = _lora_loader (),
734
+ model_state_dict = model_state_dict ,
735
+ ),
736
+ ):
737
+ assert isinstance (unet , UNet2DConditionModel )
738
+ latents = latents .to (device = unet .device , dtype = unet .dtype )
739
+ if noise is not None :
740
+ noise = noise .to (device = unet .device , dtype = unet .dtype )
741
+ if mask is not None :
742
+ mask = mask .to (device = unet .device , dtype = unet .dtype )
743
+ if masked_latents is not None :
744
+ masked_latents = masked_latents .to (device = unet .device , dtype = unet .dtype )
745
+
746
+ scheduler = get_scheduler (
747
+ context = context ,
748
+ scheduler_info = self .unet .scheduler ,
749
+ scheduler_name = self .scheduler ,
750
+ seed = seed ,
751
+ )
768
752
769
- ip_adapter_data = self .prep_ip_adapter_data (
770
- context = context ,
771
- ip_adapters = ip_adapters ,
772
- image_prompts = image_prompts ,
773
- exit_stack = exit_stack ,
774
- latent_height = latent_height ,
775
- latent_width = latent_width ,
776
- dtype = unet .dtype ,
777
- )
753
+ pipeline = self .create_pipeline (unet , scheduler )
778
754
779
- num_inference_steps , timesteps , init_timestep , scheduler_step_kwargs = self .init_scheduler (
780
- scheduler ,
781
- device = unet .device ,
782
- steps = self .steps ,
783
- denoising_start = self .denoising_start ,
784
- denoising_end = self .denoising_end ,
785
- seed = seed ,
786
- )
755
+ _ , _ , latent_height , latent_width = latents .shape
756
+ conditioning_data = self .get_conditioning_data (
757
+ context = context , unet = unet , latent_height = latent_height , latent_width = latent_width
758
+ )
787
759
788
- result_latents = pipeline .latents_from_embeddings (
789
- latents = latents ,
790
- timesteps = timesteps ,
791
- init_timestep = init_timestep ,
792
- noise = noise ,
793
- seed = seed ,
794
- mask = mask ,
795
- masked_latents = masked_latents ,
796
- gradient_mask = gradient_mask ,
797
- num_inference_steps = num_inference_steps ,
798
- scheduler_step_kwargs = scheduler_step_kwargs ,
799
- conditioning_data = conditioning_data ,
800
- control_data = controlnet_data ,
801
- ip_adapter_data = ip_adapter_data ,
802
- t2i_adapter_data = t2i_adapter_data ,
803
- callback = step_callback ,
804
- )
760
+ controlnet_data = self .prep_control_data (
761
+ context = context ,
762
+ control_input = self .control ,
763
+ latents_shape = latents .shape ,
764
+ # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
765
+ do_classifier_free_guidance = True ,
766
+ exit_stack = exit_stack ,
767
+ )
768
+
769
+ ip_adapter_data = self .prep_ip_adapter_data (
770
+ context = context ,
771
+ ip_adapters = ip_adapters ,
772
+ image_prompts = image_prompts ,
773
+ exit_stack = exit_stack ,
774
+ latent_height = latent_height ,
775
+ latent_width = latent_width ,
776
+ dtype = unet .dtype ,
777
+ )
778
+
779
+ num_inference_steps , timesteps , init_timestep , scheduler_step_kwargs = self .init_scheduler (
780
+ scheduler ,
781
+ device = unet .device ,
782
+ steps = self .steps ,
783
+ denoising_start = self .denoising_start ,
784
+ denoising_end = self .denoising_end ,
785
+ seed = seed ,
786
+ )
787
+
788
+ result_latents = pipeline .latents_from_embeddings (
789
+ latents = latents ,
790
+ timesteps = timesteps ,
791
+ init_timestep = init_timestep ,
792
+ noise = noise ,
793
+ seed = seed ,
794
+ mask = mask ,
795
+ masked_latents = masked_latents ,
796
+ gradient_mask = gradient_mask ,
797
+ num_inference_steps = num_inference_steps ,
798
+ scheduler_step_kwargs = scheduler_step_kwargs ,
799
+ conditioning_data = conditioning_data ,
800
+ control_data = controlnet_data ,
801
+ ip_adapter_data = ip_adapter_data ,
802
+ t2i_adapter_data = t2i_adapter_data ,
803
+ callback = step_callback ,
804
+ )
805
805
806
- # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
807
- result_latents = result_latents .to ("cpu" )
808
- TorchDevice .empty_cache ()
806
+ # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
807
+ result_latents = result_latents .to ("cpu" )
808
+ TorchDevice .empty_cache ()
809
809
810
- name = context .tensors .save (tensor = result_latents )
810
+ name = context .tensors .save (tensor = result_latents )
811
811
return LatentsOutput .build (latents_name = name , latents = result_latents , seed = None )
0 commit comments