Skip to content

Commit 06dd507

Browse files
committed
fix: add proper padding computation to convolution filters
1 parent cf51940 commit 06dd507

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

cellseg_models_pytorch/utils/convolve.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@
44
__all__ = ["filter2D", "gaussian", "gaussian_kernel2d"]
55

66

7+
# https://github.com/kornia/kornia/blob/main/kornia/filters/filter.py#L32
8+
def _compute_padding(kernel_size: list[int]) -> list[int]:
9+
"""Compute padding tuple."""
10+
if len(kernel_size) < 2:
11+
raise AssertionError(kernel_size)
12+
computed = [k - 1 for k in kernel_size]
13+
14+
# for even kernels we need to do asymmetric padding :(
15+
out_padding = 2 * len(kernel_size) * [0]
16+
17+
for i in range(len(kernel_size)):
18+
computed_tmp = computed[-(i + 1)]
19+
20+
pad_front = computed_tmp // 2
21+
pad_rear = computed_tmp - pad_front
22+
23+
out_padding[2 * i + 0] = pad_front
24+
out_padding[2 * i + 1] = pad_rear
25+
26+
return out_padding
27+
28+
729
def filter2D(input_tensor: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
830
"""Convolves a given kernel on input tensor without losing dimensional shape.
931
@@ -22,20 +44,17 @@ def filter2D(input_tensor: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
2244
(_, channel, _, _) = input_tensor.size()
2345

2446
# "SAME" padding to avoid losing height and width
25-
pad = [
26-
kernel.size(2) // 2,
27-
kernel.size(2) // 2,
28-
kernel.size(3) // 2,
29-
kernel.size(3) // 2,
30-
]
47+
pad = _compute_padding(kernel.shape[2:])
3148
pad_tensor = F.pad(input_tensor, pad, "replicate")
32-
3349
out = F.conv2d(pad_tensor, kernel, groups=channel)
3450
return out
3551

3652

3753
def gaussian(
38-
window_size: int, sigma: float, device: torch.device = None
54+
window_size: int,
55+
sigma: float,
56+
device: torch.device = None,
57+
dtype: torch.dtype = None,
3958
) -> torch.Tensor:
4059
"""Create a gaussian 1D tensor.
4160
@@ -47,13 +66,18 @@ def gaussian(
4766
Std of the gaussian distribution.
4867
device : torch.device
4968
Device for the tensor.
69+
dtype : torch.dtype
70+
Data type for the tensor.
5071
5172
Returns
5273
-------
5374
torch.Tensor:
5475
A gaussian 1D tensor. Shape: (window_size, ).
5576
"""
56-
x = torch.arange(window_size, device=device).float() - window_size // 2
77+
if dtype is None:
78+
dtype = torch.float32
79+
80+
x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2
5781
if window_size % 2 == 0:
5882
x = x + 0.5
5983

@@ -63,7 +87,11 @@ def gaussian(
6387

6488

6589
def gaussian_kernel2d(
66-
window_size: int, sigma: float, n_channels: int = 1, device: torch.device = None
90+
window_size: int,
91+
sigma: float,
92+
n_channels: int = 1,
93+
device: torch.device = None,
94+
dtype: torch.dtype = None,
6795
) -> torch.Tensor:
6896
"""Create 2D window_size**2 sized kernel a gaussial kernel.
6997
@@ -78,14 +106,16 @@ def gaussian_kernel2d(
78106
this kernel.
79107
device : torch.device
80108
Device for the kernel.
109+
dtype : torch.dtype
110+
Data type for the kernel.
81111
82112
Returns:
83113
-----------
84114
torch.Tensor:
85115
A tensor of shape (1, 1, window_size, window_size)
86116
"""
87-
kernel_x = gaussian(window_size, sigma, device=device)
88-
kernel_y = gaussian(window_size, sigma, device=device)
117+
kernel_x = gaussian(window_size, sigma, device=device, dtype=dtype)
118+
kernel_y = gaussian(window_size, sigma, device=device, dtype=dtype)
89119

90120
kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
91121
kernel_2d = kernel_2d.expand(n_channels, 1, window_size, window_size)

0 commit comments

Comments
 (0)