Skip to content

Commit 808455e

Browse files
author
jax authors
committed
Merge pull request #20277 from mattjj:fix-forwarding-2
PiperOrigin-RevId: 616243845
2 parents fea2665 + 8c2f6b3 commit 808455e

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

jax/_src/pjit.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from jax._src.util import (
7777
HashableFunction, safe_map, safe_zip, wraps,
7878
distributed_debug_log, split_list, weakref_lru_cache,
79-
merge_lists, flatten, unflatten)
79+
merge_lists, flatten, unflatten, subs_list)
8080

8181
map, unsafe_map = safe_map, map
8282
zip, unsafe_zip = safe_zip, zip
@@ -1810,35 +1810,33 @@ def keep_where(l, should_keep):
18101810

18111811
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
18121812

1813-
# TODO(mattjj): un-disable this optimization after we have more tests
1814-
# # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1815-
# num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1816-
# in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1817-
# # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1818-
# in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1819-
# in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1820-
# zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1821-
# ] + in_fwd_res
1822-
# del in_fwd_primal, in_fwd_res
1823-
# # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1824-
# keep = [f is None for f in in_fwd]
1825-
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1826-
# known_out_shardings = keep_where(known_out_shardings, keep)
1827-
# # Update num_out_primals to reflect pruning.
1828-
# kept_primals, kept_res = split_list(keep, [num_out_primals])
1829-
# num_out_primals = sum(f is None for f in kept_primals)
1830-
# del keep, kept_primals, kept_res
1831-
1832-
# TODO(mattjj): un-disable this optimization after we have more tests
1833-
# # Output-to-output forwarding: compute which residuals are just primal outputs
1834-
# out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1835-
# idx_map = {id(v): i for i, v in enumerate(out_vars)}
1836-
# out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1837-
# # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1838-
# keep = [f is None for f in out_fwd]
1839-
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1840-
# known_out_shardings = keep_where(known_out_shardings, keep)
1841-
# del keep
1813+
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
1814+
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1815+
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1816+
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1817+
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1818+
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1819+
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1820+
] + in_fwd_res
1821+
del in_fwd_primal, in_fwd_res
1822+
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
1823+
keep = [f is None for f in in_fwd]
1824+
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1825+
known_out_shardings = keep_where(known_out_shardings, keep)
1826+
# Update num_out_primals to reflect pruning.
1827+
kept_primals, kept_res = split_list(keep, [num_out_primals])
1828+
num_out_primals = sum(kept_primals)
1829+
del keep, kept_primals, kept_res
1830+
1831+
# Output-to-output forwarding: compute which residuals are just primal outputs
1832+
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1833+
idx_map = {id(v): i for i, v in enumerate(out_vars)}
1834+
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1835+
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1836+
keep = [f is None for f in out_fwd]
1837+
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1838+
known_out_shardings = keep_where(known_out_shardings, keep)
1839+
del keep
18421840

18431841
known_params = dict(
18441842
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
@@ -1850,11 +1848,10 @@ def keep_where(l, should_keep):
18501848
# Bind known things to pjit_p.
18511849
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
18521850
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
1853-
# TODO(mattjj): un-disable this optimization after we have more tests
1854-
# # Add back in the output fwds.
1855-
# all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1856-
# # Add back in the input fwds.
1857-
# all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
1851+
# Add back in the output fwds.
1852+
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1853+
# Add back in the input fwds.
1854+
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
18581855

18591856
known_out_vals, residual_vals = \
18601857
split_list(all_known_outs, [len(all_known_outs) - num_residuals])

tests/api_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4588,6 +4588,23 @@ def inner(a, x):
45884588
return inner(0., x)[0]
45894589
jax.grad(f)(1.) # don't crash
45904590

4591+
@parameterized.parameters(it.product(range(4), repeat=3))
4592+
@jtu.run_on_devices("cpu")
4593+
def test_jit_forwarding_correctness(self, seed, num_input_fwd, num_output_fwd):
4594+
num_args = 3
4595+
rng = np.random.RandomState(seed)
4596+
4597+
@jax.jit
4598+
def f(inputs):
4599+
inputs = [inputs[i] for i in rng.permutation(num_args)]
4600+
outputs = inputs[:num_input_fwd] + [
4601+
jnp.exp(inputs[i]) if i < num_output_fwd else jnp.sin(inputs[i])
4602+
for i in range(num_args - num_input_fwd)]
4603+
return [outputs[i] for i in rng.permutation(num_args)]
4604+
4605+
jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1,
4606+
modes=['rev'], atol=1e-3, rtol=1e-3)
4607+
45914608

45924609
class RematTest(jtu.JaxTestCase):
45934610

@@ -9312,7 +9329,6 @@ def foo_bwd(_, g):
93129329
r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'):
93139330
jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4))
93149331

9315-
93169332
def transpose_unary(f, x_example):
93179333
def transposed(y):
93189334
x, = api.linear_transpose(f, x_example)(y)

tests/core_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,6 @@ def body(c, _):
368368
dropvar, b = jaxpr.eqns[0].outvars
369369
self.assertEqual(dropvar.aval, aval)
370370

371-
# TODO(mattjj): un-skip
372-
@unittest.skip('temporarily skipping until we can add more tests')
373371
def test_input_residual_forwarding(self):
374372
# https://github.com/google/jax/pull/11151
375373
x = jnp.arange(3 * 4.).reshape(3, 4)

0 commit comments

Comments
 (0)