76
76
from jax ._src .util import (
77
77
HashableFunction , safe_map , safe_zip , wraps ,
78
78
distributed_debug_log , split_list , weakref_lru_cache ,
79
- merge_lists , flatten , unflatten , subs_list )
79
+ merge_lists , flatten , unflatten )
80
80
81
81
map , unsafe_map = safe_map , map
82
82
zip , unsafe_zip = safe_zip , zip
@@ -1808,34 +1808,37 @@ def _pjit_partial_eval(trace, *in_tracers,
1808
1808
def keep_where (l , should_keep ):
1809
1809
return tuple (x for x , keep in zip (l , should_keep ) if keep )
1810
1810
1811
- # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1812
- num_out_primals = len (known_jaxpr .out_avals ) - num_residuals
1813
- in_fwd : list [int | None ] = pe ._jaxpr_forwarding (known_jaxpr .jaxpr )
1814
- # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1815
- in_fwd_primal , in_fwd_res = split_list (in_fwd , [num_out_primals ])
1816
- in_fwd = [fwd if is_unspecified (os ) else None for os , fwd in
1817
- zip (keep_where (out_shardings , known_outs ), in_fwd_primal )
1818
- ] + in_fwd_res
1819
- del in_fwd_primal , in_fwd_res
1820
- # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1821
- keep = [f is None for f in in_fwd ]
1822
- known_jaxpr = pe .prune_closed_jaxpr_outputs (known_jaxpr , keep )
1823
1811
known_out_shardings = keep_where (out_shardings , known_outs ) + res_shardings
1824
- known_out_shardings = keep_where (known_out_shardings , keep )
1825
- # Update num_out_primals to reflect pruning.
1826
- kept_primals , kept_res = split_list (keep , [num_out_primals ])
1827
- num_out_primals = sum (f is None for f in kept_primals )
1828
- del keep , kept_primals , kept_res
1829
-
1830
- # Output-to-output forwarding: compute which residuals are just primal outputs
1831
- out_vars , res_vars = split_list (known_jaxpr .jaxpr .outvars , [num_out_primals ])
1832
- idx_map = {id (v ): i for i , v in enumerate (out_vars )}
1833
- out_fwd = [None ] * num_out_primals + [idx_map .get (id (v )) for v in res_vars ]
1834
- # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1835
- keep = [f is None for f in out_fwd ]
1836
- known_jaxpr = pe .prune_closed_jaxpr_outputs (known_jaxpr , keep )
1837
- known_out_shardings = keep_where (known_out_shardings , keep )
1838
- del keep
1812
+
1813
+ # TODO(mattjj): un-disable this optimization after we have more tests
1814
+ # # Input-to-output forwarding: compute which outputs are just forwarded inputs.
1815
+ # num_out_primals = len(known_jaxpr.out_avals) - num_residuals
1816
+ # in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
1817
+ # # Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
1818
+ # in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
1819
+ # in_fwd = [fwd if is_unspecified(os) else None for os, fwd in
1820
+ # zip(keep_where(out_shardings, known_outs), in_fwd_primal)
1821
+ # ] + in_fwd_res
1822
+ # del in_fwd_primal, in_fwd_res
1823
+ # # Prune jaxpr outputs and out_shardings by removing the input-forwards.
1824
+ # keep = [f is None for f in in_fwd]
1825
+ # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1826
+ # known_out_shardings = keep_where(known_out_shardings, keep)
1827
+ # # Update num_out_primals to reflect pruning.
1828
+ # kept_primals, kept_res = split_list(keep, [num_out_primals])
1829
+ # num_out_primals = sum(f is None for f in kept_primals)
1830
+ # del keep, kept_primals, kept_res
1831
+
1832
+ # TODO(mattjj): un-disable this optimization after we have more tests
1833
+ # # Output-to-output forwarding: compute which residuals are just primal outputs
1834
+ # out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
1835
+ # idx_map = {id(v): i for i, v in enumerate(out_vars)}
1836
+ # out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
1837
+ # # Prune jaxpr outputs and out_shardings by removing forwarded residuals.
1838
+ # keep = [f is None for f in out_fwd]
1839
+ # known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
1840
+ # known_out_shardings = keep_where(known_out_shardings, keep)
1841
+ # del keep
1839
1842
1840
1843
known_params = dict (
1841
1844
jaxpr = known_jaxpr , in_shardings = keep_where (in_shardings , known_ins ),
@@ -1847,10 +1850,11 @@ def keep_where(l, should_keep):
1847
1850
# Bind known things to pjit_p.
1848
1851
known_inputs = [pv .get_known () for pv in in_pvals if pv .is_known ()]
1849
1852
all_known_outs = pjit_p .bind (* known_inputs , ** known_params )
1850
- # Add back in the output fwds.
1851
- all_known_outs = subs_list (out_fwd , all_known_outs , all_known_outs )
1852
- # Add back in the input fwds.
1853
- all_known_outs = subs_list (in_fwd , known_inputs , all_known_outs )
1853
+ # TODO(mattjj): un-disable this optimization after we have more tests
1854
+ # # Add back in the output fwds.
1855
+ # all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
1856
+ # # Add back in the input fwds.
1857
+ # all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
1854
1858
1855
1859
known_out_vals , residual_vals = \
1856
1860
split_list (all_known_outs , [len (all_known_outs ) - num_residuals ])
0 commit comments