Skip to content

Commit 975cfa1

Browse files
committed
vae dp for wan
1 parent df8dd77 commit 975cfa1

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20+
import torch.distributed as dist
2021

2122
from ...configuration_utils import ConfigMixin, register_to_config
2223
from ...loaders import FromOriginalModelMixin
@@ -1073,6 +1074,8 @@ def __init__(
10731074
self.tile_sample_stride_height = 192
10741075
self.tile_sample_stride_width = 192
10751076

1077+
self.use_dp = False
1078+
10761079
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
10771080
self._cached_conv_counts = {
10781081
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
@@ -1113,6 +1116,53 @@ def enable_tiling(
11131116
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
11141117
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
11151118

1119+
def enable_dp(
1120+
self,
1121+
world_size: Optional[int] = None,
1122+
hw_splits: Optional[Tuple[int, int]] = None,
1123+
overlap_ratio: Optional[float] = None,
1124+
overlap_pixels: Optional[int] = None
1125+
) -> None:
1126+
r"""
1127+
"""
1128+
if world_size is None:
1129+
world_size = dist.get_world_size()
1130+
1131+
if world_size <= 1 or world_size > dist.get_world_size():
1132+
return
1133+
1134+
if hw_splits is None:
1135+
hw_splits = (1, int(world_size))
1136+
1137+
assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}"
1138+
1139+
h_split, w_split = map(int, hw_splits)
1140+
num_tiles = h_split * w_split
1141+
1142+
# assert h_split * w_split == world_size, \
1143+
# (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}")
1144+
1145+
self.use_dp = True
1146+
self.h_split, self.w_split = h_split, w_split
1147+
self.world_size = world_size
1148+
self.overlap_ratio = overlap_ratio
1149+
self.overlap_pixels = overlap_pixels
1150+
1151+
dp_ranks = list(range(0, world_size))
1152+
self.vae_dp_group = dist.new_group(ranks=dp_ranks)
1153+
self.rank = dist.get_rank()
1154+
# patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)]
1155+
# self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split)
1156+
self.tile_idxs_per_rank = [[] for _ in range(self.world_size)]
1157+
self.num_tiles_per_rank = [0] * self.world_size
1158+
rank_idx = 0
1159+
for h_idx in range(self.h_split):
1160+
for w_idx in range(self.w_split):
1161+
rank_idx %= self.world_size
1162+
self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx))
1163+
self.num_tiles_per_rank[rank_idx] += 1
1164+
rank_idx += 1
1165+
11161166
def clear_cache(self):
11171167
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
11181168
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1182,6 +1232,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
11821232
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
11831233
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
11841234

1235+
if self.use_dp:
1236+
return self.tiled_decode_with_dp(z, return_dict=return_dict)
1237+
11851238
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
11861239
return self.tiled_decode(z, return_dict=return_dict)
11871240

@@ -1393,6 +1446,185 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13931446
return (dec,)
13941447
return DecoderOutput(sample=dec)
13951448

