Skip to content

Commit d5af2fb

Browse files
committed
fix DiscretizationArgs in gspaces and R2Diffop
2 parents 2031843 + f6bf462 commit d5af2fb

File tree

7 files changed

+179
-13
lines changed

7 files changed

+179
-13
lines changed

e2cnn/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
__title__ = "e2cnn"
1313
__summary__ = "E(2)-Equivariant CNNs Library for PyTorch"
1414
__url__ = 'https://github.com/QUVA-Lab/e2cnn'
15-
__version__ = "0.2"
15+
__version__ = "0.2.1"
1616
__author__ = "Gabriele Cesa, Maurice Weiler"
1717
__email__ = "cesa.gabriele@gmail.com"
1818
__license__ = "BSD 3-Clause Clear"

e2cnn/gspaces/r2/flip2d_on_r2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
from typing import Tuple, Callable, List
20+
from e2cnn.diffops import DiscretizationArgs
2021

2122
__all__ = ["Flip2dOnR2"]
2223

@@ -136,6 +137,7 @@ def _diffop_basis_generator(self,
136137
in_repr: Representation,
137138
out_repr: Representation,
138139
max_power: int,
140+
discretization: DiscretizationArgs,
139141
**kwargs,
140142
) -> diffops.DiffopBasis:
141143
r"""
@@ -148,6 +150,7 @@ def _diffop_basis_generator(self,
148150
in_repr: the input representation
149151
out_repr: the output representation
150152
max_power (int): the maximum power of Laplacians to use
153+
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
151154
152155
Keyword Args:
153156
maximum_frequency (int): the maximum frequency allowed in the basis vectors
@@ -160,7 +163,7 @@ def _diffop_basis_generator(self,
160163
"""
161164
maximum_frequency = None
162165
maximum_offset = None
163-
166+
164167
if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None:
165168
maximum_frequency = kwargs['maximum_frequency']
166169
assert isinstance(maximum_frequency, int) and maximum_frequency >= 0
@@ -175,7 +178,7 @@ def _diffop_basis_generator(self,
175178
return diffops.diffops_Flip_act_R2(in_repr, out_repr, max_power,
176179
axis=self.axis,
177180
max_frequency=maximum_frequency,
178-
max_offset=maximum_offset)
181+
max_offset=maximum_offset, discretization=discretization)
179182

180183
def _basespace_action(self, input: np.ndarray, element: int) -> np.ndarray:
181184

e2cnn/gspaces/r2/fliprot2d_on_r2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from e2cnn.group import dihedral_group
1515
from e2cnn.group import o2_group
1616

17+
from e2cnn.diffops import DiscretizationArgs
18+
1719
import numpy as np
1820

1921

@@ -218,6 +220,7 @@ def _diffop_basis_generator(self,
218220
in_repr: Representation,
219221
out_repr: Representation,
220222
max_power: int,
223+
discretization: DiscretizationArgs,
221224
** kwargs,
222225
) -> diffops.DiffopBasis:
223226
r"""
@@ -231,6 +234,7 @@ def _diffop_basis_generator(self,
231234
in_repr: the input representation
232235
out_repr: the output representation
233236
max_power (int): the maximum power of Laplacians to use
237+
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
234238
235239
Keyword Args:
236240
maximum_frequency (int): the maximum frequency allowed in the basis vectors
@@ -241,7 +245,7 @@ def _diffop_basis_generator(self,
241245
the basis built
242246
243247
"""
244-
248+
245249
if self.fibergroup.order() > 0:
246250
maximum_frequency = None
247251
maximum_offset = None
@@ -260,9 +264,9 @@ def _diffop_basis_generator(self,
260264
return diffops.diffops_DN_act_R2(in_repr, out_repr, max_power,
261265
axis=self.axis,
262266
max_frequency=maximum_frequency,
263-
max_offset=maximum_offset)
267+
max_offset=maximum_offset, discretization=discretization)
264268
else:
265-
return diffops.diffops_O2_act_R2(in_repr, out_repr, max_power, axis=self.axis)
269+
return diffops.diffops_O2_act_R2(in_repr, out_repr, max_power, axis=self.axis, discretization=discretization)
266270

267271
def _basespace_action(self, input: np.ndarray, element: Tuple[int, Union[float, int]]) -> np.ndarray:
268272

e2cnn/gspaces/r2/general_r2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from e2cnn import kernels, diffops
55
from e2cnn.group import Group
66
from e2cnn.group import Representation
7+
from e2cnn.diffops import DiscretizationArgs
78

