@@ -54,7 +54,9 @@ def forward(self, X):
5454 )
5555 X = self .pi (X )
5656 X = self .bjorck (X )
57- X = X .view (self .out_channels , self .in_channels // self .groups , self .k1 , self .k2 )
57+ X = X .reshape (
58+ self .out_channels , self .in_channels // self .groups , self .k1 , self .k2
59+ )
5860 return X * self .scale
5961
6062 def right_inverse (self , X ):
@@ -120,11 +122,10 @@ def __init__(
120122 bias ,
121123 padding_mode ,
122124 )
123- if self .dilation [0 ] > 1 or self .dilation [1 ] > 1 :
124- raise RuntimeWarning (
125- "Dilation must be 1 in the RKO convolution."
126- "Use RkoConvTranspose2d instead."
127- )
125+ if (self .dilation [0 ] > 1 or self .dilation [1 ] > 1 ) and (
126+ self .stride [0 ] != 1 or self .stride [1 ] != 1
127+ ):
128+ raise ValueError ("dilation must be 1 when stride is not 1" )
128129 # torch.nn.init.orthogonal_(self.weight)
129130 self .scale = 1 / math .sqrt (
130131 math .ceil (self .dilation [0 ] * self .kernel_size [0 ] / self .stride [0 ])
@@ -172,7 +173,7 @@ def singular_values(self):
172173 return sv_min , sv_max , stable_rank
173174 elif self .stride [0 ] > 1 or self .stride [1 ] > 1 :
174175 raise RuntimeError (
175- "Not able to compute singular values for this " " configuration"
176+ "Not able to compute singular values for this configuration"
176177 )
177178 # Implements interface required by LipschitzModuleL2
178179 sv_min , sv_max , stable_rank = conv_singular_values_numpy (
@@ -220,15 +221,14 @@ def __init__(
220221 )
221222
222223 # raise runtime error if kernel size >= stride
223- if self .kernel_size [0 ] > self .stride [0 ] or self .kernel_size [1 ] > self .stride [1 ]:
224- raise RuntimeError (
225- "kernel size must be smaller than stride. The set of orthonal convolutions is empty in this setting."
224+ if self .kernel_size [0 ] < self .stride [0 ] or self .kernel_size [1 ] < self .stride [1 ]:
225+ raise ValueError (
226+ "kernel size must be smaller than stride. The set of orthogonal convolutions is empty in this setting."
226227 )
227- if (in_channels % groups != 0 ) and (out_channels % groups != 0 ):
228- raise RuntimeError (
229- "in_channels and out_channels must be divisible by groups"
230- )
231-
228+ if (self .dilation [0 ] > 1 or self .dilation [1 ] > 1 ) and (
229+ self .stride [0 ] != 1 or self .stride [1 ] != 1
230+ ):
231+ raise ValueError ("dilation must be 1 when stride is not 1" )
232232 if (
233233 self .stride [0 ] != self .kernel_size [0 ]
234234 or self .stride [1 ] != self .kernel_size [1 ]
@@ -271,7 +271,7 @@ def singular_values(self):
271271 return sv_min , sv_max , stable_rank
272272 elif self .stride [0 ] > 1 or self .stride [1 ] > 1 :
273273 raise RuntimeError (
274- "Not able to compute singular values for this " " configuration"
274+ "Not able to compute singular values for this configuration"
275275 )
276276 # Implements interface required by LipschitzModuleL2
277277 sv_min , sv_max , stable_rank = conv_singular_values_numpy (
0 commit comments