Skip to content

Commit 9e44e27

Browse files
ameya98mariogeiger
andauthored
Add Linear layer in Equinox with all tests passing! (#62)
* Add Linear in Equinox. * Formatting. * Add to __init__. * Remove print() * Add tests for equinox Linear and flax Linear. * Fix equinox linear. * Formatting. * Fix imports. * Formatting. * Undo accidental change. * flake8 fixes. * Fixes for channels + one more test. * Formatting. * Try running pytest-xdist. * Minor fix. * force kw arguments --------- Co-authored-by: Mario Geiger <geiger.mario@gmail.com>
1 parent e0753a8 commit 9e44e27

File tree

9 files changed

+521
-11
lines changed

9 files changed

+521
-11
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ jobs:
3535
pip install ".[dev]"
3636
- name: Install pytest
3737
run: |
38-
pip install pytest pytest-cov
38+
pip install pytest pytest-cov pytest-xdist
3939
pip install coveralls
4040
- name: Test with pytest
4141
run: |
42-
coverage run --source=e3nn_jax -m pytest --doctest-modules --ignore=docs/ --ignore=tests/noxfile.py tests examples
42+
coverage run --source=e3nn_jax -m pytest -n auto --doctest-modules --ignore=docs/ --ignore=tests/noxfile.py tests examples
4343
- name: Upload to coveralls
4444
if: github.event_name == 'push'
4545
run: |

e3nn_jax/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
from e3nn_jax._src.utils.vmap import vmap
117117

118118
# make submodules flax and haiku available
119-
from e3nn_jax import flax, haiku
119+
from e3nn_jax import flax, haiku, equinox
120120
from e3nn_jax import utils
121121

122122
__all__ = [
@@ -229,5 +229,6 @@
229229
"vmap",
230230
"flax",
231231
"haiku",
232+
"equinox",
232233
"utils",
233234
]

e3nn_jax/_src/linear_equinox.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
from typing import Optional, Union, Tuple, Dict
2+
3+
import equinox as eqx
4+
import chex
5+
import jax
6+
import jax.numpy as jnp
7+
8+
import e3nn_jax as e3nn
9+
from e3nn_jax._src.utils.dtype import get_pytree_dtype
10+
11+
from .linear import (
12+
FunctionalLinear,
13+
linear_indexed,
14+
linear_mixed,
15+
linear_mixed_per_channel,
16+
linear_vanilla,
17+
)
18+
19+
20+
def _get_gradient_normalization(
21+
gradient_normalization: Optional[Union[float, str]]
22+
) -> float:
23+
"""Get the gradient normalization from the config or from the argument."""
24+
if gradient_normalization is None:
25+
gradient_normalization = e3nn.config("gradient_normalization")
26+
if isinstance(gradient_normalization, str):
27+
return {"element": 0.0, "path": 1.0}[gradient_normalization]
28+
return gradient_normalization
29+
30+
31+
class Linear(eqx.Module):
32+
r"""Equivariant Linear Flax module
33+
34+
Args:
35+
irreps_out (`Irreps`): output representations, if allowed bu Schur's lemma.
36+
channel_out (optional int): if specified, the last axis before the irreps
37+
is assumed to be the channel axis and is mixed with the irreps.
38+
irreps_in (`Irreps`): input representations. If not specified,
39+
the input representations is obtained when calling the module.
40+
channel_in (optional int): required when using 'mixed_per_channel' linear_type,
41+
indicating the size of the last axis before the irreps in the input.
42+
biases (bool): whether to add a bias to the output.
43+
path_normalization (str or float): Normalization of the paths, ``element`` or ``path``.
44+
0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.
45+
gradient_normalization (str or float): Normalization of the gradients, ``element`` or ``path``.
46+
0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.
47+
num_indexed_weights (optional int): number of indexed weights. See example below.
48+
weights_per_channel (bool): whether to have one set of weights per channel.
49+
force_irreps_out (bool): whether to force the output irreps to be the one specified in ``irreps_out``.
50+
51+
Due to how Equinox is implemented, the random key, irreps_in and irreps_out must be supplied at initialization.
52+
The type of the linear layer must also be supplied at initialization:
53+
'vanilla', 'indexed', 'mixed', 'mixed_per_channel'
54+
Also, depending on what type of linear layer is used, additional options
55+
(eg. 'num_indexed_weights', 'weights_per_channel', 'weights_dim', 'channel_in')
56+
must be supplied.
57+
58+
Examples:
59+
Vanilla::
60+
61+
>>> import e3nn_jax as e3nn
62+
>>> import jax
63+
64+
>>> x = e3nn.normal("0e + 1o")
65+
>>> linear = e3nn.equinox.Linear(
66+
irreps_out="2x0e + 1o + 2e",
67+
irreps_in=x.irreps,
68+
key=jax.random.PRNGKey(0),
69+
)
70+
>>> linear(x).irreps # Note that the 2e is discarded. Avoid this by setting force_irreps_out=True.
71+
2x0e+1x1o
72+
>>> linear(x).shape
73+
(5,)
74+
75+
External weights::
76+
77+
>>> linear = e3nn.equinox.Linear(
78+
irreps_out="2x0e + 1o",
79+
irreps_in=x.irreps,
80+
linear_type="mixed",
81+
weights_dim=4,
82+
key=jax.random.PRNGKey(0),
83+
)
84+
>>> e = jnp.array([1., 2., 3., 4.])
85+
>>> linear(e, x).irreps
86+
2x0e+1x1o
87+
>>> linear(e, x).shape
88+
(5,)
89+
90+
Indexed weights::
91+
92+
>>> linear = e3nn.equinox.Linear(
93+
irreps_out="2x0e + 1o + 2e",
94+
irreps_in=x.irreps,
95+
linear_type="indexed",
96+
num_indexed_weights=3,
97+
key=jax.random.PRNGKey(0),
98+
)
99+
>>> i = jnp.array(2)
100+
>>> linear(i, x).irreps
101+
2x0e+1x1o
102+
>>> linear(i, x).shape
103+
(5,)
104+
"""
105+
irreps_out: e3nn.Irreps
106+
irreps_in: e3nn.Irreps
107+
channel_out: int
108+
channel_in: int
109+
gradient_normalization: Optional[Union[float, str]]
110+
path_normalization: Optional[Union[float, str]]
111+
biases: bool
112+
num_indexed_weights: Optional[int]
113+
weights_per_channel: bool
114+
force_irreps_out: bool
115+
weights_dim: Optional[int]
116+
linear_type: str
117+
118+
# These are used internally.
119+
_linear: FunctionalLinear
120+
_weights: Dict[str, jnp.ndarray]
121+
_input_dtype: jnp.dtype
122+
123+
def __init__(
124+
self,
125+
*,
126+
irreps_out: e3nn.Irreps,
127+
irreps_in: e3nn.Irreps,
128+
channel_out: Optional[int] = None,
129+
channel_in: Optional[int] = None,
130+
biases: bool = False,
131+
path_normalization: Optional[Union[str, float]] = None,
132+
gradient_normalization: Optional[Union[str, float]] = None,
133+
num_indexed_weights: Optional[int] = None,
134+
weights_per_channel: bool = False,
135+
force_irreps_out: bool = False,
136+
weights_dim: Optional[int] = None,
137+
input_dtype: jnp.dtype = jnp.float32,
138+
linear_type: str = "vanilla",
139+
key: chex.PRNGKey,
140+
):
141+
irreps_in_regrouped = e3nn.Irreps(irreps_in).regroup()
142+
irreps_out = e3nn.Irreps(irreps_out)
143+
144+
self.irreps_in = irreps_in_regrouped
145+
self.channel_in = channel_in
146+
self.channel_out = channel_out
147+
self.biases = biases
148+
self.path_normalization = path_normalization
149+
self.num_indexed_weights = num_indexed_weights
150+
self.weights_per_channel = weights_per_channel
151+
self.force_irreps_out = force_irreps_out
152+
self.linear_type = linear_type
153+
self.weights_dim = weights_dim
154+
self._input_dtype = input_dtype
155+
156+
self.gradient_normalization = _get_gradient_normalization(
157+
gradient_normalization
158+
)
159+
160+
channel_irrep_multiplier = 1
161+
if self.channel_out is not None:
162+
assert not self.weights_per_channel
163+
channel_irrep_multiplier = self.channel_out
164+
165+
if not self.force_irreps_out:
166+
irreps_out = irreps_out.filter(keep=irreps_in_regrouped)
167+
irreps_out = irreps_out.simplify()
168+
self.irreps_out = irreps_out
169+
170+
self._linear = FunctionalLinear(
171+
irreps_in_regrouped,
172+
channel_irrep_multiplier * irreps_out,
173+
biases=self.biases,
174+
path_normalization=self.path_normalization,
175+
gradient_normalization=self.gradient_normalization,
176+
)
177+
self._weights = self._get_weights(key)
178+
179+
def _get_weights(self, key: chex.PRNGKey):
180+
"""Constructs the weights for the linear module."""
181+
irreps_in = self._linear.irreps_in
182+
irreps_out = self._linear.irreps_out
183+
184+
weights = {}
185+
for ins in self._linear.instructions:
186+
weight_key, key = jax.random.split(key)
187+
if ins.i_in == -1:
188+
name = f"b[{ins.i_out}] {irreps_out[ins.i_out]}"
189+
else:
190+
name = f"w[{ins.i_in},{ins.i_out}] {irreps_in[ins.i_in]},{irreps_out[ins.i_out]}"
191+
192+
if self.linear_type == "vanilla":
193+
weight_shape = ins.path_shape
194+
weight_std = ins.weight_std
195+
196+
if self.linear_type == "indexed":
197+
if self.num_indexed_weights is None:
198+
raise ValueError(
199+
"num_indexed_weights must be provided when 'linear_type' is 'indexed'"
200+
)
201+
202+
weight_shape = (self.num_indexed_weights,) + ins.path_shape
203+
weight_std = ins.weight_std
204+
205+
if self.linear_type in ["mixed", "mixed_per_channel"]:
206+
if self.weights_dim is None:
207+
raise ValueError(
208+
"weights_dim must be provided when 'linear_type' is 'mixed'"
209+
)
210+
211+
d = self.weights_dim
212+
if self.linear_type == "mixed":
213+
weight_shape = (d,) + ins.path_shape
214+
215+
if self.linear_type == "mixed_per_channel":
216+
if self.channel_in is None:
217+
raise ValueError(
218+
"channel_in must be provided when 'linear_type' is 'mixed_per_channel'"
219+
)
220+
weight_shape = (d, self.channel_in) + ins.path_shape
221+
222+
alpha = 1 / d
223+
stddev = jnp.sqrt(alpha) ** (1.0 - self.gradient_normalization)
224+
weight_std = stddev * ins.weight_std
225+
226+
weights[name] = weight_std * jax.random.normal(
227+
weight_key,
228+
weight_shape,
229+
self._input_dtype,
230+
)
231+
return weights
232+
233+
def __call__(self, weights_or_input, input_or_none=None) -> e3nn.IrrepsArray:
234+
"""Apply the linear operator.
235+
236+
Args:
237+
weights (optional IrrepsArray or jnp.ndarray): scalar weights that are contracted with free parameters.
238+
An array of shape ``(..., contracted_axis)``. Broadcasting with `input` is supported.
239+
input (IrrepsArray): input irreps-array of shape ``(..., [channel_in,] irreps_in.dim)``.
240+
Broadcasting with `weights` is supported.
241+
242+
Returns:
243+
IrrepsArray: output irreps-array of shape ``(..., [channel_out,] irreps_out.dim)``.
244+
Properly normalized assuming that the weights and input are properly normalized.
245+
"""
246+
if input_or_none is None:
247+
weights = None
248+
input: e3nn.IrrepsArray = weights_or_input
249+
else:
250+
weights: jnp.ndarray = weights_or_input
251+
input: e3nn.IrrepsArray = input_or_none
252+
del weights_or_input, input_or_none
253+
254+
input = e3nn.as_irreps_array(input)
255+
256+
dtype = get_pytree_dtype(weights, input)
257+
if dtype.kind == "i":
258+
dtype = jnp.float32
259+
input = input.astype(dtype)
260+
261+
if self.irreps_in != input.irreps.regroup():
262+
raise ValueError(
263+
f"e3nn.equinox.Linear: The input irreps ({input.irreps}) "
264+
f"do not match the expected irreps ({self.irreps_in})."
265+
)
266+
267+
if self.channel_in is not None:
268+
if self.channel_in != input.shape[-2]:
269+
raise ValueError(
270+
f"e3nn.equinox.Linear: The input channel ({input.shape[-2]}) "
271+
f"does not match the expected channel ({self.channel_in})."
272+
)
273+
274+
input = input.remove_zero_chunks().regroup()
275+
276+
def get_parameter(
277+
name: str,
278+
path_shape: Tuple[int, ...],
279+
weight_std: float,
280+
dtype: jnp.dtype = jnp.float32,
281+
):
282+
del path_shape, weight_std, dtype
283+
return self._weights[name]
284+
285+
assertion_message = (
286+
"Weights cannot be provided when 'linear_type' is 'vanilla'."
287+
"Otherwise, weights must be provided."
288+
"If weights are provided, they must be either: \n"
289+
"* integers and num_indexed_weights must be provided, or \n"
290+
"* floats and num_indexed_weights must not be provided.\n"
291+
f"weights.dtype={weights.dtype if weights is not None else None}, "
292+
f"num_indexed_weights={self.num_indexed_weights}"
293+
)
294+
295+
if self.linear_type == "vanilla":
296+
assert weights is None, assertion_message
297+
output = linear_vanilla(input, self._linear, get_parameter)
298+
299+
if self.linear_type in ["indexed", "mixed", "mixed_per_channel"]:
300+
assert weights is not None, assertion_message
301+
if isinstance(weights, e3nn.IrrepsArray):
302+
if not weights.irreps.is_scalar():
303+
raise ValueError("weights must be scalar")
304+
weights = weights.array
305+
306+
if self.linear_type == "indexed":
307+
assert weights.dtype.kind == "i", assertion_message
308+
if self.weights_per_channel:
309+
raise NotImplementedError(
310+
"weights_per_channel not implemented for indexed weights"
311+
)
312+
313+
output = linear_indexed(
314+
input, self._linear, get_parameter, weights, self.num_indexed_weights
315+
)
316+
317+
if self.linear_type in ["mixed", "mixed_per_channel"]:
318+
assert weights.dtype.kind in "fc", assertion_message
319+
assert self.num_indexed_weights is None, assertion_message
320+
321+
if self.linear_type == "mixed":
322+
output = linear_mixed(
323+
input,
324+
self._linear,
325+
get_parameter,
326+
weights,
327+
self.gradient_normalization,
328+
)
329+
330+
if self.linear_type == "mixed_per_channel":
331+
output = linear_mixed_per_channel(
332+
input,
333+
self._linear,
334+
get_parameter,
335+
weights,
336+
self.gradient_normalization,
337+
)
338+
339+
if self.channel_out is not None:
340+
output = output.mul_to_axis(self.channel_out)
341+
342+
return output.rechunk(self.irreps_out)

e3nn_jax/_src/linear_flax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,9 @@ def __call__(self, weights_or_input, input_or_none=None) -> e3nn.IrrepsArray:
113113
output_irreps = e3nn.Irreps(self.irreps_out).simplify()
114114
else:
115115
output_irreps_unsimplified = e3nn.Irreps(self.irreps_out).filter(
116-
input.irreps
116+
keep=input.irreps
117117
)
118118
output_irreps = output_irreps_unsimplified.simplify()
119-
120119
if self.channel_out is not None:
121120
assert not self.weights_per_channel
122121
input = input.axis_to_mul()

0 commit comments

Comments
 (0)