Skip to content

Commit efeedba

Browse files
authored
Merge pull request #42 from Tensor46/develop
Develop
2 parents 138f2c7 + 622545c commit efeedba

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

core/NeuralEssentials/cudamodel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def check_precision_device(self, inputs):
3434
break
3535
self.precision = p.dtype if "p" in locals() else torch.float32
3636
if type(inputs) in [list, tuple]:
37+
inputs = [x if x.dtype == torch.long else x.type(self.precision)
38+
for x in inputs]
3739
if self.is_cuda:
38-
inputs = [(x.type(self.precision).cuda() if self.is_cuda else
39-
x.type(self.precision)) if x.dtype != torch.long
40-
else x for x in inputs]
40+
inputs = [x.cuda() for x in inputs]
4141
return inputs
4242
else:
4343
if not (inputs.dtype == torch.long):

core/NeuralLayers/dropblock.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class DropBlock(nn.Module):
2020
tensor channels. default = False - recommended in paper
2121
iterative_p: when True, iteratively increases probability from 0 to p
2222
till n_iterations = steps_to_max, and maintains p there after.
23-
steps_to_max: steps to reach p, default = 500000
23+
default = True
24+
steps_to_max: iterations to reach p, default = 50000
2425
2526
Return:
2627
torch.Tensor of shape BCHW
@@ -31,8 +32,8 @@ def __init__(self,
3132
p: float = 0.1,
3233
block_size: int = 5,
3334
shared: bool = False,
34-
iterative_p: bool = False,
35-
steps_to_max: int = 500000):
35+
iterative_p: bool = True,
36+
steps_to_max: int = 50000):
3637

3738
super(DropBlock, self).__init__()
3839
# checks
@@ -70,7 +71,7 @@ def __init__(self,
7071
"{}".format(type(iterative_p).__name__))
7172
if iterative_p:
7273
# steps_to_max = steps to reach p
73-
self.steps_to_max = 500000
74+
self.steps_to_max = steps_to_max
7475
self.register_buffer("n_iterations", torch.Tensor([0]).sum())
7576

7677
def forward(self, tensor):
@@ -89,12 +90,14 @@ def forward(self, tensor):
8990
pad = self.w//2
9091
if self.shared:
9192
c = 1
93+
9294
mask = torch.ones(n, c, h-2*pad, w-2*pad).to(tensor.device)
9395
mask = torch.bernoulli(mask * gamma)
9496
mask = F.pad(mask, (pad, pad, pad, pad))
95-
block_mask = F.conv2d(mask, torch.ones(c, 1, self.w, self.w),
96-
padding=pad, groups=c)
97+
kernel = torch.ones(c, 1, self.w, self.w).to(tensor.device)
98+
block_mask = F.conv2d(mask, kernel, padding=pad, groups=c)
9799
block_mask = (block_mask == 0).float().detach()
100+
98101
# norm = count(M)/count_ones(M)
99102
norm = block_mask.sum(2, True).sum(3, True) / h / w
100103
return tensor * block_mask * norm # A × count(M)/count_ones(M)

0 commit comments

Comments
 (0)