Skip to content

Commit c515f15

Browse files
committed
fix residual forwarding bug, fixes #20267
1 parent cdafb8f commit c515f15

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

jax/_src/pjit.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from jax._src.util import (
7777
HashableFunction, safe_map, safe_zip, wraps,
7878
distributed_debug_log, split_list, weakref_lru_cache,
79-
merge_lists, flatten, unflatten, subs_list2)
79+
merge_lists, flatten, unflatten, subs_list)
8080

8181
map, unsafe_map = safe_map, map
8282
zip, unsafe_zip = safe_zip, zip
@@ -1798,8 +1798,8 @@ def _pjit_partial_eval(trace, *in_tracers,
17981798

17991799
known_ins = tuple(pv.is_known() for pv in in_pvals)
18001800
unknown_ins = tuple(not k for k in known_ins)
1801-
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = pe.partial_eval_jaxpr_nounits(
1802-
jaxpr, unknown_ins, instantiate=False)
1801+
known_jaxpr, unknown_jaxpr, unknown_outs, res_avals = \
1802+
pe.partial_eval_jaxpr_nounits(jaxpr, unknown_ins, instantiate=False)
18031803
unknown_outs = tuple(unknown_outs)
18041804
known_outs = tuple(not uk for uk in unknown_outs)
18051805
num_residuals = len(res_avals)
@@ -1808,28 +1808,34 @@ def _pjit_partial_eval(trace, *in_tracers,
18081808
def keep_where(l, should_keep):
18091809
return tuple(x for x, keep in zip(l, should_keep) if keep)
18101810

1811-
# Compute which outputs are just forwarded inputs.
1811+
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
18121812
num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1813-
in_fwd = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1814-
1813+
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
18151814
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
18161815
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
18171816
in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
18181817
zip(keep_where(out_shardings, known_outs), in_fwd_primal)
18191818
] + in_fwd_res
18201819
del in_fwd_primal, in_fwd_res
1820+
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
1821+
keep = [f is None for f in in_fwd]
1822+
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1823+
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
1824+
known_out_shardings = keep_where(known_out_shardings, keep)
1825+
# Update num_out_primals to reflect pruning.
1826+
kept_primals, kept_res = split_list(keep, [num_out_primals])
1827+
num_out_primals = sum(f is None for f in kept_primals)
1828+
del keep, kept_primals, kept_res
18211829

1822-
# Compute which residuals are just primal outputs.
1830+
# Output-to-output forwarding: compute which residuals are just primal outputs
18231831
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
18241832
idx_map = {id(v): i for i, v in enumerate(out_vars)}
18251833
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1826-
1827-
# Prune jaxpr outputs and out_shardings by removing forwards.
1828-
keep = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
1834+
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1835+
keep = [f is None for f in out_fwd]
18291836
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1830-
known_out_shardings = keep_where(out_shardings, known_outs) + res_shardings
18311837
known_out_shardings = keep_where(known_out_shardings, keep)
1832-
del keep, num_out_primals
1838+
del keep
18331839

18341840
known_params = dict(
18351841
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
@@ -1841,16 +1847,18 @@ def keep_where(l, should_keep):
18411847
# Bind known things to pjit_p.
18421848
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
18431849
all_known_outs = pjit_p.bind(*known_inputs, **known_params)
1844-
all_known_outs = subs_list2(in_fwd, out_fwd, known_inputs, all_known_outs,
1845-
all_known_outs)
1850+
# Add back in the output fwds.
1851+
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1852+
# Add back in the input fwds.
1853+
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
18461854

18471855
known_out_vals, residual_vals = \
18481856
split_list(all_known_outs, [len(all_known_outs) - num_residuals])
18491857
residual_tracers = map(trace.new_instantiated_const, residual_vals)
18501858

1851-
# The convention of partial_eval_jaxpr_nounits is to place residual binders
1852-
# at the front of the jaxpr produced, so we move them to the back since both
1853-
# the jaxpr equation built below and the pjit transpose rule assume a
1859+
# The convention of partial_eval_jaxpr_nounits is to place residual binders at
1860+
# the front of the jaxpr produced, so we move them to the back since both the
1861+
# jaxpr equation built below and the pjit transpose rule assume a
18541862
# residual-inputs-last convention.
18551863
unknown_jaxpr = pe.move_binders_to_back(
18561864
unknown_jaxpr, [True] * num_residuals + [False] * sum(unknown_ins))

0 commit comments

Comments
 (0)