Skip to content

Commit b398e4b

Browse files
author
jax authors
committed
Merge pull request #20273 from mattjj:fix-forwarding
PiperOrigin-RevId: 616186555
2 parents cdafb8f + 8a7c604 commit b398e4b

File tree

3 files changed

+52
-29
lines changed

3 files changed

+52
-29
lines changed

jax/_src/pjit.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from jax._src.util import (
7777
HashableFunction, safe_map, safe_zip, wraps,
7878
distributed_debug_log, split_list, weakref_lru_cache,
79-
merge_lists, flatten, unflatten, subs_list2)
79+
merge_lists, flatten, unflatten)
8080

8181
map, unsafe_map = safe_map, map
8282
zip, unsafe_zip = safe_zip, zip
@@ -1798,8 +1798,8 @@ def _pjit_partial_eval(trace, *in_tracers,
17981798

17991799
known_ins = tuple(pv.is_known() for pv in in_pvals)
18001800
unknown_ins = tuple(not k for k in known_ins)
1801-
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits(
1802-
jaxpr, unknown_ins, instantiate=False)
1801+
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
1802+
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
18031803
unknown_outs = tuple(unknown_outs)
18041804
known_outs = tuple(not uk for uk in unknown_outs)
18051805
num_residuals = len(res_avals)
@@ -1808,28 +1808,37 @@ def _pjit_partial_eval(trace, *in_tracers,
18081808
def keep_where(l, should_keep):
18091809
return tuple(x for x, keep in zip(l, should_keep) if keep)
18101810

1811-
# Compute which outputs are just forwarded inputs.
1812-
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1813-
in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1814-
1815-
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1816-
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1817-
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1818-
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1819-
] + in_fwd_res
1820-
del in_fwd_primal, in_fwd_res
1821-
1822-
# Compute which residuals are just primal outputs.
1823-
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1824-
idx_map = {id(v): i for i, v in enumerate(out_vars)}
1825-
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1826-
1827-
# Prune jaxpr outputs and out_shardings by removing forwards.
1828-
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
1829-
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
18301811
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
1831-
known_out_shardings = keep_where(known_out_shardings, keep)
1832-
del keep, num_out_primals
1812+
1813+
# TODO(mattjj): un-disable this optimization after we have more tests
1814+
# # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1815+
# num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1816+
# in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1817+
# # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1818+
# in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1819+
# in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1820+
# zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1821+
# ] + in_fwd_res
1822+
# del in_fwd_primal, in_fwd_res
1823+
# # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1824+
# keep = [f is None for f in in_fwd]
1825+
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1826+
# known_out_shardings = keep_where(known_out_shardings, keep)
1827+
# # Update num_out_primals to reflect pruning.
1828+
# kept_primals, kept_res = split_list(keep, [num_out_primals])
1829+
# num_out_primals = sum(f is None for f in kept_primals)
1830+
# del keep, kept_primals, kept_res
1831+
1832+
# TODO(mattjj): un-disable this optimization after we have more tests
1833+
# # Output-to-output forwarding: compute which residuals are just primal outputs
1834+
# out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1835+
# idx_map = {id(v): i for i, v in enumerate(out_vars)}
1836+
# out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1837+
# # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1838+
# keep = [f is None for f in out_fwd]
1839+
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1840+
# known_out_shardings = keep_where(known_out_shardings, keep)
1841+
# del keep
18331842

18341843
known_params = dict(
18351844
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
@@ -1841,16 +1850,19 @@ def keep_where(l, should_keep):
18411850
# Bind known things to pjit_p.
18421851
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
18431852
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
1844-
all_known_outs = subs_list2(in_fwd, out_fwd, known_inputs, all_known_outs,
1845-
all_known_outs)
1853+
# TODO(mattjj): un-disable this optimization after we have more tests
1854+
# # Add back in the output fwds.
1855+
# all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1856+
# # Add back in the input fwds.
1857+
# all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
18461858

18471859
known_out_vals, residual_vals = \
18481860
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
18491861
residual_tracers = map(trace.new_instantiated_const, residual_vals)
18501862

1851-
# The convention of partial_eval_jaxpr_nounits is to place residual binders
1852-
# at the front of the jaxpr produced, so we move them to the back since both
1853-
# the jaxpr equation built below and the pjit transpose rule assume a
1863+
# The convention of partial_eval_jaxpr_nounits is to place residual binders at
1864+
# the front of the jaxpr produced, so we move them to the back since both the
1865+
# jaxpr equation built below and the pjit transpose rule assume a
18541866
# residual-inputs-last convention.
18551867
unknown_jaxpr = pe.move_binders_to_back(
18561868
unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))

tests/api_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4565,6 +4565,15 @@ def foo(self):
45654565
gc.collect()
45664566
assert a() is None
45674567

4568+
def test_forwarding_bug(self):
4569+
# Test for issue #20267.
4570+
def f(x):
4571+
@jax.jit
4572+
def inner(a, x):
4573+
return a, jnp.exp(x)
4574+
return inner(0., x)[0]
4575+
jax.grad(f)(1.) # don't crash
4576+
45684577

45694578
class RematTest(jtu.JaxTestCase):
45704579

tests/core_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def body(c, _):
368368
dropvar, b = jaxpr.eqns[0].outvars
369369
self.assertEqual(dropvar.aval, aval)
370370

371+
# TODO(mattjj): un-skip
372+
@unittest.skip('temporarily skipping until we can add more tests')
371373
def test_input_residual_forwarding(self):
372374
# https://github.com/google/jax/pull/11151
373375
x = jnp.arange(3 * 4.).reshape(3, 4)

0 commit comments

Comments
 (0)