Skip to content

Commit 1a17cd8

Browse files
committed
add vae convenience methods
1 parent cf37a8a commit 1a17cd8

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,35 @@ def _denormalize_latents(
113113
latents = latents * latents_std / scaling_factor + latents_mean
114114
return latents
115115

116+
def enable_vae_slicing(self):
117+
r"""
118+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
119+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
120+
"""
121+
self.vae.enable_slicing()
122+
123+
def disable_vae_slicing(self):
124+
r"""
125+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
126+
computing decoding in one step.
127+
"""
128+
self.vae.disable_slicing()
129+
130+
def enable_vae_tiling(self):
131+
r"""
132+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
133+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
134+
processing larger images.
135+
"""
136+
self.vae.enable_tiling()
137+
138+
def disable_vae_tiling(self):
139+
r"""
140+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
141+
computing decoding in one step.
142+
"""
143+
self.vae.disable_tiling()
144+
116145
def check_inputs(self, video, height, width, latents):
117146
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
118147
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")

0 commit comments

Comments
 (0)