|
17 | 17 | import torch |
18 | 18 | import torch.nn as nn |
19 | 19 | import torch.nn.functional as F |
| 20 | +import torch.distributed as dist |
20 | 21 |
|
21 | 22 | from ...configuration_utils import ConfigMixin, register_to_config |
22 | 23 | from ...loaders import FromOriginalModelMixin |
@@ -1073,6 +1074,8 @@ def __init__( |
1073 | 1074 | self.tile_sample_stride_height = 192 |
1074 | 1075 | self.tile_sample_stride_width = 192 |
1075 | 1076 |
|
| 1077 | + self.use_dp = False |
| 1078 | + |
1076 | 1079 | # Precompute and cache conv counts for encoder and decoder for clear_cache speedup |
1077 | 1080 | self._cached_conv_counts = { |
1078 | 1081 | "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) |
@@ -1113,6 +1116,53 @@ def enable_tiling( |
1113 | 1116 | self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height |
1114 | 1117 | self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width |
1115 | 1118 |
|
| 1119 | + def enable_dp( |
| 1120 | + self, |
| 1121 | + world_size: Optional[int] = None, |
| 1122 | + hw_splits: Optional[Tuple[int, int]] = None, |
| 1123 | + overlap_ratio: Optional[float] = None, |
| 1124 | + overlap_pixels: Optional[int] = None |
| 1125 | + ) -> None: |
| 1126 | + r""" |
| 1127 | + """ |
| 1128 | + if world_size is None: |
| 1129 | + world_size = dist.get_world_size() |
| 1130 | + |
| 1131 | + if world_size <= 1 or world_size > dist.get_world_size(): |
| 1132 | + return |
| 1133 | + |
| 1134 | + if hw_splits is None: |
| 1135 | + hw_splits = (1, int(world_size)) |
| 1136 | + |
| 1137 | + assert len(hw_splits) == 2, f"'hw_splits' should be a tuple of 2 int, but got length {len(hw_splits)}" |
| 1138 | + |
| 1139 | + h_split, w_split = map(int, hw_splits) |
| 1140 | + num_tiles = h_split * w_split |
| 1141 | + |
| 1142 | + # assert h_split * w_split == world_size, \ |
| 1143 | + # (f"world_size must be {w_split} * {h_split} = {w_split * h_split}, but got {world_size}") |
| 1144 | + |
| 1145 | + self.use_dp = True |
| 1146 | + self.h_split, self.w_split = h_split, w_split |
| 1147 | + self.world_size = world_size |
| 1148 | + self.overlap_ratio = overlap_ratio |
| 1149 | + self.overlap_pixels = overlap_pixels |
| 1150 | + |
| 1151 | + dp_ranks = list(range(0, world_size)) |
| 1152 | + self.vae_dp_group = dist.new_group(ranks=dp_ranks) |
| 1153 | + self.rank = dist.get_rank() |
| 1154 | + # patch_ranks_flatten = [tile_idx % world_size for tile_idx in range(num_tiles)] |
| 1155 | + # self.patch_ranks = torch.Tensor(patch_ranks_flatten).reshape(h_split, w_split) |
| 1156 | + self.tile_idxs_per_rank = [[] for _ in range(self.world_size)] |
| 1157 | + self.num_tiles_per_rank = [0] * self.world_size |
| 1158 | + rank_idx = 0 |
| 1159 | + for h_idx in range(self.h_split): |
| 1160 | + for w_idx in range(self.w_split): |
| 1161 | + rank_idx %= self.world_size |
| 1162 | + self.tile_idxs_per_rank[rank_idx].append((h_idx, w_idx)) |
| 1163 | + self.num_tiles_per_rank[rank_idx] += 1 |
| 1164 | + rank_idx += 1 |
| 1165 | + |
1116 | 1166 | def clear_cache(self): |
1117 | 1167 | # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call |
1118 | 1168 | self._conv_num = self._cached_conv_counts["decoder"] |
@@ -1182,6 +1232,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): |
1182 | 1232 | tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio |
1183 | 1233 | tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio |
1184 | 1234 |
|
| 1235 | + if self.use_dp: |
| 1236 | + return self.tiled_decode_with_dp(z, return_dict=return_dict) |
| 1237 | + |
1185 | 1238 | if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): |
1186 | 1239 | return self.tiled_decode(z, return_dict=return_dict) |
1187 | 1240 |
|
@@ -1393,6 +1446,152 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod |
1393 | 1446 | return (dec,) |
1394 | 1447 | return DecoderOutput(sample=dec) |
1395 | 1448 |
|
| 1449 | + def tiled_decode_with_dp(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
| 1450 | + r""" |
| 1451 | + Decode a batch of images using a tiled decoder. |
| 1452 | +
|
| 1453 | + Args: |
| 1454 | + z (`torch.Tensor`): Input batch of latent vectors. |
| 1455 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 1456 | + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
| 1457 | +
|
| 1458 | + Returns: |
| 1459 | + [`~models.vae.DecoderOutput`] or `tuple`: |
| 1460 | + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
| 1461 | + returned. |
| 1462 | + """ |
| 1463 | + _, _, num_frames, height, width = z.shape |
| 1464 | + device = z.device |
| 1465 | + sample_height = height * self.spatial_compression_ratio |
| 1466 | + sample_width = width * self.spatial_compression_ratio |
| 1467 | + |
| 1468 | + # Calculate stride based on h_split and w_split |
| 1469 | + tile_latent_stride_height = int((height + self.h_split - 1) / self.h_split) |
| 1470 | + tile_latent_stride_width = int((width + self.w_split - 1) / self.w_split) |
| 1471 | + |
| 1472 | + # Calculate overlap in latent space |
| 1473 | + overlap_latent_height = 3 |
| 1474 | + overlap_latent_width = 3 |
| 1475 | + if self.overlap_pixels is not None: |
| 1476 | + overlap_latent = (self.overlap_pixels + self.spatial_compression_ratio - 1) // self.spatial_compression_ratio |
| 1477 | + overlap_latent_height = overlap_latent |
| 1478 | + overlap_latent_width = overlap_latent |
| 1479 | + elif self.overlap_ratio is not None: |
| 1480 | + overlap_latent_height = int(self.overlap_ratio * height) |
| 1481 | + overlap_latent_width = int(self.overlap_ratio * width) |
| 1482 | + |
| 1483 | + # Calculate minimum tile size in latent space |
| 1484 | + tile_latent_min_height = tile_latent_stride_height + overlap_latent_height |
| 1485 | + tile_latent_min_width = tile_latent_stride_width + overlap_latent_width |
| 1486 | + |
| 1487 | + # Convert min/stride to sample space |
| 1488 | + tile_sample_min_height = tile_latent_min_height * self.spatial_compression_ratio |
| 1489 | + tile_sample_min_width = tile_latent_min_width * self.spatial_compression_ratio |
| 1490 | + tile_sample_stride_height = tile_latent_stride_height * self.spatial_compression_ratio |
| 1491 | + tile_sample_stride_width = tile_latent_stride_width * self.spatial_compression_ratio |
| 1492 | + |
| 1493 | + self.tile_sample_min_height = tile_sample_min_height |
| 1494 | + self.tile_sample_min_width = tile_sample_min_width |
| 1495 | + self.tile_sample_stride_height = tile_sample_stride_height |
| 1496 | + self.tile_sample_stride_width = tile_sample_stride_width |
| 1497 | + |
| 1498 | + if self.config.patch_size is not None: |
| 1499 | + sample_height = sample_height // self.config.patch_size |
| 1500 | + sample_width = sample_width // self.config.patch_size |
| 1501 | + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size |
| 1502 | + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size |
| 1503 | + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height |
| 1504 | + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width |
| 1505 | + else: |
| 1506 | + blend_height = self.tile_sample_min_height - tile_sample_stride_height |
| 1507 | + blend_width = self.tile_sample_min_width - tile_sample_stride_width |
| 1508 | + |
| 1509 | + # Split z into overlapping tiles and decode them separately. |
| 1510 | + # The tiles have an overlap to avoid seams between tiles. |
| 1511 | + # Determine tile grid dimensions - patch_ranks shape is [h_split, w_split] |
| 1512 | + num_tile_rows = self.h_split |
| 1513 | + num_tile_cols = self.w_split |
| 1514 | + |
| 1515 | + # Each rank computes only tiles assigned to it based on tile_idxs_per_rank |
| 1516 | + # local_tiles = [] # List to store tiles computed by this rank |
| 1517 | + local_tiles = [] |
| 1518 | + local_hw_shapes = [] |
| 1519 | + |
| 1520 | + for h_idx, w_idx in self.tile_idxs_per_rank[self.rank]: |
| 1521 | + self.clear_cache() |
| 1522 | + patch_height_start = h_idx * tile_latent_stride_height |
| 1523 | + patch_height_end = patch_height_start + tile_latent_min_height |
| 1524 | + patch_width_start = w_idx * tile_latent_stride_width |
| 1525 | + patch_width_end = patch_width_start + tile_latent_min_width |
| 1526 | + time = [] |
| 1527 | + for k in range(num_frames): |
| 1528 | + self._conv_idx = [0] |
| 1529 | + tile = z[:, :, k : k + 1, patch_height_start : patch_height_end, patch_width_start : patch_width_end] |
| 1530 | + tile = self.post_quant_conv(tile) |
| 1531 | + decoded = self.decoder( |
| 1532 | + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) |
| 1533 | + ) |
| 1534 | + time.append(decoded) |
| 1535 | + time = torch.cat(time, dim=2) |
| 1536 | + local_tiles.append(time.flatten(3, 4)) # flatten h,w dim for concate all tiles in one rank |
| 1537 | + local_hw_shapes.append(torch.Tensor([*time.shape[3:5]]).to(device).int()) # record hw for futher unflatten |
| 1538 | + self.clear_cache() |
| 1539 | + |
| 1540 | + # concat all tiles on local rank |
| 1541 | + local_tiles = torch.cat(local_tiles, dim=3) |
| 1542 | + local_hw_shapes = torch.stack(local_hw_shapes) |
| 1543 | + |
| 1544 | + # get all hw shapes for each rank (perhaps has different shapes for last tile) |
| 1545 | + gathered_shape_list = [torch.empty((num_tiles, 2), dtype=local_hw_shapes.dtype, device=device) |
| 1546 | + for num_tiles in self.num_tiles_per_rank] |
| 1547 | + dist.all_gather(gathered_shape_list, local_hw_shapes, group=self.vae_dp_group) |
| 1548 | + |
| 1549 | + # gather tiles on all ranks |
| 1550 | + b, c, n = local_tiles.shape[:3] |
| 1551 | + gathered_tiles = [ |
| 1552 | + torch.empty( |
| 1553 | + (b, c, n, tiles_shape.prod(dim=1).sum().item()), |
| 1554 | + dtype=local_tiles.dtype, device=device) for tiles_shape in gathered_shape_list |
| 1555 | + ] |
| 1556 | + dist.all_gather(gathered_tiles, local_tiles, group=self.vae_dp_group) |
| 1557 | + |
| 1558 | + # put tiles in rows based on tile_idxs_per_rank |
| 1559 | + rows = [[None] * num_tile_cols for _ in range(num_tile_rows)] |
| 1560 | + for rank_idx, tile_idxs in enumerate(self.tile_idxs_per_rank): |
| 1561 | + rank_tile_hw_shapes = gathered_shape_list[rank_idx] |
| 1562 | + hw_start_idx = 0 |
| 1563 | + # perhaps has more than one tile in each rank, get each by hw_shapes |
| 1564 | + for tile_idx, (h_idx, w_idx) in enumerate(tile_idxs): |
| 1565 | + rank_tile_hw_shape = rank_tile_hw_shapes[tile_idx] |
| 1566 | + hw_end_idx = hw_start_idx + rank_tile_hw_shape.prod().item() # flattend hw |
| 1567 | + rows[h_idx][w_idx] = gathered_tiles[rank_idx][:, :, :, hw_start_idx:hw_end_idx].unflatten( |
| 1568 | + 3, rank_tile_hw_shape.tolist()) # unflatten hw dim |
| 1569 | + hw_start_idx = hw_end_idx |
| 1570 | + |
| 1571 | + # combine all tiles, same as tiled decode |
| 1572 | + result_rows = [] |
| 1573 | + for i, row in enumerate(rows): |
| 1574 | + result_row = [] |
| 1575 | + for j, tile in enumerate(row): |
| 1576 | + # blend the above tile and the left tile |
| 1577 | + # to the current tile and add the current tile to the result row |
| 1578 | + if i > 0: |
| 1579 | + tile = self.blend_v(rows[i - 1][j], tile, blend_height) |
| 1580 | + if j > 0: |
| 1581 | + tile = self.blend_h(row[j - 1], tile, blend_width) |
| 1582 | + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) |
| 1583 | + result_rows.append(torch.cat(result_row, dim=-1)) |
| 1584 | + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] |
| 1585 | + |
| 1586 | + if self.config.patch_size is not None: |
| 1587 | + dec = unpatchify(dec, patch_size=self.config.patch_size) |
| 1588 | + |
| 1589 | + dec = torch.clamp(dec, min=-1.0, max=1.0) |
| 1590 | + |
| 1591 | + if not return_dict: |
| 1592 | + return (dec,) |
| 1593 | + return DecoderOutput(sample=dec) |
| 1594 | + |
1396 | 1595 | def forward( |
1397 | 1596 | self, |
1398 | 1597 | sample: torch.Tensor, |
|
0 commit comments