diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 1a72aa3cfeb3..f11c7db25386 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -15,11 +15,12 @@ import torch import torch.nn as nn +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import deprecate +from ...utils import deprecate, logging from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -35,6 +36,9 @@ from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -127,6 +131,7 @@ def __init__( self.use_slicing = False self.use_tiling = False + self.use_dp = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size @@ -214,9 +219,58 @@ def set_default_attn_processor(self): self.set_attn_processor(processor) + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + self.spatial_compression_ratio = 2 ** (len(self.config.block_out_channels) - 1) + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape + if self.use_dp: + return self._tiled_encode(x) if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): return self._tiled_encode(x) @@ -256,6 +310,8 @@ def encode( return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) @@ -310,6 +366,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, :, None].to(a.dtype) + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, :].to(a.dtype) + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, x] * blend_ratio + return b + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. @@ -469,6 +539,157 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width + + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width + + def _tiled_encode_with_dp(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + _, _, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: + tile = x[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + return tile + + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device + ) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, latent_height, latent_width = z.shape + device = z.device + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end + ) -> torch.Tensor: + + tile = z[:, :, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + return decoded + + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device + ) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) + result_row.append(tile[:, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2)[:, :, :sample_height, :sample_width] + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f8bdfeb75524..f034cebde12b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin @@ -1073,6 +1074,8 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + self.use_dp = False + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) @@ -1113,6 +1116,52 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] @@ -1130,6 +1179,9 @@ def _encode(self, x: torch.Tensor): if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + if self.use_dp: + return self.tiled_encode_with_dp(x) + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1182,6 +1234,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) @@ -1249,6 +1304,20 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch. ) return b + def blend_v_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + y = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (y / blend_extent)[None, None, None, :, None].to(a.dtype) + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - blend_ratio) + b[:, :, :, y, :] * blend_ratio + return b + + def blend_h_(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + x = torch.arange(0, blend_extent, device=a.device) + blend_ratio = (x / blend_extent)[None, None, None, None, :].to(a.dtype) + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - blend_ratio) + b[:, :, :, :, x] * blend_ratio + return b + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. @@ -1393,6 +1462,191 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) + def calculate_tiled_parallel_size(self, latent_height, latent_width): + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + blend_latent_height = tile_latent_min_height - tile_latent_stride_height + blend_latent_width = tile_latent_min_width - tile_latent_stride_width + + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_sample_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_sample_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_sample_height = tile_sample_min_height - tile_sample_stride_height + blend_sample_width = tile_sample_min_width - tile_sample_stride_width + + return \ + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width + + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, sample_height, sample_width = x.shape + device = x.device + latent_height = sample_height // self.spatial_compression_ratio + latent_width = sample_width // self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_encode_op( + x, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: + + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + patch_height_start : patch_height_end, + patch_width_start : patch_width_end, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + time = torch.cat(time, dim=2) + self.clear_cache() + return time + + rows = self.run_vae_tile_parallel( + x, vae_encode_op, + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, device, + num_frames=num_frames + ) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_latent_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_latent_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, latent_height, latent_width = z.shape + device = z.device + sample_height = latent_height * self.spatial_compression_ratio + sample_width = latent_width * self.spatial_compression_ratio + + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, \ + tile_sample_min_height, tile_sample_min_width, tile_sample_stride_height, tile_sample_stride_width, \ + blend_latent_height, blend_latent_width, blend_sample_height, blend_sample_width = \ + self.calculate_tiled_parallel_size(latent_height, latent_width) + + def vae_decode_op( + z, patch_height_start, patch_height_end, patch_width_start, patch_width_end, num_frames + ) -> torch.Tensor: + + self.clear_cache() + + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + time = torch.cat(time, dim=2) + self.clear_cache() + return time + + rows = self.run_vae_tile_parallel( + z, vae_decode_op, + tile_latent_min_height, tile_latent_min_width, tile_latent_stride_height, tile_latent_stride_width, device, + num_frames=num_frames + ) + + # combine all tiles, same as tiled decode + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v_(rows[i - 1][j], tile, blend_sample_height) + if j > 0: + tile = self.blend_h_(row[j - 1], tile, blend_sample_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 9c6031a988f9..d798711ec240 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List import numpy as np import torch import torch.nn as nn +import torch.distributed as dist from ...utils import BaseOutput from ...utils.torch_utils import randn_tensor @@ -926,3 +927,78 @@ def disable_slicing(self): decoding in one step. """ self.use_slicing = False + + def enable_dp(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling Parallel doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_dp = True + + def disable_dp(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_dp = False + + def run_vae_tile_parallel( + self, + input: torch.Tensor, + vae_op, + min_height, + min_width, + stride_height, + stride_width, + device, + **kwargs) -> List[List[torch.Tensor]]: + + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + patch_height_start = h_idx * stride_height + patch_height_end = patch_height_start + min_height + patch_width_start = w_idx * stride_width + patch_width_end = patch_width_start + min_width + tile = vae_op(input, patch_height_start, patch_height_end, patch_width_start, patch_width_end, **kwargs) + local_tiles.append(tile.flatten(-2, -1)) + local_hw_shapes.append(torch.Tensor([*tile.shape[-2:]]).to(device).int()) + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=-1) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + tile_shape_first = local_tiles.shape[:-1] + gathered_tiles = [ + torch.empty( + (*tile_shape_first, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * self.w_split for _ in range(self.h_split)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][..., hw_start_idx:hw_end_idx].unflatten( + -1, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + return rows \ No newline at end of file