diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f8bdfeb75524..2291499043cd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch.distributed as dist from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin @@ -1073,6 +1074,8 @@ def __init__( self.tile_sample_stride_height = 192 self.tile_sample_stride_width = 192 + self.use_dp = False + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup self._cached_conv_counts = { "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) @@ -1113,6 +1116,56 @@ def enable_tiling( self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + def enable_dp( + self, + world_size: Optional[int] = None, + hw_splits: Optional[Tuple[int, int]] = None, + overlap_ratio: Optional[float] = None, + overlap_pixels: Optional[int] = None + ) -> None: + r""" + """ + if world_size is None: + world_size = dist.get_world_size() + + if world_size <= 1 or world_size > dist.get_world_size(): + logger.warning( + f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \ + f"Fall back to normal vae") + return + + if hw_splits is None: + hw_splits = (1, int(world_size)) + + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" + + h_split, w_split = map(int, hw_splits) + num_tiles = h_split * w_split + + # assert h_split * w_split == world_size, \ + # (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}") + + self.use_dp = True + self.h_split, self.w_split = h_split, w_split + self.world_size = world_size + self.overlap_ratio = overlap_ratio + self.overlap_pixels = overlap_pixels + + dp_ranks = list(range(0, world_size)) + self.vae_dp_group = dist.new_group(ranks=dp_ranks) + self.rank = dist.get_rank() + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] + self.num_tiles_per_rank = [0] * self.world_size + rank_idx = 0 + for h_idx in range(self.h_split): + for w_idx in range(self.w_split): + rank_idx %= self.world_size + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) + self.num_tiles_per_rank[rank_idx] += 1 + rank_idx += 1 + def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] @@ -1130,6 +1183,9 @@ def _encode(self, x: torch.Tensor): if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + if self.use_dp: + return self.tiled_encode_with_dp(x) + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) @@ -1182,6 +1238,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + if self.use_dp: + return self.tiled_decode_with_dp(z, return_dict=return_dict) + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): return self.tiled_decode(z, return_dict=return_dict) @@ -1393,6 +1452,276 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return (dec,) return DecoderOutput(sample=dec) + def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + device = x.device + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * latent_height) + overlap_latent_width = int(self.overlap_ratio * latent_width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + local_tiles = [] + local_hw_shapes = [] + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_sample_stride_height + patch_height_end = patch_height_start + tile_sample_min_height + patch_width_start = w_idx * tile_sample_stride_width + patch_width_end = patch_width_start + tile_sample_min_width + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + patch_height_start : patch_height_end, + patch_width_start : patch_width_end, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(3, 4)) + local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=3) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + b, c, n = local_tiles.shape[:3] + gathered_tiles = [ + torch.empty( + (b, c, n, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( + 3, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + device = z.device + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + # Calculate stride based on h_split and w_split + tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) + tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) + + # Calculate overlap in latent space + overlap_latent_height = 3 + overlap_latent_width = 3 + if self.overlap_pixels is not None: + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio + overlap_latent_height = overlap_latent + overlap_latent_width = overlap_latent + elif self.overlap_ratio is not None: + overlap_latent_height = int(self.overlap_ratio * height) + overlap_latent_width = int(self.overlap_ratio * width) + + # Calculate minimum tile size in latent space + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width + + # Convert min/stride to sample space + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio + + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = tile_sample_min_height - tile_sample_stride_height + blend_width = tile_sample_min_width - tile_sample_stride_width + + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] + num_tile_rows = self.h_split + num_tile_cols = self.w_split + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + # Each rank computes only tiles assigned to it based on tile_idxs_per_rank + local_tiles = [] # List to store tiles computed by this rank + local_hw_shapes = [] # List to store shapes of tiles by this rank + + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: + self.clear_cache() + patch_height_start = h_idx * tile_latent_stride_height + patch_height_end = patch_height_start + tile_latent_min_height + patch_width_start = w_idx * tile_latent_stride_width + patch_width_end = patch_width_start + tile_latent_min_width + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + time = torch.cat(time, dim=2) + local_tiles.append(time.flatten(3, 4)) # flatten h,w dim for concate all tiles in one rank + local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) # record hw for futher unflatten + self.clear_cache() + + # concat all tiles on local rank + local_tiles = torch.cat(local_tiles, dim=3) + local_hw_shapes = torch.stack(local_hw_shapes) + + # get all hw shapes for each rank (perhaps has different shapes for last tile) + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) + for num_tiles in self.num_tiles_per_rank] + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) + + # gather tiles on all ranks + b, c, n = local_tiles.shape[:3] + gathered_tiles = [ + torch.empty( + (b, c, n, tiles_shape.prod(dim=1).sum().item()), + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list + ] + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) + + # put tiles in rows based on tile_idxs_per_rank + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): + if not tile_idxs: + continue + rank_tile_hw_shapes = gathered_shape_list[rank_idx] + hw_start_idx = 0 + # perhaps has more than one tile in each rank, get each by hw_shapes + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw + rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( + 3, rank_tile_hw_shape.tolist()) # unflatten hw dim + hw_start_idx = hw_end_idx + + # combine all tiles, same as tiled decode + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + def forward( self, sample: torch.Tensor,