|
39 | 39 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
40 | 40 |
|
41 | 41 |
|
| 42 | +class FusedLeakyReLU(nn.Module): |
| 43 | + """ |
| 44 | + Fused LeakyRelu with scale factor and channel-wise bias. |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, negative_slope: float = 0.2, scale: float = 1.0, channels: Optional[int] = None): |
| 48 | + super().__init__() |
| 49 | + self.negative_slope = negative_slope |
| 50 | + self.scale = scale |
| 51 | + self.channels = channels |
| 52 | + |
| 53 | + if self.channels is not None: |
| 54 | + self.bias = nn.Parameter(torch.zeros(self.channels,)) |
| 55 | + else: |
| 56 | + self.bias = None |
| 57 | + |
| 58 | + def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: |
| 59 | + if self.bias is not None: |
| 60 | + # Expand self.bias to have all singleton dims except at self.channel_dim |
| 61 | + expanded_shape = [1] * x.ndim |
| 62 | + expanded_shape[channel_dim] = self.bias.shape[0] |
| 63 | + bias = self.bias.reshape(*expanded_shape) |
| 64 | + x = x + bias |
| 65 | + return F.leaky_relu(x, self.negative_slope) * self.scale |
| 66 | + |
| 67 | + |
42 | 68 | class ConvLayer(nn.Module): |
43 | 69 | def __init__( |
44 | 70 | self, |
45 | 71 | in_channel: int, |
46 | 72 | out_channel: int, |
47 | 73 | kernel_size: int, |
48 | | - downsample: bool = False, |
| 74 | + downsample_factor: Optional[int] = None, |
| 75 | + blur_kernel: Tuple[int] = (1, 3, 3, 1), |
49 | 76 | bias: bool = True, |
50 | 77 | activate: bool = True, |
51 | 78 | ): |
52 | 79 | super().__init__() |
53 | 80 |
|
54 | | - self.downsample = downsample |
| 81 | + self.downsample_factor = downsample_factor |
| 82 | + self.blur_kernel = blur_kernel |
55 | 83 | self.activate = activate |
56 | 84 |
|
57 | 85 | if activate: |
58 | | - self.act = nn.LeakyReLU(0.2) |
59 | | - self.bias_leaky_relu = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) |
| 86 | + act_fn_channels = out_channel if bias else None |
| 87 | + act_fn_scale = 2**0.5 if bias else 1.0 |
| 88 | + self.act = FusedLeakyReLU(scale=act_fn_scale, channels=act_fn_channels, channel_dim=1) |
60 | 89 |
|
61 | | - if downsample: |
62 | | - factor = 2 |
63 | | - blur_kernel = (1, 3, 3, 1) |
64 | | - p = (len(blur_kernel) - factor) + (kernel_size - 1) |
| 90 | + if self.downsample_factor is not None: |
| 91 | + p = (len(self.blur_kernel) - self.downsample_factor) + (kernel_size - 1) |
65 | 92 | pad0 = (p + 1) // 2 |
66 | 93 | pad1 = p // 2 |
67 | 94 |
|
|
0 commit comments