76
76
from jax ._src .util import (
77
77
HashableFunction , safe_map , safe_zip , wraps ,
78
78
distributed_debug_log , split_list , weakref_lru_cache ,
79
- merge_lists , flatten , unflatten )
79
+ merge_lists , flatten , unflatten , subs_list )
80
80
81
81
map , unsafe_map = safe_map , map
82
82
zip , unsafe_zip = safe_zip , zip
@@ -1810,35 +1810,33 @@ def keep_where(l, should_keep):
1810
1810
1811
1811
known_out_shardings = keep_where (out_shardings , known_outs ) + res_shardings
1812
1812
1813
- # TODO(mattjj): un-disable this optimization after we have more tests
1814
- # # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1815
- # num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1816
- # in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1817
- # # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1818
- # in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1819
- # in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1820
- # zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1821
- # ] + in_fwd_res
1822
- # del in_fwd_primal, in_fwd_res
1823
- # # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1824
- # keep = [f is None for f in in_fwd]
1825
- # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1826
- # known_out_shardings = keep_where(known_out_shardings, keep)
1827
- # # Update num_out_primals to reflect pruning.
1828
- # kept_primals, kept_res = split_list(keep, [num_out_primals])
1829
- # num_out_primals = sum(f is None for f in kept_primals)
1830
- # del keep, kept_primals, kept_res
1831
-
1832
- # TODO(mattjj): un-disable this optimization after we have more tests
1833
- # # Output-to-output forwarding: compute which residuals are just primal outputs
1834
- # out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1835
- # idx_map = {id(v): i for i, v in enumerate(out_vars)}
1836
- # out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1837
- # # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1838
- # keep = [f is None for f in out_fwd]
1839
- # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1840
- # known_out_shardings = keep_where(known_out_shardings, keep)
1841
- # del keep
1813
+ # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1814
+ num_out_primals = len (known_jaxpr .out_avals ) - num_residuals
1815
+ in_fwd : list [int | None ] = pe ._jaxpr_forwarding (known_jaxpr .jaxpr )
1816
+ # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1817
+ in_fwd_primal , in_fwd_res = split_list (in_fwd , [num_out_primals ])
1818
+ in_fwd = [fwd if is_unspecified (os ) else None for os , fwd in
1819
+ zip (keep_where (out_shardings , known_outs ), in_fwd_primal )
1820
+ ] + in_fwd_res
1821
+ del in_fwd_primal , in_fwd_res
1822
+ # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1823
+ keep = [f is None for f in in_fwd ]
1824
+ known_jaxpr = pe .prune_closed_jaxpr_outputs (known_jaxpr , keep )
1825
+ known_out_shardings = keep_where (known_out_shardings , keep )
1826
+ # Update num_out_primals to reflect pruning.
1827
+ kept_primals , kept_res = split_list (keep , [num_out_primals ])
1828
+ num_out_primals = sum (kept_primals )
1829
+ del keep , kept_primals , kept_res
1830
+
1831
+ # Output-to-output forwarding: compute which residuals are just primal outputs
1832
+ out_vars , res_vars = split_list (known_jaxpr .jaxpr .outvars , [num_out_primals ])
1833
+ idx_map = {id (v ): i for i , v in enumerate (out_vars )}
1834
+ out_fwd = [None ] * num_out_primals + [idx_map .get (id (v )) for v in res_vars ]
1835
+ # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1836
+ keep = [f is None for f in out_fwd ]
1837
+ known_jaxpr = pe .prune_closed_jaxpr_outputs (known_jaxpr , keep )
1838
+ known_out_shardings = keep_where (known_out_shardings , keep )
1839
+ del keep
1842
1840
1843
1841
known_params = dict (
1844
1842
jaxpr = known_jaxpr , in_shardings = keep_where (in_shardings , known_ins ),
@@ -1850,11 +1848,10 @@ def keep_where(l, should_keep):
1850
1848
# Bind known things to pjit_p.
1851
1849
known_inputs = [pv .get_known () for pv in in_pvals if pv .is_known ()]
1852
1850
all_known_outs = pjit_p .bind (* known_inputs , ** known_params )
1853
- # TODO(mattjj): un-disable this optimization after we have more tests
1854
- # # Add back in the output fwds.
1855
- # all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1856
- # # Add back in the input fwds.
1857
- # all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
1851
+ # Add back in the output fwds.
1852
+ all_known_outs = subs_list (out_fwd , all_known_outs , all_known_outs )
1853
+ # Add back in the input fwds.
1854
+ all_known_outs = subs_list (in_fwd , known_inputs , all_known_outs )
1858
1855
1859
1856
known_out_vals , residual_vals = \
1860
1857
split_list (all_known_outs , [len (all_known_outs ) - num_residuals ])
0 commit comments