@@ -20,7 +20,8 @@ class DropBlock(nn.Module):
20
20
tensor channels. default = False - recommended in paper
21
21
iterative_p: when True, iteratively increases probability from 0 to p
22
22
till n_iterations = steps_to_max, and maintains p there after.
23
- steps_to_max: steps to reach p, default = 500000
23
+ default = True
24
+ steps_to_max: iterations to reach p, default = 50000
24
25
25
26
Return:
26
27
torch.Tensor of shape BCHW
@@ -31,8 +32,8 @@ def __init__(self,
31
32
p : float = 0.1 ,
32
33
block_size : int = 5 ,
33
34
shared : bool = False ,
34
- iterative_p : bool = False ,
35
- steps_to_max : int = 500000 ):
35
+ iterative_p : bool = True ,
36
+ steps_to_max : int = 50000 ):
36
37
37
38
super (DropBlock , self ).__init__ ()
38
39
# checks
@@ -70,7 +71,7 @@ def __init__(self,
70
71
"{}" .format (type (iterative_p ).__name__ ))
71
72
if iterative_p :
72
73
# steps_to_max = steps to reach p
73
- self .steps_to_max = 500000
74
+ self .steps_to_max = steps_to_max
74
75
self .register_buffer ("n_iterations" , torch .Tensor ([0 ]).sum ())
75
76
76
77
def forward (self , tensor ):
@@ -89,12 +90,14 @@ def forward(self, tensor):
89
90
pad = self .w // 2
90
91
if self .shared :
91
92
c = 1
93
+
92
94
mask = torch .ones (n , c , h - 2 * pad , w - 2 * pad ).to (tensor .device )
93
95
mask = torch .bernoulli (mask * gamma )
94
96
mask = F .pad (mask , (pad , pad , pad , pad ))
95
- block_mask = F . conv2d ( mask , torch .ones (c , 1 , self .w , self .w ),
96
- padding = pad , groups = c )
97
+ kernel = torch .ones (c , 1 , self .w , self .w ). to ( tensor . device )
98
+ block_mask = F . conv2d ( mask , kernel , padding = pad , groups = c )
97
99
block_mask = (block_mask == 0 ).float ().detach ()
100
+
98
101
# norm = count(M)/count_ones(M)
99
102
norm = block_mask .sum (2 , True ).sum (3 , True ) / h / w
100
103
return tensor * block_mask * norm # A × count(M)/count_ones(M)
0 commit comments