1
+ from copy import copy
2
+ from typing import Any , Optional
3
+ from weakref import WeakKeyDictionary
1
4
import torch
2
5
import torch_xla
3
6
from torch .utils ._pytree import tree_flatten , tree_unflatten
4
- from torch_xla .experimental .custom_kernel import jax_import_guard
7
+ from torch_xla .experimental .custom_kernel import _jax_env_context , jax_import_guard
5
8
6
9
7
10
class Type :
@@ -827,21 +830,23 @@ def get_computation_hlo(computation):
827
830
828
831
class XlaComputation :
829
832
830
- def __init__ (self , name , hlo_module , flattened_inputs ):
833
+ def __init__ (self , name , hlo_module , flattened_inputs , pick_tensor_args ):
831
834
self .num_inputs = len (flattened_inputs )
832
835
builder = create_builder (name )
833
836
computation = computation_from_module_proto (name , hlo_module )
834
837
params = []
835
838
for idx , val in enumerate (flattened_inputs ):
836
839
params .append (mkparam (builder , idx , tensor_shape (val )))
837
840
call_op = Op .call (computation , params )
838
- call_computation = call_op .build ('call_jax ' )
841
+ call_computation = call_op .build (f'call_jax_ { name } ' )
839
842
self .call_computation = call_computation
840
843
self .name = name
844
+ self .pick_tensor_args = pick_tensor_args
841
845
842
846
def __call__ (self , input_list ):
847
+ input_tensors = self .pick_tensor_args (input_list )
843
848
result = torch_xla ._XLAC ._xla_user_computation (f'xla::call_jax_{ self .name } ' ,
844
- input_list ,
849
+ input_tensors ,
845
850
self .call_computation )
846
851
if isinstance (result , list ) and len (result ) == 1 :
847
852
return result [0 ]
@@ -855,32 +860,142 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None):
855
860
# If we don't do this before calling jax, any torch_xla operation will hang.
856
861
jax_import_guard ()
857
862
858
- import jax
859
- import torchax .ops .mappings as mappings
860
-
861
- flattened , spec = tree_flatten ((args , kwargs ))
862
-
863
- def fn_flattened_inputs (* flattened ):
864
- args , kwargs = tree_unflatten (flattened , spec )
865
- return jax_func (* args , ** kwargs )
866
-
867
- sample_input_shapes = tuple (
868
- jax .ShapeDtypeStruct (a .shape , mappings .t2j_dtype (a .dtype ))
869
- for a in flattened )
870
- # `as_serialized_hlo_module_proto` is mentioned at
871
- # https://github.com/jax-ml/jax/discussions/22266
872
- hlo_module = jax .jit (
873
- fn_flattened_inputs ,
874
- keep_unused = True ).lower (* sample_input_shapes ).compiler_ir (
875
- 'hlo' ).as_serialized_hlo_module_proto () # type: ignore
876
-
877
- return XlaComputation (name , hlo_module , flattened )
878
-
879
-
880
- def call_jax (jax_func , args , kwargs = None , name = None ):
863
+ # Prevent JAX from discovering MegaScale devices a second time. If we don't do this,
864
+ # then the MegaScale device discovery will hang.
865
+ with _jax_env_context ():
866
+ import jax
867
+ import torchax .ops .mappings as mappings
868
+
869
+ flattened_inputs , spec = jax .tree .flatten ((args , kwargs ))
870
+
871
+ def abstractify (a ): # make a pytree leaf abstract
872
+ import jax
873
+ import torch_xla
874
+ if a is None :
875
+ return None
876
+ if isinstance (a , torch .Tensor ):
877
+ assert a .device == torch_xla .device (
878
+ ), f"Inputs must be XLA tensors. Got { a .device } "
879
+ return jax .ShapeDtypeStruct (a .shape , mappings .t2j_dtype (a .dtype ))
880
+ return a
881
+
882
+ sample_inputs = tuple (abstractify (a ) for a in flattened_inputs )
883
+
884
+ # Pick out the non-static args.
885
+ # Consider anything that is not a `jax.ShapeDtypeStruct` as a static arg.
886
+ def pick_tensor_args (flattened_args ):
887
+ tensor_args = []
888
+ for i in range (len (sample_inputs )):
889
+ if isinstance (sample_inputs [i ], jax .ShapeDtypeStruct ):
890
+ tensor_args .append (flattened_args [i ])
891
+ return tensor_args
892
+
893
+ sample_tensor_args = pick_tensor_args (sample_inputs )
894
+ tensor_args = pick_tensor_args (flattened_inputs )
895
+
896
+ # This function only takes in tensor arguments because its signature must
897
+ # match the signature of the HLO module lowered from JAX, allowing us to
898
+ # wrap it in an XLA user computation.
899
+ def fn (* tensor_args ):
900
+ # Go from a list of tensor args to the full list of flattened arguments,
901
+ # by referencing the original flattened inputs.
902
+ new_flattened = copy (flattened_inputs )
903
+ tensor_args_iter = iter (tensor_args )
904
+ for i in range (len (sample_inputs )):
905
+ if isinstance (sample_inputs [i ], jax .ShapeDtypeStruct ):
906
+ new_flattened [i ] = next (tensor_args_iter )
907
+ args , kwargs = jax .tree .unflatten (spec , new_flattened )
908
+ return jax_func (* args , ** kwargs )
909
+
910
+ def get_hlo ():
911
+ import torch_xla .debug .profiler as xp
912
+ # If we see this trace span in the profiler, we'll know that there's a cache miss.
913
+ with xp .Trace ('jax_to_hlo' ):
914
+ hlo_ir = jax .jit (
915
+ fn , keep_unused = True ).lower (* sample_tensor_args ).compiler_ir ('hlo' )
916
+
917
+ # Get a protobuf representation of the HLO. `as_serialized_hlo_module_proto` is
918
+ # mentioned at https://github.com/jax-ml/jax/discussions/22266
919
+ return hlo_ir .as_serialized_hlo_module_proto () # type: ignore
920
+
921
+ hlo_module = _jax_to_hlo_cache_get_or_insert (jax_func , sample_inputs , spec ,
922
+ get_hlo )
923
+ return XlaComputation (name , hlo_module , tensor_args , pick_tensor_args )
924
+
925
+
926
+ def _jax_to_hlo_cache_get_or_insert (jax_func , sample_inputs : tuple [Any , ...],
927
+ input_tree_spec , get_hlo ):
928
+ global _JAX_TO_HLO_CACHE
929
+ # Use two layers of dictionary lookup.
930
+ # The first layer uses the `jax_func`, which is only weakly referenced.
931
+ # The second layer uses the sample inputs and the tree spec, which is strongly referenced.
932
+ inner_dict = _JAX_TO_HLO_CACHE .get (jax_func , None )
933
+ if inner_dict is not None :
934
+ hlo = inner_dict .get ((sample_inputs , input_tree_spec ), None )
935
+ if hlo is not None :
936
+ return hlo
937
+
938
+ # Compget_hlo jax function to HLO.
939
+ hlo = get_hlo ()
940
+ if inner_dict is None :
941
+ _JAX_TO_HLO_CACHE [jax_func ] = {}
942
+ _JAX_TO_HLO_CACHE [jax_func ][(sample_inputs , input_tree_spec )] = hlo
943
+ return hlo
944
+
945
+
946
+ def _jax_to_hlo_cache_num_misses () -> int :
947
+ size = 0
948
+ for inner_dict in _JAX_TO_HLO_CACHE .values ():
949
+ size += len (inner_dict )
950
+ return size
951
+
952
+
953
+ _JAX_TO_HLO_CACHE = WeakKeyDictionary ()
954
+
955
+
956
+ def call_jax (jax_func ,
957
+ args : tuple [Any , ...],
958
+ kwargs : Optional [dict [str , Any ]] = None ,
959
+ name = None ):
881
960
"""
882
961
Call a JAX function `jax_func` with the given `args` and `kwargs` that may contain
883
962
XLA tensors.
963
+
964
+ Args:
965
+ jax_func: a functionally pure Python callable that does some math on JAX arrays.
966
+ It needs to be `jax.jit` traceable.
967
+
968
+ args: a tuple of arguments to pass to `jax_func`. Any XLA tensors are turned into
969
+ JAX arrays before being passed to `jax_func`.
970
+
971
+ kwargs: a dictionary of keyword arguments to pass to `jax_func`. Any XLA tensors are
972
+ turned into JAX arrays before being passed to `jax_func`.
973
+
974
+ ## Example
975
+
976
+ >>> import torch
977
+ >>> import torch_xla
978
+ >>> import torch_xla.core.xla_builder as xb
979
+ >>>
980
+ >>> def f(a, b):
981
+ >>> # Call any JAX functionality here.
982
+ >>> import jax.numpy as jnp
983
+ >>> return a + jnp.sin(b)
984
+ >>>
985
+ >>> # Pass PyTorch/XLA tensors to JAX function this way.
986
+ >>> a = torch.ones((3, 3), device='xla')
987
+ >>> b = xb.call_jax(f, (a, a))
988
+ >>>
989
+ >>> # Result is the same as if we ran the equivalent torch ops.
990
+ >>> torch.testing.assert_close(b.cpu(), torch.sin(torch.ones(3, 3)) + 1)
991
+
992
+ ## Caching
993
+
994
+ In order to call `jax_func`, we will jit compile it into HLO, which involves tracing
995
+ the function. The address of `jax_func` and the shapes of `args` and `kwargs` is used
996
+ as the key into a cache to avoid repeated tracing/compilation, similar to how `jax.jit`
997
+ works. If you get tracing overhead, check if `jax_func` is being redefined all the time.
998
+ A common mistake is defining `jax_func` as a local function, e.g. during a training step.
884
999
"""
885
1000
886
1001
kwargs = kwargs or {}
0 commit comments