Skip to content

Commit c60069c

Browse files
committed
Annotate types on drop fns to avoid torchscript error
1 parent cc5a11a commit c60069c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

timm/models/layers/drop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import math
2222

2323

24-
def drop_block_2d(x, drop_prob=0.1, training=False, block_size=7, gamma_scale=1.0, drop_with_noise=False):
24+
def drop_block_2d(
25+
x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7,
26+
gamma_scale: float = 1.0, drop_with_noise: bool = False):
2527
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
2628
2729
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
@@ -79,7 +81,7 @@ def forward(self, x):
7981
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
8082

8183

82-
def drop_path(x, drop_prob=0., training=False):
84+
def drop_path(x, drop_prob: float = 0., training: bool = False):
8385
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
8486
8587
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,

0 commit comments

Comments
 (0)