Skip to content

Commit 3e12290

Browse files
committed
fix tests
1 parent e366b3c commit 3e12290

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

e3nn_jax/_src/irreps.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,11 +1058,11 @@ def rot_y(phi):
10581058

10591059
if b is not None:
10601060
if l < len(Jd):
1061-
J = Jd[l]
1061+
J = Jd[l].astype(b.dtype)
10621062
R += [J @ rot_y(b) @ J]
10631063
else:
10641064
X = generators(l)
1065-
R += [jax.scipy.linalg.expm(b * X[0])]
1065+
R += [jax.scipy.linalg.expm(b.astype(X.dtype) * X[0]).astype(b.dtype)]
10661066

10671067
if c is not None:
10681068
R += [rot_y(c)]
@@ -1094,11 +1094,12 @@ def _wigner_D_from_log_coordinates(l: int, log_coordinates: jnp.ndarray) -> jnp.
10941094
"""
10951095
X = generators(l)
10961096

1097-
def func(log_coordinates):
1098-
return jax.scipy.linalg.expm(jnp.einsum("a,aij->ij", log_coordinates, X))
1097+
def func(log):
1098+
log = log.astype(X.dtype)
1099+
return jax.scipy.linalg.expm(jnp.einsum("a,aij->ij", log, X))
10991100

11001101
f = func
11011102
for _ in range(log_coordinates.ndim - 1):
11021103
f = jax.vmap(f)
11031104

1104-
return f(log_coordinates)
1105+
return f(log_coordinates).astype(log_coordinates.dtype)

e3nn_jax/_src/irreps_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22
import numpy as np
33
import pytest
4-
4+
import jax.numpy as jnp
55
import e3nn_jax as e3nn
66

77

@@ -141,9 +141,31 @@ def test_D(keys, ir):
141141
jax.config.update("jax_enable_x64", True)
142142

143143
ir = e3nn.Irrep(ir)
144-
angles = e3nn.rand_angles(keys[0])
144+
angles = e3nn.rand_angles(keys[0], dtype=np.float64)
145145
Da = ir.D_from_angles(*angles)
146146
w = e3nn.angles_to_log_coordinates(*angles)
147147
Dw = ir.D_from_log_coordinates(w)
148148

149-
np.testing.assert_allclose(Da, Dw, atol=1e-10, rtol=0.0008)
149+
assert Dw.dtype == np.float64, "D_from_log_coordinates should return float64"
150+
assert Da.dtype == np.float64, "D_from_angles should return float64"
151+
np.testing.assert_allclose(Da, Dw, atol=1e-10, rtol=0.002)
152+
153+
154+
@pytest.mark.parametrize("ir", ["0e", "1e", "2e", "3e", "4e", "12e"])
155+
def test_dtype_D_from_angles(ir):
156+
jax.config.update("jax_enable_x64", True)
157+
158+
ir = e3nn.Irrep(ir)
159+
e3nn.utils.assert_output_dtype_matches_input_dtype(
160+
ir.D_from_angles, jnp.array(1.0), jnp.array(1.0), jnp.array(1.0)
161+
)
162+
163+
164+
@pytest.mark.parametrize("ir", ["0e", "1e", "2e", "3e", "4e", "12e"])
165+
def test_dtype_D_from_log_coordinates(ir):
166+
jax.config.update("jax_enable_x64", True)
167+
168+
ir = e3nn.Irrep(ir)
169+
e3nn.utils.assert_output_dtype_matches_input_dtype(
170+
ir.D_from_log_coordinates, jnp.array([1.0, 1.0, 0.0])
171+
)

0 commit comments

Comments
 (0)