Skip to content

Commit 3898546

Browse files
authored
Merge pull request #79 from e3nn/get_parameter
Fix documentation build errors.
2 parents 818bfd5 + 129706a commit 3898546

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

e3nn_jax/_src/basic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,8 +648,11 @@ def normal(
648648
raise ValueError("Normalization needs to be 'norm' or 'component'")
649649

650650

651-
def where(mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray):
652-
"""Selects elements from `x` or `y`, depending on `mask`.
651+
def where(
652+
mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray
653+
) -> e3nn.IrrepsArray:
654+
"""
655+
Selects elements from `x` or `y`, depending on `mask`.
653656
654657
Equivalent to:
655658
>>> e3nn.IrrepsArray(x.irreps, jnp.where(mask, x.array, y.array))

e3nn_jax/_src/batchnorm/bn_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class BatchNorm(nn.Module):
2525
2626
Args:
2727
use_running_average: if True, the statistics stored in batch_stats will be
28-
used instead of computing the batch statistics on the input.
28+
used instead of computing the batch statistics on the input.
2929
eps (float): epsilon for numerical stability, has to be between 0 and 1.
3030
the field norm is transformed to ``(1 - eps) * norm + eps``
3131
leading to a slower convergence toward norm 1.

e3nn_jax/_src/irreps_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _is_none_slice(x):
3838
return isinstance(x, slice) and x == slice(None)
3939

4040

41-
@attrs(frozen=True, init=True, repr=False)
41+
@attrs(frozen=True, init=True, repr=False, cmp=False)
4242
class IrrepsArray:
4343
r"""Array with a representation of rotations.
4444

0 commit comments

Comments
 (0)