@@ -113,6 +113,35 @@ def _denormalize_latents(
113
113
latents = latents * latents_std / scaling_factor + latents_mean
114
114
return latents
115
115
116
+ def enable_vae_slicing (self ):
117
+ r"""
118
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
119
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
120
+ """
121
+ self .vae .enable_slicing ()
122
+
123
+ def disable_vae_slicing (self ):
124
+ r"""
125
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
126
+ computing decoding in one step.
127
+ """
128
+ self .vae .disable_slicing ()
129
+
130
+ def enable_vae_tiling (self ):
131
+ r"""
132
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
133
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
134
+ processing larger images.
135
+ """
136
+ self .vae .enable_tiling ()
137
+
138
+ def disable_vae_tiling (self ):
139
+ r"""
140
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
141
+ computing decoding in one step.
142
+ """
143
+ self .vae .disable_tiling ()
144
+
116
145
def check_inputs (self , video , height , width , latents ):
117
146
if height % self .vae_spatial_compression_ratio != 0 or width % self .vae_spatial_compression_ratio != 0 :
118
147
raise ValueError (f"`height` and `width` have to be divisible by 32 but are { height } and { width } ." )
0 commit comments