Skip to content

Commit 6ffdb99

Browse files
committed
Add fused relu for Wan animate activations
1 parent 7146bb0 commit 6ffdb99

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,29 +39,56 @@
3939
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4040

4141

42+
class FusedLeakyReLU(nn.Module):
43+
"""
44+
Fused LeakyRelu with scale factor and channel-wise bias.
45+
"""
46+
47+
def __init__(self, negative_slope: float = 0.2, scale: float = 1.0, channels: Optional[int] = None):
48+
super().__init__()
49+
self.negative_slope = negative_slope
50+
self.scale = scale
51+
self.channels = channels
52+
53+
if self.channels is not None:
54+
self.bias = nn.Parameter(torch.zeros(self.channels,))
55+
else:
56+
self.bias = None
57+
58+
def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
59+
if self.bias is not None:
60+
# Expand self.bias to have all singleton dims except at self.channel_dim
61+
expanded_shape = [1] * x.ndim
62+
expanded_shape[channel_dim] = self.bias.shape[0]
63+
bias = self.bias.reshape(*expanded_shape)
64+
x = x + bias
65+
return F.leaky_relu(x, self.negative_slope) * self.scale
66+
67+
4268
class ConvLayer(nn.Module):
4369
def __init__(
4470
self,
4571
in_channel: int,
4672
out_channel: int,
4773
kernel_size: int,
48-
downsample: bool = False,
74+
downsample_factor: Optional[int] = None,
75+
blur_kernel: Tuple[int] = (1, 3, 3, 1),
4976
bias: bool = True,
5077
activate: bool = True,
5178
):
5279
super().__init__()
5380

54-
self.downsample = downsample
81+
self.downsample_factor = downsample_factor
82+
self.blur_kernel = blur_kernel
5583
self.activate = activate
5684

5785
if activate:
58-
self.act = nn.LeakyReLU(0.2)
59-
self.bias_leaky_relu = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
86+
act_fn_channels = out_channel if bias else None
87+
act_fn_scale = 2**0.5 if bias else 1.0
88+
self.act = FusedLeakyReLU(scale=act_fn_scale, channels=act_fn_channels, channel_dim=1)
6089

61-
if downsample:
62-
factor = 2
63-
blur_kernel = (1, 3, 3, 1)
64-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
90+
if self.downsample_factor is not None:
91+
p = (len(self.blur_kernel) - self.downsample_factor) + (kernel_size - 1)
6592
pad0 = (p + 1) // 2
6693
pad1 = p // 2
6794

0 commit comments

Comments
 (0)