Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 329 additions & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
Expand Down Expand Up @@ -1073,6 +1074,8 @@ def __init__(
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192

self.use_dp = False

# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
Expand Down Expand Up @@ -1113,6 +1116,56 @@ def enable_tiling(
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width

def enable_dp(
self,
world_size: Optional[int] = None,
hw_splits: Optional[Tuple[int, int]] = None,
overlap_ratio: Optional[float] = None,
overlap_pixels: Optional[int] = None
) -> None:
r"""
"""
if world_size is None:
world_size = dist.get_world_size()

if world_size <= 1 or world_size > dist.get_world_size():
logger.warning(
f"Supported world_size for vae dp is between 2 - {dist.get_world_size}, but got {world_size}. " \
f"Fall back to normal vae")
return

if hw_splits is None:
hw_splits = (1, int(world_size))

assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}"

h_split, w_split = map(int, hw_splits)
num_tiles = h_split * w_split

# assert h_split * w_split == world_size, \
# (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}")

self.use_dp = True
self.h_split, self.w_split = h_split, w_split
self.world_size = world_size
self.overlap_ratio = overlap_ratio
self.overlap_pixels = overlap_pixels

dp_ranks = list(range(0, world_size))
self.vae_dp_group = dist.new_group(ranks=dp_ranks)
self.rank = dist.get_rank()
# patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)]
# self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split)
self.tile_idxs_per_rank = [[] for _ in range(self.world_size)]
self.num_tiles_per_rank = [0] * self.world_size
rank_idx = 0
for h_idx in range(self.h_split):
for w_idx in range(self.w_split):
rank_idx %= self.world_size
self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx))
self.num_tiles_per_rank[rank_idx] += 1
rank_idx += 1

def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
Expand All @@ -1130,6 +1183,9 @@ def _encode(self, x: torch.Tensor):
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)

if self.use_dp:
return self.tiled_encode_with_dp(x)

if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)

Expand Down Expand Up @@ -1182,6 +1238,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio

if self.use_dp:
return self.tiled_decode_with_dp(z, return_dict=return_dict)

if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)

Expand Down Expand Up @@ -1393,6 +1452,276 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
return (dec,)
return DecoderOutput(sample=dec)

def tiled_encode_with_dp(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.

Args:
x (`torch.Tensor`): Input batch of videos.

Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
device = x.device
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio

# Calculate stride based on h_split and w_split
tile_latent_stride_height = int((latent_height + self.h_split - 1) / self.h_split)
tile_latent_stride_width = int((latent_width + self.w_split - 1) / self.w_split)

# Calculate overlap in latent space
overlap_latent_height = 3
overlap_latent_width = 3
if self.overlap_pixels is not None:
overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio
overlap_latent_height = overlap_latent
overlap_latent_width = overlap_latent
elif self.overlap_ratio is not None:
overlap_latent_height = int(self.overlap_ratio * latent_height)
overlap_latent_width = int(self.overlap_ratio * latent_width)

# Calculate minimum tile size in latent space
tile_latent_min_height = tile_latent_stride_height + overlap_latent_height
tile_latent_min_width = tile_latent_stride_width + overlap_latent_width

blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width

tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio
tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio
tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio
tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio

# Determine tile grid dimensions - patch_ranks shape is [h_split, w_split]
num_tile_rows = self.h_split
num_tile_cols = self.w_split

# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
local_tiles = []
local_hw_shapes = []

for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]:
self.clear_cache()
patch_height_start = h_idx * tile_sample_stride_height
patch_height_end = patch_height_start + tile_sample_min_height
patch_width_start = w_idx * tile_sample_stride_width
patch_width_end = patch_width_start + tile_sample_min_width
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, patch_height_start : patch_height_end, patch_width_start : patch_width_end]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
patch_height_start : patch_height_end,
patch_width_start : patch_width_end,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
time = torch.cat(time, dim=2)
local_tiles.append(time.flatten(3, 4))
local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int())
self.clear_cache()