89
from abc import abstractmethod
910
from typing import List, Union
@@ -128,6 +129,7 @@ def build_diffop_basis(self,
128129
in_repr: Representation,
129130
out_repr: Representation,
130131
max_power: int,
132+
discretization: DiscretizationArgs,
131133
**kwargs) -> diffops.DiffopBasis:
132134
r"""
133135
@@ -167,6 +169,7 @@ def build_diffop_basis(self,
167169
in_repr (Representation): the input representation
168170
out_repr (Representation): the output representation
169171
max_power (int): the largest power of the Laplacian that will be used
172+
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
170173
**kwargs: Group-specific keywords arguments for ``_basis_generator`` method
171174
172175
Returns:
@@ -209,5 +212,6 @@ def _diffop_basis_generator(self,
209212
in_repr: Representation,
210213
out_repr: Representation,
211214
max_power: int,
215+
discretization: DiscretizationArgs,
212216
**kwargs):
213217
pass

e2cnn/gspaces/r2/rot2d_on_r2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from e2cnn.group import so2_group
1919

2020
import numpy as np
21+
from e2cnn.diffops import DiscretizationArgs
2122

2223

2324
__all__ = ["Rot2dOnR2"]
@@ -156,6 +157,7 @@ def _diffop_basis_generator(self,
156157
in_repr: Representation,
157158
out_repr: Representation,
158159
max_power: int,
160+
discretization: DiscretizationArgs,
159161
**kwargs,
160162
) -> diffops.DiffopBasis:
161163
r"""
@@ -169,6 +171,7 @@ def _diffop_basis_generator(self,
169171
in_repr: the input representation
170172
out_repr: the output representation
171173
max_power (int): the maximum power of Laplacians to use
174+
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
172175
173176
Keyword Args:
174177
maximum_frequency (int): the maximum frequency allowed in the basis vectors
@@ -179,7 +182,7 @@ def _diffop_basis_generator(self,
179182
the basis built
180183
181184
"""
182-
185+
183186
if self.fibergroup.order() > 0:
184187
maximum_frequency = None
185188
maximum_offset = None
@@ -197,9 +200,9 @@ def _diffop_basis_generator(self,
197200

198201
return diffops.diffops_CN_act_R2(in_repr, out_repr, max_power,
199202
maximum_frequency,
200-
max_offset=maximum_offset)
203+
max_offset=maximum_offset, discretization=discretization)
201204
else:
202-
return diffops.diffops_SO2_act_R2(in_repr, out_repr, max_power)
205+
return diffops.diffops_SO2_act_R2(in_repr, out_repr, max_power, discretization=discretization)
203206

204207
def _basespace_action(self, input: np.ndarray, element: Union[float, int]) -> np.ndarray:
205208

e2cnn/gspaces/r2/trivial_on_r2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from e2cnn.group import CyclicGroup
1414
from e2cnn.group import cyclic_group
1515

16+
from e2cnn.diffops import DiscretizationArgs
17+
1618
import numpy as np
1719

1820
__all__ = ["TrivialOnR2"]
@@ -122,6 +124,7 @@ def _diffop_basis_generator(self,
122124
in_repr: Representation,
123125
out_repr: Representation,
124126
max_power: int,
127+
discretization: DiscretizationArgs,
125128
**kwargs,
126129
) -> diffops.DiffopBasis:
127130
r"""
@@ -134,6 +137,7 @@ def _diffop_basis_generator(self,
134137
in_repr: the input representation
135138
out_repr: the output representation
136139
max_power (int): the maximum power of Laplacians to use
140+
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
137141
138142
Keyword Args:
139143
maximum_frequency (int): the maximum frequency allowed in the basis vectors
@@ -160,7 +164,7 @@ def _diffop_basis_generator(self,
160164

161165
return diffops.diffops_Trivial_act_R2(in_repr, out_repr, max_power,
162166
maximum_frequency,
163-
max_offset=maximum_offset)
167+
max_offset=maximum_offset, discretization=discretization)
164168

165169
def _basespace_action(self, input: np.ndarray, element: Union[float, int]) -> np.ndarray:
166170

test/nn/test_diffop.py

Lines changed: 151 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,81 @@ def test_cyclic(self):
3131
cl.train()
3232
cl.check_equivariance()
3333

34-
def test_so2(self):
34+
def test_cyclic_gauss(self):
35+
N = 8
36+
g = Rot2dOnR2(N)
37+
38+
r1 = FieldType(g, list(g.representations.values()))
39+
r2 = FieldType(g, list(g.representations.values()) * 2)
40+
41+
s = 7
42+
43+
cl = R2Diffop(r1, r2, s, maximum_order=4, smoothing=1., bias=True)
44+
cl.bias.data = 20 * torch.randn_like(cl.bias.data)
45+
46+
cl.eval()
47+
cl.check_equivariance()
48+
49+
cl.train()
50+
cl.check_equivariance()
51+
52+
def test_cyclic_rbffd(self):
53+
N = 8
54+
g = Rot2dOnR2(N)
55+
56+
r1 = FieldType(g, list(g.representations.values()))
57+
r2 = FieldType(g, list(g.representations.values()) * 2)
58+
59+
s = 7
60+
61+
cl = R2Diffop(r1, r2, s, maximum_order=4, rbffd=True, bias=True)
62+
cl.bias.data = 20 * torch.randn_like(cl.bias.data)
63+
64+
cl.eval()
65+
cl.check_equivariance()
66+
67+
cl.train()
68+
cl.check_equivariance()
69+
70+
def test_so2_gauss(self):
3571
N = 7
3672
g = Rot2dOnR2(-1, N)
73+
74+
r1 = FieldType(g, list(g.representations.values()))
75+
r2 = FieldType(g, list(g.representations.values()))
76+
77+
s = 7
78+
79+
cl = R2Diffop(r1, r2, s, maximum_order=4, smoothing=1., bias=True)
80+
81+
cl.eval()
82+
cl.check_equivariance()
3783

84+
def test_so2_rbffd(self):
85+
N = 7
86+
g = Rot2dOnR2(-1, N)
87+
3888
r1 = FieldType(g, list(g.representations.values()))
3989
r2 = FieldType(g, list(g.representations.values()))
4090

4191
s = 7
42-
92+
93+
cl = R2Diffop(r1, r2, s, maximum_order=4, rbffd=True, bias=True)
94+
95+
cl.eval()
96+
cl.check_equivariance()
97+
98+
def test_so2(self):
99+
N = 7
100+
g = Rot2dOnR2(-1, N)
101+
102+
r1 = FieldType(g, list(g.representations.values()))
103+
r2 = FieldType(g, list(g.representations.values()))
104+
105+
s = 7
106+
43107
cl = R2Diffop(r1, r2, s, maximum_order=4, bias=True)
44-
108+
45109
cl.eval()
46110
cl.check_equivariance()
47111

@@ -59,6 +123,34 @@ def test_dihedral(self):
59123
cl.eval()
60124
cl.check_equivariance()
61125

126+
def test_dihedral_gauss(self):
127+
N = 8
128+
g = FlipRot2dOnR2(N, axis=np.pi / 3)
129+
130+
r1 = FieldType(g, list(g.representations.values()))
131+
r2 = FieldType(g, list(g.representations.values()))
132+
133+
s = 7
134+
135+
cl = R2Diffop(r1, r2, s, maximum_order=4, smoothing=1., bias=True)
136+
137+
cl.eval()
138+
cl.check_equivariance()
139+
140+
def test_dihedral_rbffd(self):
141+
N = 8
142+
g = FlipRot2dOnR2(N, axis=np.pi / 3)
143+
144+
r1 = FieldType(g, list(g.representations.values()))
145+
r2 = FieldType(g, list(g.representations.values()))
146+
147+
s = 7
148+
149+
cl = R2Diffop(r1, r2, s, maximum_order=4, rbffd=True, bias=True)
150+
151+
cl.eval()
152+
cl.check_equivariance()
153+
62154
def test_o2(self):
63155
N = 7
64156
g = FlipRot2dOnR2(-1, N)
@@ -73,6 +165,34 @@ def test_o2(self):
73165
cl.eval()
74166
cl.check_equivariance()
75167

168+
def test_o2_gauss(self):
169+
N = 7
170+
g = FlipRot2dOnR2(-1, N)
171+
172+
r1 = FieldType(g, list(g.representations.values()))
173+
r2 = FieldType(g, list(g.representations.values()))
174+
175+
s = 7
176+
177+
cl = R2Diffop(r1, r2, s, maximum_order=4, bias=True, smoothing=1.)
178+
179+
cl.eval()
180+
cl.check_equivariance()
181+
182+
def test_o2_rbffd(self):
183+
N = 7
184+
g = FlipRot2dOnR2(-1, N)
185+
186+
r1 = FieldType(g, list(g.representations.values()))
187+
r2 = FieldType(g, list(g.representations.values()))
188+
189+
s = 7
190+
191+
cl = R2Diffop(r1, r2, s, maximum_order=4, bias=True, rbffd=True)
192+
193+
cl.eval()
194+
cl.check_equivariance()
195+
76196
def test_flip(self):
77197
# g = Flip2dOnR2(axis=np.pi/3)
78198
g = Flip2dOnR2(axis=np.pi/2)
@@ -87,6 +207,34 @@ def test_flip(self):
87207
cl.eval()
88208
cl.check_equivariance()
89209

210+
def test_flip_gauss(self):
211+
# g = Flip2dOnR2(axis=np.pi/3)
212+
g = Flip2dOnR2(axis=np.pi / 2)
213+
214+
r1 = FieldType(g, list(g.representations.values()))
215+
r2 = FieldType(g, list(g.representations.values()) * 3).sorted()
216+
217+
s = 9
218+
219+
cl = R2Diffop(r1, r2, s, maximum_order=4, bias=True, smoothing=1.)
220+
221+
cl.eval()
222+
cl.check_equivariance()
223+
224+
def test_flip_rbffd(self):
225+
# g = Flip2dOnR2(axis=np.pi/3)
226+
g = Flip2dOnR2(axis=np.pi / 2)
227+
228+
r1 = FieldType(g, list(g.representations.values()))
229+
r2 = FieldType(g, list(g.representations.values()) * 3).sorted()
230+
231+
s = 9
232+
233+
cl = R2Diffop(r1, r2, s, maximum_order=4, bias=True, rbffd=True)
234+
235+
cl.eval()
236+
cl.check_equivariance()
237+
90238
def test_padding_mode_reflect(self):
91239
g = Flip2dOnR2(axis=np.pi / 2)
92240

0 commit comments

Comments
 (0)