Skip to content

Commit 7a5fab1

Browse files
committed
vae decode dp for wan
1 parent df8dd77 commit 7a5fab1

File tree

1 file changed

+199
-0
lines changed

1 file changed

+199
-0
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20+
import torch.distributed as dist
2021

2122
from ...configuration_utils import ConfigMixin, register_to_config
2223
from ...loaders import FromOriginalModelMixin
@@ -1073,6 +1074,8 @@ def __init__(
10731074
self.tile_sample_stride_height = 192
10741075
self.tile_sample_stride_width = 192
10751076

1077+
self.use_dp = False
1078+
10761079
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
10771080
self._cached_conv_counts = {
10781081
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
@@ -1113,6 +1116,53 @@ def enable_tiling(
11131116
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
11141117
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
11151118

1119+
def enable_dp(
1120+
self,
1121+
world_size: Optional[int] = None,
1122+
hw_splits: Optional[Tuple[int, int]] = None,
1123+
overlap_ratio: Optional[float] = None,
1124+
overlap_pixels: Optional[int] = None
1125+
) -> None:
1126+
r"""
1127+
"""
1128+
if world_size is None:
1129+
world_size = dist.get_world_size()
1130+
1131+
if world_size <= 1 or world_size > dist.get_world_size():
1132+
return
1133+
1134+
if hw_splits is None:
1135+
hw_splits = (1, int(world_size))
1136+
1137+
assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}"
1138+
1139+
h_split, w_split = map(int, hw_splits)
1140+
num_tiles = h_split * w_split
1141+
1142+
# assert h_split * w_split == world_size, \
1143+
# (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}")
1144+
1145+
self.use_dp = True
1146+
self.h_split, self.w_split = h_split, w_split
1147+
self.world_size = world_size
1148+
self.overlap_ratio = overlap_ratio
1149+
self.overlap_pixels = overlap_pixels
1150+
1151+
dp_ranks = list(range(0, world_size))
1152+
self.vae_dp_group = dist.new_group(ranks=dp_ranks)
1153+
self.rank = dist.get_rank()
1154+
# patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)]
1155+
# self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split)
1156+
self.tile_idxs_per_rank = [[] for _ in range(self.world_size)]
1157+
self.num_tiles_per_rank = [0] * self.world_size
1158+
rank_idx = 0
1159+
for h_idx in range(self.h_split):
1160+
for w_idx in range(self.w_split):
1161+
rank_idx %= self.world_size
1162+
self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx))
1163+
self.num_tiles_per_rank[rank_idx] += 1
1164+
rank_idx += 1
1165+
11161166
def clear_cache(self):
11171167
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
11181168
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1182,6 +1232,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
11821232
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
11831233
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
11841234

1235+
if self.use_dp:
1236+
return self.tiled_decode_with_dp(z, return_dict=return_dict)
1237+
11851238
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
11861239
return self.tiled_decode(z, return_dict=return_dict)
11871240

@@ -1393,6 +1446,152 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13931446
return (dec,)
13941447
return DecoderOutput(sample=dec)
13951448