# concat all tiles on local rank
local_tiles = torch.cat(local_tiles, dim=3)
local_hw_shapes = torch.stack(local_hw_shapes)

# get all hw shapes for each rank (perhaps has different shapes for last tile)
gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device)
for num_tiles in self.num_tiles_per_rank]
dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group)

# gather tiles on all ranks
b, c, n = local_tiles.shape[:3]
gathered_tiles = [
torch.empty(
(b, c, n, tiles_shape.prod(dim=1).sum().item()),
dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list
]
dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group)

# put tiles in rows based on tile_idxs_per_rank
rows = [[None] * num_tile_cols for _ in range(num_tile_rows)]
for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank):
if not tile_idxs:
continue
rank_tile_hw_shapes = gathered_shape_list[rank_idx]
hw_start_idx = 0
# perhaps has more than one tile in each rank, get each by hw_shapes
for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs):
rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx]
hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw
rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten(
3, rank_tile_hw_shape.tolist()) # unflatten hw dim
hw_start_idx = hw_end_idx

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))

enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc

def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
device = z.device
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio

# Calculate stride based on h_split and w_split
tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split)
tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split)

# Calculate overlap in latent space
overlap_latent_height = 3
overlap_latent_width = 3
if self.overlap_pixels is not None:
overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio
overlap_latent_height = overlap_latent
overlap_latent_width = overlap_latent
elif self.overlap_ratio is not None:
overlap_latent_height = int(self.overlap_ratio * height)
overlap_latent_width = int(self.overlap_ratio * width)

# Calculate minimum tile size in latent space
tile_latent_min_height = tile_latent_stride_height + overlap_latent_height
tile_latent_min_width = tile_latent_stride_width + overlap_latent_width

# Convert min/stride to sample space
tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio
tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio
tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio
tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio

if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = tile_sample_min_height - tile_sample_stride_height
blend_width = tile_sample_min_width - tile_sample_stride_width

# Determine tile grid dimensions - patch_ranks shape is [h_split, w_split]
num_tile_rows = self.h_split
num_tile_cols = self.w_split

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
# Each rank computes only tiles assigned to it based on tile_idxs_per_rank
local_tiles = [] # List to store tiles computed by this rank
local_hw_shapes = [] # List to store shapes of tiles by this rank

for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]:
self.clear_cache()
patch_height_start = h_idx * tile_latent_stride_height
patch_height_end = patch_height_start + tile_latent_min_height
patch_width_start = w_idx * tile_latent_stride_width
patch_width_end = patch_width_start + tile_latent_min_width
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end]
tile = self.post_quant_conv(tile)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
time.append(decoded)
time = torch.cat(time, dim=2)
local_tiles.append(time.flatten(3, 4)) # flatten h,w dim for concate all tiles in one rank
local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) # record hw for futher unflatten
self.clear_cache()

# concat all tiles on local rank
local_tiles = torch.cat(local_tiles, dim=3)
local_hw_shapes = torch.stack(local_hw_shapes)

# get all hw shapes for each rank (perhaps has different shapes for last tile)
gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device)
for num_tiles in self.num_tiles_per_rank]
dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group)

# gather tiles on all ranks
b, c, n = local_tiles.shape[:3]
gathered_tiles = [
torch.empty(
(b, c, n, tiles_shape.prod(dim=1).sum().item()),
dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list
]
dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group)

# put tiles in rows based on tile_idxs_per_rank
rows = [[None] * num_tile_cols for _ in range(num_tile_rows)]
for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank):
if not tile_idxs:
continue
rank_tile_hw_shapes = gathered_shape_list[rank_idx]
hw_start_idx = 0
# perhaps has more than one tile in each rank, get each by hw_shapes
for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs):
rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx]
hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw
rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten(
3, rank_tile_hw_shape.tolist()) # unflatten hw dim
hw_start_idx = hw_end_idx

# combine all tiles, same as tiled decode
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]

if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)

dec = torch.clamp(dec, min=-1.0, max=1.0)

if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

def forward(
self,
sample: torch.Tensor,
Expand Down