1449+
def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1450+
r"""
1451+
Decode a batch of images using a tiled decoder.
1452+
1453+
Args:
1454+
z (`torch.Tensor`): Input batch of latent vectors.
1455+
return_dict (`bool`, *optional*, defaults to `True`):
1456+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1457+
1458+
Returns:
1459+
[`~models.vae.DecoderOutput`] or `tuple`:
1460+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1461+
returned.
1462+
"""
1463+
_, _, num_frames, height, width = z.shape
1464+
device = z.device
1465+
sample_height = height * self.spatial_compression_ratio
1466+
sample_width = width * self.spatial_compression_ratio
1467+
1468+
# Calculate stride based on h_split and w_split
1469+
tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split)
1470+
tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split)
1471+
1472+
# Calculate overlap in latent space
1473+
overlap_latent_height = 2
1474+
overlap_latent_width = 2
1475+
if self.overlap_pixels is not None:
1476+
overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio
1477+
overlap_latent_height = overlap_latent
1478+
overlap_latent_width = overlap_latent
1479+
elif self.overlap_ratio is not None:
1480+
overlap_latent_height = int(self.overlap_ratio * height)
1481+
overlap_latent_width = int(self.overlap_ratio * width)
1482+
1483+
# Calculate minimum tile size in latent space
1484+
tile_latent_min_height = tile_latent_stride_height + overlap_latent_height
1485+
tile_latent_min_width = tile_latent_stride_width + overlap_latent_width
1486+
1487+
# Convert min/stride to sample space
1488+
tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio
1489+
tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio
1490+
tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio
1491+
tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio
1492+
1493+
self.tile_sample_min_height = tile_sample_min_height
1494+
self.tile_sample_min_width = tile_sample_min_width
1495+
self.tile_sample_stride_height = tile_sample_stride_height
1496+
self.tile_sample_stride_width = tile_sample_stride_width
1497+
1498+
if self.config.patch_size is not None:
1499+
sample_height = sample_height // self.config.patch_size
1500+
sample_width = sample_width // self.config.patch_size
1501+
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
1502+
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
1503+
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
1504+
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
1505+
else:
1506+
blend_height = self.tile_sample_min_height - tile_sample_stride_height
1507+
blend_width = self.tile_sample_min_width - tile_sample_stride_width
1508+
1509+
# Split z into overlapping tiles and decode them separately.
1510+
# The tiles have an overlap to avoid seams between tiles.
1511+
# Determine tile grid dimensions - patch_ranks shape is [h_split, w_split]
1512+
num_tile_rows = self.h_split
1513+
num_tile_cols = self.w_split
1514+
1515+
# Each rank computes only tiles assigned to it based on patch_ranks
1516+
# local_tiles = {} # Dictionary to store tiles computed by this rank: {(i_idx, j_idx): tile_tensor}
1517+
local_tiles = []
1518+
local_hw_shapes = []
1519+
1520+
# h_idxs, w_idxs = torch.where(self.patch_ranks == self.rank)
1521+
for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]:
1522+
self.clear_cache()
1523+
patch_height_start = h_idx * tile_latent_stride_height
1524+
patch_height_end = patch_height_start + tile_latent_min_height
1525+
patch_width_start = w_idx * tile_latent_stride_width
1526+
patch_width_end = patch_width_start + tile_latent_min_width
1527+
time = []
1528+
for k in range(num_frames):
1529+
self._conv_idx = [0]
1530+
tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end]
1531+
tile = self.post_quant_conv(tile)
1532+
decoded = self.decoder(
1533+
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
1534+
)
1535+
time.append(decoded)
1536+
time = torch.cat(time, dim=2)
1537+
local_tiles.append(time.flatten(3, 4))
1538+
local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int())
1539+
self.clear_cache()
1540+
1541+
local_tiles = torch.cat(local_tiles, dim=3)
1542+
local_hw_shapes = torch.stack(local_hw_shapes)
1543+
1544+
gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device)
1545+
for num_tiles in self.num_tiles_per_rank]
1546+
dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group)
1547+
1548+
b, c, n = local_tiles.shape[:3]
1549+
gathered_tiles = [
1550+
torch.empty(
1551+
(b, c, n, tiles_shape.prod(dim=1).sum().item()),
1552+
dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list
1553+
]
1554+
dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group)
1555+
1556+
rows = [[None] * num_tile_cols for _ in range(num_tile_rows)]
1557+
for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank):
1558+
rank_tile_hw_shapes = gathered_shape_list[rank_idx]
1559+
hw_start_idx = 0
1560+
for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs):
1561+
rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx]
1562+
hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item()
1563+
rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten(
1564+
3, rank_tile_hw_shape.tolist())
1565+
hw_start_idx = hw_end_idx
1566+
1567+
1568+
# # Gather all tiles from all ranks
1569+
# # Prepare data for all_gather: each rank sends a dictionary of (position, tile) pairs
1570+
# gathered_tiles_list = [None] * self.world_size
1571+
# dist.all_gather_object(gathered_tiles_list, local_tiles, group=self.vae_dp_group)
1572+
1573+
# # Reconstruct the full rows structure from gathered tiles
1574+
# # First, find a reference tile to determine the expected tile shape
1575+
# all_tiles = {}
1576+
# reference_shape = None
1577+
# for rank_tiles in gathered_tiles_list:
1578+
# if rank_tiles is not None and len(rank_tiles) > 0:
1579+
# all_tiles.update(rank_tiles)
1580+
# if reference_shape is None:
1581+
# reference_shape = list(rank_tiles.values())[0].shape
1582+
# del gathered_tiles_list
1583+
1584+
# rows = []
1585+
# for i_idx in range(num_tile_rows):
1586+
# row = []
1587+
# for j_idx in range(num_tile_cols):
1588+
# # Find the tile at position (i_idx, j_idx) from gathered results
1589+
# tile = all_tiles.get((i_idx, j_idx), None)
1590+
# # If tile not found (shouldn't happen if world_size matches tile count), use reference shape
1591+
# if tile is None:
1592+
# if reference_shape is not None:
1593+
# # Use reference shape but ensure it's on the correct device
1594+
# tile = torch.zeros(*reference_shape, device=z.device, dtype=z.dtype)
1595+
# else:
1596+
# # Fallback: estimate shape (shouldn't happen in normal operation)
1597+
# batch_size, channels = z.shape[:2]
1598+
# estimated_h = tile_sample_min_height // (self.config.patch_size if self.config.patch_size is not None else 1)
1599+
# estimated_w = tile_sample_min_width // (self.config.patch_size if self.config.patch_size is not None else 1)
1600+
# tile = torch.zeros(batch_size, channels, num_frames, estimated_h, estimated_w,
1601+
# device=z.device, dtype=z.dtype)
1602+
# row.append(tile)
1603+
# rows.append(row)
1604+
1605+
result_rows = []
1606+
for i, row in enumerate(rows):
1607+
result_row = []
1608+
for j, tile in enumerate(row):
1609+
# blend the above tile and the left tile
1610+
# to the current tile and add the current tile to the result row
1611+
if i > 0:
1612+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1613+
if j > 0:
1614+
tile = self.blend_h(row[j - 1], tile, blend_width)
1615+
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
1616+
result_rows.append(torch.cat(result_row, dim=-1))
1617+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1618+
1619+
if self.config.patch_size is not None:
1620+
dec = unpatchify(dec, patch_size=self.config.patch_size)
1621+
1622+
dec = torch.clamp(dec, min=-1.0, max=1.0)
1623+
1624+
if not return_dict:
1625+
return (dec,)
1626+
return DecoderOutput(sample=dec)
1627+
13961628
def forward(
13971629
self,
13981630
sample: torch.Tensor,

0 commit comments

Comments
 (0)