1449+
def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1450+
r"""
1451+
Decode a batch of images using a tiled decoder.
1452+
1453+
Args:
1454+
z (`torch.Tensor`): Input batch of latent vectors.
1455+
return_dict (`bool`, *optional*, defaults to `True`):
1456+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1457+
1458+
Returns:
1459+
[`~models.vae.DecoderOutput`] or `tuple`:
1460+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1461+
returned.
1462+
"""
1463+
_, _, num_frames, height, width = z.shape
1464+
device = z.device
1465+
sample_height = height * self.spatial_compression_ratio
1466+
sample_width = width * self.spatial_compression_ratio
1467+
1468+
# Calculate stride based on h_split and w_split
1469+
tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split)
1470+
tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split)
1471+
1472+
# Calculate overlap in latent space
1473+
overlap_latent_height = 3
1474+
overlap_latent_width = 3
1475+
if self.overlap_pixels is not None:
1476+
overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio
1477+
overlap_latent_height = overlap_latent
1478+
overlap_latent_width = overlap_latent
1479+
elif self.overlap_ratio is not None:
1480+
overlap_latent_height = int(self.overlap_ratio * height)
1481+
overlap_latent_width = int(self.overlap_ratio * width)
1482+
1483+
# Calculate minimum tile size in latent space
1484+
tile_latent_min_height = tile_latent_stride_height + overlap_latent_height
1485+
tile_latent_min_width = tile_latent_stride_width + overlap_latent_width
1486+
1487+
# Convert min/stride to sample space
1488+
tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio
1489+
tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio
1490+
tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio
1491+
tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio
1492+
1493+
self.tile_sample_min_height = tile_sample_min_height
1494+
self.tile_sample_min_width = tile_sample_min_width
1495+
self.tile_sample_stride_height = tile_sample_stride_height
1496+
self.tile_sample_stride_width = tile_sample_stride_width
1497+
1498+
if self.config.patch_size is not None:
1499+
sample_height = sample_height // self.config.patch_size
1500+
sample_width = sample_width // self.config.patch_size
1501+
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
1502+
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
1503+
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
1504+
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
1505+
else:
1506+
blend_height = self.tile_sample_min_height - tile_sample_stride_height
1507+
blend_width = self.tile_sample_min_width - tile_sample_stride_width
1508+
1509+
# Split z into overlapping tiles and decode them separately.
1510+
# The tiles have an overlap to avoid seams between tiles.
1511+
# Determine tile grid dimensions - patch_ranks shape is [h_split, w_split]
1512+
num_tile_rows = self.h_split
1513+
num_tile_cols = self.w_split
1514+
1515+
# Each rank computes only tiles assigned to it based on tile_idxs_per_rank
1516+
# local_tiles = [] # List to store tiles computed by this rank
1517+
local_tiles = []
1518+
local_hw_shapes = []
1519+
1520+
for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]:
1521+
self.clear_cache()
1522+
patch_height_start = h_idx * tile_latent_stride_height
1523+
patch_height_end = patch_height_start + tile_latent_min_height
1524+
patch_width_start = w_idx * tile_latent_stride_width
1525+
patch_width_end = patch_width_start + tile_latent_min_width
1526+
time = []
1527+
for k in range(num_frames):
1528+
self._conv_idx = [0]
1529+
tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end]
1530+
tile = self.post_quant_conv(tile)
1531+
decoded = self.decoder(
1532+
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
1533+
)
1534+
time.append(decoded)
1535+
time = torch.cat(time, dim=2)
1536+
local_tiles.append(time.flatten(3, 4)) # flatten h,w dim for concate all tiles in one rank
1537+
local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) # record hw for futher unflatten
1538+
self.clear_cache()
1539+
1540+
# concat all tiles on local rank
1541+
local_tiles = torch.cat(local_tiles, dim=3)
1542+
local_hw_shapes = torch.stack(local_hw_shapes)
1543+
1544+
# get all hw shapes for each rank (perhaps has different shapes for last tile)
1545+
gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device)
1546+
for num_tiles in self.num_tiles_per_rank]
1547+
dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group)
1548+
1549+
# gather tiles on all ranks
1550+
b, c, n = local_tiles.shape[:3]
1551+
gathered_tiles = [
1552+
torch.empty(
1553+
(b, c, n, tiles_shape.prod(dim=1).sum().item()),
1554+
dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list
1555+
]
1556+
dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group)
1557+
1558+
# put tiles in rows based on tile_idxs_per_rank
1559+
rows = [[None] * num_tile_cols for _ in range(num_tile_rows)]
1560+
for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank):
1561+
rank_tile_hw_shapes = gathered_shape_list[rank_idx]
1562+
hw_start_idx = 0
1563+
# perhaps has more than one tile in each rank, get each by hw_shapes
1564+
for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs):
1565+
rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx]
1566+
hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw
1567+
rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten(
1568+
3, rank_tile_hw_shape.tolist()) # unflatten hw dim
1569+
hw_start_idx = hw_end_idx
1570+
1571+
# combine all tiles, same as tiled decode
1572+
result_rows = []
1573+
for i, row in enumerate(rows):
1574+
result_row = []
1575+
for j, tile in enumerate(row):
1576+
# blend the above tile and the left tile
1577+
# to the current tile and add the current tile to the result row
1578+
if i > 0:
1579+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1580+
if j > 0:
1581+
tile = self.blend_h(row[j - 1], tile, blend_width)
1582+
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
1583+
result_rows.append(torch.cat(result_row, dim=-1))
1584+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1585+
1586+
if self.config.patch_size is not None:
1587+
dec = unpatchify(dec, patch_size=self.config.patch_size)
1588+
1589+
dec = torch.clamp(dec, min=-1.0, max=1.0)
1590+
1591+
if not return_dict:
1592+
return (dec,)
1593+
return DecoderOutput(sample=dec)
1594+
13961595
def forward(
13971596
self,
13981597
sample: torch.Tensor,

0 commit comments

Comments
 (0)