76
76
from jax ._src .util import (
77
77
HashableFunction , safe_map , safe_zip , wraps ,
78
78
distributed_debug_log , split_list , weakref_lru_cache ,
79
- merge_lists , flatten , unflatten , subs_list2 )
79
+ merge_lists , flatten , unflatten )
80
80
81
81
map , unsafe_map = safe_map , map
82
82
zip , unsafe_zip = safe_zip , zip
@@ -1798,8 +1798,8 @@ def _pjit_partial_eval(trace, *in_tracers,
1798
1798
1799
1799
known_ins = tuple (pv .is_known () for pv in in_pvals )
1800
1800
unknown_ins = tuple (not k for k in known_ins )
1801
- known_jaxpr , unknown_jaxpr , unknown_outs , res_avals = pe . partial_eval_jaxpr_nounits (
1802
- jaxpr , unknown_ins , instantiate = False )
1801
+ known_jaxpr , unknown_jaxpr , unknown_outs , res_avals = \
1802
+ pe . partial_eval_jaxpr_nounits ( jaxpr , unknown_ins , instantiate = False )
1803
1803
unknown_outs = tuple (unknown_outs )
1804
1804
known_outs = tuple (not uk for uk in unknown_outs )
1805
1805
num_residuals = len (res_avals )
@@ -1808,28 +1808,37 @@ def _pjit_partial_eval(trace, *in_tracers,
1808
1808
def keep_where (l , should_keep ):
1809
1809
return tuple (x for x , keep in zip (l , should_keep ) if keep )
1810
1810
1811
- # Compute which outputs are just forwarded inputs.
1812
- num_out_primals = len (known_jaxpr .out_avals ) - num_residuals
1813
- in_fwd = pe ._jaxpr_forwarding (known_jaxpr .jaxpr )
1814
-
1815
- # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1816
- in_fwd_primal , in_fwd_res = split_list (in_fwd , [num_out_primals ])
1817
- in_fwd = [fwd if is_unspecified (os ) else None for os , fwd in
1818
- zip (keep_where (out_shardings , known_outs ), in_fwd_primal )
1819
- ] + in_fwd_res
1820
- del in_fwd_primal , in_fwd_res
1821
-
1822
- # Compute which residuals are just primal outputs.
1823
- out_vars , res_vars = split_list (known_jaxpr .jaxpr .outvars , [num_out_primals ])
1824
- idx_map = {id (v ): i for i , v in enumerate (out_vars )}
1825
- out_fwd = [None ] * num_out_primals + [idx_map .get (id (v )) for v in res_vars ]
1826
-
1827
- # Prune jaxpr outputs and out_shardings by removing forwards.
1828
- keep = [f1 is None and f2 is None for f1 , f2 in zip (in_fwd , out_fwd )]
1829
- known_jaxpr = pe .prune_closed_jaxpr_outputs (known_jaxpr , keep )
1830
1811
known_out_shardings = keep_where (out_shardings , known_outs ) + res_shardings
1831
- known_out_shardings = keep_where (known_out_shardings , keep )
1832
- del keep , num_out_primals
1812
+
1813
+ # TODO(mattjj): un-disable this optimization after we have more tests
1814
+ # # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1815
+ # num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1816
+ # in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1817
+ # # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1818
+ # in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1819
+ # in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1820
+ # zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1821
+ # ] + in_fwd_res
1822
+ # del in_fwd_primal, in_fwd_res
1823
+ # # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1824
+ # keep = [f is None for f in in_fwd]
1825
+ # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1826
+ # known_out_shardings = keep_where(known_out_shardings, keep)
1827
+ # # Update num_out_primals to reflect pruning.
1828
+ # kept_primals, kept_res = split_list(keep, [num_out_primals])
1829
+ # num_out_primals = sum(f is None for f in kept_primals)
1830
+ # del keep, kept_primals, kept_res
1831
+
1832
+ # TODO(mattjj): un-disable this optimization after we have more tests
1833
+ # # Output-to-output forwarding: compute which residuals are just primal outputs
1834
+ # out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1835
+ # idx_map = {id(v): i for i, v in enumerate(out_vars)}
1836
+ # out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1837
+ # # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1838
+ # keep = [f is None for f in out_fwd]
1839
+ # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1840
+ # known_out_shardings = keep_where(known_out_shardings, keep)
1841
+ # del keep
1833
1842
1834
1843
known_params = dict (
1835
1844
jaxpr = known_jaxpr , in_shardings = keep_where (in_shardings , known_ins ),
@@ -1841,16 +1850,19 @@ def keep_where(l, should_keep):
1841
1850
# Bind known things to pjit_p.
1842
1851
known_inputs = [pv .get_known () for pv in in_pvals if pv .is_known ()]
1843
1852
all_known_outs = pjit_p .bind (* known_inputs , ** known_params )
1844
- all_known_outs = subs_list2 (in_fwd , out_fwd , known_inputs , all_known_outs ,
1845
- all_known_outs )
1853
+ # TODO(mattjj): un-disable this optimization after we have more tests
1854
+ # # Add back in the output fwds.
1855
+ # all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1856
+ # # Add back in the input fwds.
1857
+ # all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
1846
1858
1847
1859
known_out_vals , residual_vals = \
1848
1860
split_list (all_known_outs , [len (all_known_outs ) - num_residuals ])
1849
1861
residual_tracers = map (trace .new_instantiated_const , residual_vals )
1850
1862
1851
- # The convention of partial_eval_jaxpr_nounits is to place residual binders
1852
- # at the front of the jaxpr produced, so we move them to the back since both
1853
- # the jaxpr equation built below and the pjit transpose rule assume a
1863
+ # The convention of partial_eval_jaxpr_nounits is to place residual binders at
1864
+ # the front of the jaxpr produced, so we move them to the back since both the
1865
+ # jaxpr equation built below and the pjit transpose rule assume a
1854
1866
# residual-inputs-last convention.
1855
1867
unknown_jaxpr = pe .move_binders_to_back (
1856
1868
unknown_jaxpr , [True ] * num_residuals + [False ] * sum (unknown_ins ))
0 commit comments