4
4
__all__ = ["filter2D" , "gaussian" , "gaussian_kernel2d" ]
5
5
6
6
7
+ # https://github.com/kornia/kornia/blob/main/kornia/filters/filter.py#L32
8
+ def _compute_padding (kernel_size : list [int ]) -> list [int ]:
9
+ """Compute padding tuple."""
10
+ if len (kernel_size ) < 2 :
11
+ raise AssertionError (kernel_size )
12
+ computed = [k - 1 for k in kernel_size ]
13
+
14
+ # for even kernels we need to do asymmetric padding :(
15
+ out_padding = 2 * len (kernel_size ) * [0 ]
16
+
17
+ for i in range (len (kernel_size )):
18
+ computed_tmp = computed [- (i + 1 )]
19
+
20
+ pad_front = computed_tmp // 2
21
+ pad_rear = computed_tmp - pad_front
22
+
23
+ out_padding [2 * i + 0 ] = pad_front
24
+ out_padding [2 * i + 1 ] = pad_rear
25
+
26
+ return out_padding
27
+
28
+
7
29
def filter2D (input_tensor : torch .Tensor , kernel : torch .Tensor ) -> torch .Tensor :
8
30
"""Convolves a given kernel on input tensor without losing dimensional shape.
9
31
@@ -22,20 +44,17 @@ def filter2D(input_tensor: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
22
44
(_ , channel , _ , _ ) = input_tensor .size ()
23
45
24
46
# "SAME" padding to avoid losing height and width
25
- pad = [
26
- kernel .size (2 ) // 2 ,
27
- kernel .size (2 ) // 2 ,
28
- kernel .size (3 ) // 2 ,
29
- kernel .size (3 ) // 2 ,
30
- ]
47
+ pad = _compute_padding (kernel .shape [2 :])
31
48
pad_tensor = F .pad (input_tensor , pad , "replicate" )
32
-
33
49
out = F .conv2d (pad_tensor , kernel , groups = channel )
34
50
return out
35
51
36
52
37
53
def gaussian (
38
- window_size : int , sigma : float , device : torch .device = None
54
+ window_size : int ,
55
+ sigma : float ,
56
+ device : torch .device = None ,
57
+ dtype : torch .dtype = None ,
39
58
) -> torch .Tensor :
40
59
"""Create a gaussian 1D tensor.
41
60
@@ -47,13 +66,18 @@ def gaussian(
47
66
Std of the gaussian distribution.
48
67
device : torch.device
49
68
Device for the tensor.
69
+ dtype : torch.dtype
70
+ Data type for the tensor.
50
71
51
72
Returns
52
73
-------
53
74
torch.Tensor:
54
75
A gaussian 1D tensor. Shape: (window_size, ).
55
76
"""
56
- x = torch .arange (window_size , device = device ).float () - window_size // 2
77
+ if dtype is None :
78
+ dtype = torch .float32
79
+
80
+ x = torch .arange (window_size , device = device , dtype = dtype ) - window_size // 2
57
81
if window_size % 2 == 0 :
58
82
x = x + 0.5
59
83
@@ -63,7 +87,11 @@ def gaussian(
63
87
64
88
65
89
def gaussian_kernel2d (
66
- window_size : int , sigma : float , n_channels : int = 1 , device : torch .device = None
90
+ window_size : int ,
91
+ sigma : float ,
92
+ n_channels : int = 1 ,
93
+ device : torch .device = None ,
94
+ dtype : torch .dtype = None ,
67
95
) -> torch .Tensor :
68
96
"""Create 2D window_size**2 sized kernel a gaussial kernel.
69
97
@@ -78,14 +106,16 @@ def gaussian_kernel2d(
78
106
this kernel.
79
107
device : torch.device
80
108
Device for the kernel.
109
+ dtype : torch.dtype
110
+ Data type for the kernel.
81
111
82
112
Returns:
83
113
-----------
84
114
torch.Tensor:
85
115
A tensor of shape (1, 1, window_size, window_size)
86
116
"""
87
- kernel_x = gaussian (window_size , sigma , device = device )
88
- kernel_y = gaussian (window_size , sigma , device = device )
117
+ kernel_x = gaussian (window_size , sigma , device = device , dtype = dtype )
118
+ kernel_y = gaussian (window_size , sigma , device = device , dtype = dtype )
89
119
90
120
kernel_2d = torch .matmul (kernel_x .unsqueeze (- 1 ), kernel_y .unsqueeze (- 1 ).t ())
91
121
kernel_2d = kernel_2d .expand (n_channels , 1 , window_size , window_size )
0 commit comments