Skip to content

Commit 9cf2fbe

Browse files
author
jax authors
committed
Merge pull request #20278 from mattjj:fix-forwarding-3
PiperOrigin-RevId: 616264716
2 parents 808455e + c021117 commit 9cf2fbe

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def trace_to_subjaxpr_nounits_fwd2(
847847
out_pvals = [t.pval for t in out_tracers]
848848

849849
# Which consts (aka residuals) are just forwarded inputs? Check obj id.
850-
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
850+
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
851851
id_map = {id(c): i for i, c in enumerate(in_consts)}
852852
input_fwds: list[int | None] = [id_map.get(id(c)) for c in consts]
853853

tests/api_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4593,14 +4593,16 @@ def inner(a, x):
45934593
def test_jit_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd):
45944594
num_args = 3
45954595
rng = np.random.RandomState(seed)
4596+
in_perm = rng.permutation(num_args)
4597+
out_perm = rng.permutation(num_args)
45964598

45974599
@jax.jit
45984600
def f(inputs):
4599-
inputs = [inputs[i] for i in rng.permutation(num_args)]
4601+
inputs = [inputs[i] for i in in_perm]
46004602
outputs = inputs[:num_input_fwd] + [
46014603
jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i])
46024604
for i in range(num_args - num_input_fwd)]
4603-
return [outputs[i] for i in rng.permutation(num_args)]
4605+
return [outputs[i] for i in out_perm]
46044606

46054607
jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1,
46064608
modes=['rev'], atol=1e-3, rtol=1e-3)

tests/shard_map_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,27 @@ def f(x):
16041604
with jax.disable_jit():
16051605
f(x) # don't crash
16061606

1607+
@parameterized.parameters(it.product(range(4), repeat=3))
1608+
@jtu.run_on_devices("cpu")
1609+
def test_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd):
1610+
num_args = 3
1611+
rng = np.random.RandomState(seed)
1612+
mesh = Mesh(np.array(jax.devices()[:1]), ('i',))
1613+
1614+
in_perm = rng.permutation(num_args)
1615+
out_perm = rng.permutation(num_args)
1616+
1617+
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
1618+
def f(inputs):
1619+
inputs = [inputs[i] for i in in_perm]
1620+
outputs = inputs[:num_input_fwd] + [
1621+
jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i])
1622+
for i in range(num_args - num_input_fwd)]
1623+
return [outputs[i] for i in out_perm]
1624+
1625+
jtu.check_grads(f, (list(jnp.arange(float(num_args))[:,None]),), order=1,
1626+
modes=['rev'], atol=1e-3, rtol=1e-3)
1627+
16071628

16081629
class FunSpec(NamedTuple):
16091630
name: str

0 commit comments

Comments
 (0)