Skip to content

Commit 8a7c604

Browse files
committed
disable optimization
1 parent c515f15 commit 8a7c604

File tree

3 files changed

+47
-32
lines changed

3 files changed

+47
-32
lines changed

jax/_src/pjit.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from jax._src.util import (
7777
HashableFunction, safe_map, safe_zip, wraps,
7878
distributed_debug_log, split_list, weakref_lru_cache,
79-
merge_lists, flatten, unflatten, subs_list)
79+
merge_lists, flatten, unflatten)
8080

8181
map, unsafe_map = safe_map, map
8282
zip, unsafe_zip = safe_zip, zip
@@ -1808,34 +1808,37 @@ def _pjit_partial_eval(trace, *in_tracers,
18081808
def keep_where(l, should_keep):
18091809
return tuple(x for x, keep in zip(l, should_keep) if keep)
18101810

1811-
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
1812-
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1813-
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1814-
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1815-
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1816-
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1817-
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1818-
] + in_fwd_res
1819-
del in_fwd_primal, in_fwd_res
1820-
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
1821-
keep = [f is None for f in in_fwd]
1822-
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
18231811
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
1824-
known_out_shardings = keep_where(known_out_shardings, keep)
1825-
# Update num_out_primals to reflect pruning.
1826-
kept_primals, kept_res = split_list(keep, [num_out_primals])
1827-
num_out_primals = sum(f is None for f in kept_primals)
1828-
del keep, kept_primals, kept_res
1829-
1830-
# Output-to-output forwarding: compute which residuals are just primal outputs
1831-
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1832-
idx_map = {id(v): i for i, v in enumerate(out_vars)}
1833-
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1834-
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1835-
keep = [f is None for f in out_fwd]
1836-
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1837-
known_out_shardings = keep_where(known_out_shardings, keep)
1838-
del keep
1812+
1813+
# TODO(mattjj): un-disable this optimization after we have more tests
1814+
# # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1815+
# num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1816+
# in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1817+
# # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1818+
# in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1819+
# in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1820+
# zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1821+
# ] + in_fwd_res
1822+
# del in_fwd_primal, in_fwd_res
1823+
# # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1824+
# keep = [f is None for f in in_fwd]
1825+
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1826+
# known_out_shardings = keep_where(known_out_shardings, keep)
1827+
# # Update num_out_primals to reflect pruning.
1828+
# kept_primals, kept_res = split_list(keep, [num_out_primals])
1829+
# num_out_primals = sum(f is None for f in kept_primals)
1830+
# del keep, kept_primals, kept_res
1831+
1832+
# TODO(mattjj): un-disable this optimization after we have more tests
1833+
# # Output-to-output forwarding: compute which residuals are just primal outputs
1834+
# out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1835+
# idx_map = {id(v): i for i, v in enumerate(out_vars)}
1836+
# out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1837+
# # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1838+
# keep = [f is None for f in out_fwd]
1839+
# known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1840+
# known_out_shardings = keep_where(known_out_shardings, keep)
1841+
# del keep
18391842

18401843
known_params = dict(
18411844
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
@@ -1847,10 +1850,11 @@ def keep_where(l, should_keep):
18471850
# Bind known things to pjit_p.
18481851
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
18491852
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
1850-
# Add back in the output fwds.
1851-
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1852-
# Add back in the input fwds.
1853-
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
1853+
# TODO(mattjj): un-disable this optimization after we have more tests
1854+
# # Add back in the output fwds.
1855+
# all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1856+
# # Add back in the input fwds.
1857+
# all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
18541858

18551859
known_out_vals, residual_vals = \
18561860
split_list(all_known_outs, [len(all_known_outs) - num_residuals])

tests/api_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4565,6 +4565,15 @@ def foo(self):
45654565
gc.collect()
45664566
assert a() is None
45674567

4568+
def test_forwarding_bug(self):
4569+
# Test for issue #20267.
4570+
def f(x):
4571+
@jax.jit
4572+
def inner(a, x):
4573+
return a, jnp.exp(x)
4574+
return inner(0., x)[0]
4575+
jax.grad(f)(1.) # don't crash
4576+
45684577

45694578
class RematTest(jtu.JaxTestCase):
45704579

tests/core_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def body(c, _):
368368
dropvar, b = jaxpr.eqns[0].outvars
369369
self.assertEqual(dropvar.aval, aval)
370370

371+
# TODO(mattjj): un-skip
372+
@unittest.skip('temporarily skipping until we can add more tests')
371373
def test_input_residual_forwarding(self):
372374
# https://github.com/google/jax/pull/11151
373375
x = jnp.arange(3 * 4.).reshape(3, 4)

0 commit comments

Comments
 (0)