Skip to content

Commit e0753a8

Browse files
authored
Add ones() and ones_like() with tests. (#64)
1 parent af22024 commit e0753a8

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

e3nn_jax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
as_irreps_array,
5858
zeros,
5959
zeros_like,
60+
ones,
61+
ones_like,
6062
concatenate,
6163
stack,
6264
mean,
@@ -176,6 +178,8 @@
176178
"as_irreps_array",
177179
"zeros",
178180
"zeros_like",
181+
"ones",
182+
"ones_like",
179183
"concatenate",
180184
"stack",
181185
"mean",

e3nn_jax/_src/basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def as_irreps_array(array: Union[jnp.ndarray, e3nn.IrrepsArray], *, backend=None
101101
return e3nn.IrrepsArray(f"{array.shape[-1]}x0e", array)
102102

103103

104-
def zeros(irreps: IntoIrreps, leading_shape, dtype=None) -> e3nn.IrrepsArray:
104+
def zeros(
105+
irreps: IntoIrreps, leading_shape: Tuple = (), dtype: jnp.dtype = None
106+
) -> e3nn.IrrepsArray:
105107
r"""Create an IrrepsArray of zeros."""
106108
irreps = e3nn.Irreps(irreps)
107109
array = jnp.zeros(leading_shape + (irreps.dim,), dtype=dtype)
@@ -113,6 +115,20 @@ def zeros_like(irreps_array: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
113115
return e3nn.zeros(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype)
114116

115117

118+
def ones(
119+
irreps: IntoIrreps, leading_shape: Tuple = (), dtype: jnp.dtype = None
120+
) -> e3nn.IrrepsArray:
121+
r"""Create an IrrepsArray of ones."""
122+
irreps = e3nn.Irreps(irreps)
123+
array = jnp.ones(leading_shape + (irreps.dim,), dtype=dtype)
124+
return e3nn.IrrepsArray(irreps, array, zero_flags=(False,) * len(irreps))
125+
126+
127+
def ones_like(irreps_array: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
128+
r"""Create an IrrepsArray of ones with the same shape as another IrrepsArray."""
129+
return e3nn.ones(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype)
130+
131+
116132
def _align_two_irreps(
117133
irreps1: e3nn.Irreps, irreps2: e3nn.Irreps
118134
) -> Tuple[e3nn.Irreps, e3nn.Irreps]:

tests/_src/basic_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,42 @@
55

66

77
def assert_array_equals_chunks(x: e3nn.IrrepsArray):
8-
y = e3nn.from_chunks(x.irreps, x.chunks, x.shape[:-1])
8+
y = e3nn.from_chunks(x.irreps, x.chunks, x.shape[:-1], x.dtype)
99
np.testing.assert_array_equal(x.array, y.array)
1010

1111

12+
def test_zeros():
13+
x = e3nn.zeros("0e + 1e", leading_shape=(3, 5))
14+
assert jnp.all(x.array == 0)
15+
assert x.shape == (3, 5, 4)
16+
assert x.irreps == "0e + 1e"
17+
assert_array_equals_chunks(x)
18+
19+
20+
def test_zeros_like():
21+
x = e3nn.ones("0e + 1e", leading_shape=(3, 5))
22+
y = e3nn.zeros_like(x)
23+
assert jnp.all(y.array == 0)
24+
assert y.shape == x.shape
25+
assert y.irreps == x.irreps
26+
27+
28+
def test_ones():
29+
x = e3nn.ones("0e + 1e", leading_shape=(3, 5))
30+
assert jnp.all(x.array == 1)
31+
assert x.shape == (3, 5, 4)
32+
assert x.irreps == "0e + 1e"
33+
assert_array_equals_chunks(x)
34+
35+
36+
def test_ones_like():
37+
x = e3nn.zeros("0e + 1e", leading_shape=(3, 5))
38+
y = e3nn.ones_like(x)
39+
assert jnp.all(y.array == 1)
40+
assert y.shape == x.shape
41+
assert y.irreps == x.irreps
42+
43+
1244
def test_concatenate1(keys):
1345
x1 = e3nn.normal("0e + 1e", keys[0], (3,))
1446
x2 = e3nn.normal("0e + 1e", keys[0], (2,))

0 commit comments

Comments
 (0)