Skip to content

Commit a3ef52e

Browse files
authored
Cache HLO in xb.call_jax and support non-tensor args (#8878)
1 parent 2aa11cc commit a3ef52e

File tree

3 files changed

+254
-27
lines changed

3 files changed

+254
-27
lines changed

test/test_jax_interop.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,117 @@ def f_jax(a, b):
103103
# backward should produce same gradient
104104
torch.testing.assert_close(out_grad_torch, out_grad_jax)
105105

106+
def test_call_jax_non_tensor_args(self):
107+
"""Test that call_jax works with non-tensor arguments."""
108+
109+
dev = xm.xla_device()
110+
a = torch.ones((3, 3), device=dev)
111+
112+
def f(a, num: float, string: str, dictionary: dict, none):
113+
assert isinstance(string, str)
114+
import jax.numpy as jnp
115+
if none is None:
116+
return a + jnp.sin(num) + int(string) + dictionary['x']
117+
raise ValueError('none should be None')
118+
119+
b = xb.call_jax(
120+
f, (
121+
a,
122+
1.0,
123+
"10",
124+
{
125+
"x": torch.tensor(0.25, device=dev)
126+
},
127+
),
128+
kwargs={"none": None})
129+
torch_xla.sync()
130+
torch.testing.assert_close(
131+
b, torch.sin(torch.ones(3, 3)) + 1 + 10 + 0.25, check_device=False)
132+
133+
def test_call_jax_cache_hlo(self):
134+
"""Test that the HLO of a jax function should be cached."""
135+
136+
starting_cache_misses = xb._jax_to_hlo_cache_num_misses()
137+
138+
# Let's trace two different jax functions a couple of times.
139+
dev = xm.xla_device()
140+
a = torch.ones((3, 3), device=dev)
141+
142+
def f(a, b):
143+
import jax.numpy as jnp
144+
return a + jnp.sin(b)
145+
146+
def g(a, b):
147+
import jax.numpy as jnp
148+
return a + jnp.cos(b)
149+
150+
xb.call_jax(f, (a, a))
151+
xb.call_jax(f, (a, a))
152+
xb.call_jax(g, (a, a))
153+
xb.call_jax(g, (a, a))
154+
155+
ending_cache_misses = xb._jax_to_hlo_cache_num_misses()
156+
self.assertEqual(ending_cache_misses - starting_cache_misses, 2)
157+
158+
def test_call_jax_cache_by_shape(self):
159+
"""Test that the same function may be traced again if the shape of its arguments changes."""
160+
161+
starting_cache_misses = xb._jax_to_hlo_cache_num_misses()
162+
163+
# Let's trace the same jax function with different shapes.
164+
dev = xm.xla_device()
165+
a = torch.ones((3, 3), device=dev)
166+
b = torch.ones((2, 2), device=dev)
167+
168+
def f(a, b):
169+
import jax.numpy as jnp
170+
return a + jnp.sin(b)
171+
172+
xb.call_jax(f, (a, a))
173+
xb.call_jax(f, (b, b))
174+
175+
ending_cache_misses = xb._jax_to_hlo_cache_num_misses()
176+
self.assertEqual(ending_cache_misses - starting_cache_misses, 2)
177+
178+
def test_call_jax_cache_by_tree_spec(self):
179+
"""Test that the same function may be traced again if the tree spec of its arguments changes."""
180+
starting_cache_misses = xb._jax_to_hlo_cache_num_misses()
181+
182+
# Let's trace the same jax function with different tree specs.
183+
dev = xm.xla_device()
184+
a = torch.ones((3, 3), device=dev)
185+
b = torch.ones((3, 2), device=dev)
186+
187+
def f(inputs):
188+
a = inputs['a']
189+
b = inputs['b']
190+
return a @ b
191+
192+
xb.call_jax(f, ({'a': a, 'b': a},))
193+
xb.call_jax(f, ({'a': a, 'b': b},))
194+
195+
ending_cache_misses = xb._jax_to_hlo_cache_num_misses()
196+
self.assertEqual(ending_cache_misses - starting_cache_misses, 2)
197+
198+
def test_call_jax_cache_by_static_args(self):
199+
"""Test that the same function may be traced again if a non-tensor argument changes."""
200+
starting_cache_misses = xb._jax_to_hlo_cache_num_misses()
201+
202+
# Let's trace the same jax function with different static args.
203+
dev = xm.xla_device()
204+
a = torch.ones((3, 3), device=dev)
205+
206+
def f(a, num: float):
207+
import jax.numpy as jnp
208+
return a + jnp.sin(num)
209+
210+
xb.call_jax(f, (a, 1.0))
211+
xb.call_jax(f, (a, 2.0))
212+
xb.call_jax(f, (a, 3.0))
213+
214+
ending_cache_misses = xb._jax_to_hlo_cache_num_misses()
215+
self.assertEqual(ending_cache_misses - starting_cache_misses, 3)
216+
106217

107218
if __name__ == "__main__":
108219
absltest.main()

torch_xla/core/xla_builder.py

Lines changed: 142 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from copy import copy
2+
from typing import Any, Optional
3+
from weakref import WeakKeyDictionary
14
import torch
25
import torch_xla
36
from torch.utils._pytree import tree_flatten, tree_unflatten
4-
from torch_xla.experimental.custom_kernel import jax_import_guard
7+
from torch_xla.experimental.custom_kernel import _jax_env_context, jax_import_guard
58

69

710
class Type:
@@ -827,21 +830,23 @@ def get_computation_hlo(computation):
827830

828831
class XlaComputation:
829832

830-
def __init__(self, name, hlo_module, flattened_inputs):
833+
def __init__(self, name, hlo_module, flattened_inputs, pick_tensor_args):
831834
self.num_inputs = len(flattened_inputs)
832835
builder = create_builder(name)
833836
computation = computation_from_module_proto(name, hlo_module)
834837
params = []
835838
for idx, val in enumerate(flattened_inputs):
836839
params.append(mkparam(builder, idx, tensor_shape(val)))
837840
call_op = Op.call(computation, params)
838-
call_computation = call_op.build('call_jax')
841+
call_computation = call_op.build(f'call_jax_{name}')
839842
self.call_computation = call_computation
840843
self.name = name
844+
self.pick_tensor_args = pick_tensor_args
841845

842846
def __call__(self, input_list):
847+
input_tensors = self.pick_tensor_args(input_list)
843848
result = torch_xla._XLAC._xla_user_computation(f'xla::call_jax_{self.name}',
844-
input_list,
849+
input_tensors,
845850
self.call_computation)
846851
if isinstance(result, list) and len(result) == 1:
847852
return result[0]
@@ -855,32 +860,142 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None):
855860
# If we don't do this before calling jax, any torch_xla operation will hang.
856861
jax_import_guard()
857862

858-
import jax
859-
import torchax.ops.mappings as mappings
860-
861-
flattened, spec = tree_flatten((args, kwargs))
862-
863-
def fn_flattened_inputs(*flattened):
864-
args, kwargs = tree_unflatten(flattened, spec)
865-
return jax_func(*args, **kwargs)
866-
867-
sample_input_shapes = tuple(
868-
jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype))
869-
for a in flattened)
870-
# `as_serialized_hlo_module_proto` is mentioned at
871-
# https://github.com/jax-ml/jax/discussions/22266
872-
hlo_module = jax.jit(
873-
fn_flattened_inputs,
874-
keep_unused=True).lower(*sample_input_shapes).compiler_ir(
875-
'hlo').as_serialized_hlo_module_proto() # type: ignore
876-
877-
return XlaComputation(name, hlo_module, flattened)
878-
879-
880-
def call_jax(jax_func, args, kwargs=None, name=None):
863+
# Prevent JAX from discovering MegaScale devices a second time. If we don't do this,
864+
# then the MegaScale device discovery will hang.
865+
with _jax_env_context():
866+
import jax
867+
import torchax.ops.mappings as mappings
868+
869+
flattened_inputs, spec = jax.tree.flatten((args, kwargs))
870+
871+
def abstractify(a): # make a pytree leaf abstract
872+
import jax
873+
import torch_xla
874+
if a is None:
875+
return None
876+
if isinstance(a, torch.Tensor):
877+
assert a.device == torch_xla.device(
878+
), f"Inputs must be XLA tensors. Got {a.device}"
879+
return jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype))
880+
return a
881+
882+
sample_inputs = tuple(abstractify(a) for a in flattened_inputs)
883+
884+
# Pick out the non-static args.
885+
# Consider anything that is not a `jax.ShapeDtypeStruct` as a static arg.
886+
def pick_tensor_args(flattened_args):
887+
tensor_args = []
888+
for i in range(len(sample_inputs)):
889+
if isinstance(sample_inputs[i], jax.ShapeDtypeStruct):
890+
tensor_args.append(flattened_args[i])
891+
return tensor_args
892+
893+
sample_tensor_args = pick_tensor_args(sample_inputs)
894+
tensor_args = pick_tensor_args(flattened_inputs)
895+
896+
# This function only takes in tensor arguments because its signature must
897+
# match the signature of the HLO module lowered from JAX, allowing us to
898+
# wrap it in an XLA user computation.
899+
def fn(*tensor_args):
900+
# Go from a list of tensor args to the full list of flattened arguments,
901+
# by referencing the original flattened inputs.
902+
new_flattened = copy(flattened_inputs)
903+
tensor_args_iter = iter(tensor_args)
904+
for i in range(len(sample_inputs)):
905+
if isinstance(sample_inputs[i], jax.ShapeDtypeStruct):
906+
new_flattened[i] = next(tensor_args_iter)
907+
args, kwargs = jax.tree.unflatten(spec, new_flattened)
908+
return jax_func(*args, **kwargs)
909+
910+
def get_hlo():
911+
import torch_xla.debug.profiler as xp
912+
# If we see this trace span in the profiler, we'll know that there's a cache miss.
913+
with xp.Trace('jax_to_hlo'):
914+
hlo_ir = jax.jit(
915+
fn, keep_unused=True).lower(*sample_tensor_args).compiler_ir('hlo')
916+
917+
# Get a protobuf representation of the HLO. `as_serialized_hlo_module_proto` is
918+
# mentioned at https://github.com/jax-ml/jax/discussions/22266
919+
return hlo_ir.as_serialized_hlo_module_proto() # type: ignore
920+
921+
hlo_module = _jax_to_hlo_cache_get_or_insert(jax_func, sample_inputs, spec,
922+
get_hlo)
923+
return XlaComputation(name, hlo_module, tensor_args, pick_tensor_args)
924+
925+
926+
def _jax_to_hlo_cache_get_or_insert(jax_func, sample_inputs: tuple[Any, ...],
927+
input_tree_spec, get_hlo):
928+
global _JAX_TO_HLO_CACHE
929+
# Use two layers of dictionary lookup.
930+
# The first layer uses the `jax_func`, which is only weakly referenced.
931+
# The second layer uses the sample inputs and the tree spec, which is strongly referenced.
932+
inner_dict = _JAX_TO_HLO_CACHE.get(jax_func, None)
933+
if inner_dict is not None:
934+
hlo = inner_dict.get((sample_inputs, input_tree_spec), None)
935+
if hlo is not None:
936+
return hlo
937+
938+
# Compget_hlo jax function to HLO.
939+
hlo = get_hlo()
940+
if inner_dict is None:
941+
_JAX_TO_HLO_CACHE[jax_func] = {}
942+
_JAX_TO_HLO_CACHE[jax_func][(sample_inputs, input_tree_spec)] = hlo
943+
return hlo
944+
945+
946+
def _jax_to_hlo_cache_num_misses() -> int:
947+
size = 0
948+
for inner_dict in _JAX_TO_HLO_CACHE.values():
949+
size += len(inner_dict)
950+
return size
951+
952+
953+
_JAX_TO_HLO_CACHE = WeakKeyDictionary()
954+
955+
956+
def call_jax(jax_func,
957+
args: tuple[Any, ...],
958+
kwargs: Optional[dict[str, Any]] = None,
959+
name=None):
881960
"""
882961
Call a JAX function `jax_func` with the given `args` and `kwargs` that may contain
883962
XLA tensors.
963+
964+
Args:
965+
jax_func: a functionally pure Python callable that does some math on JAX arrays.
966+
It needs to be `jax.jit` traceable.
967+
968+
args: a tuple of arguments to pass to `jax_func`. Any XLA tensors are turned into
969+
JAX arrays before being passed to `jax_func`.
970+
971+
kwargs: a dictionary of keyword arguments to pass to `jax_func`. Any XLA tensors are
972+
turned into JAX arrays before being passed to `jax_func`.
973+
974+
## Example
975+
976+
>>> import torch
977+
>>> import torch_xla
978+
>>> import torch_xla.core.xla_builder as xb
979+
>>>
980+
>>> def f(a, b):
981+
>>> # Call any JAX functionality here.
982+
>>> import jax.numpy as jnp
983+
>>> return a + jnp.sin(b)
984+
>>>
985+
>>> # Pass PyTorch/XLA tensors to JAX function this way.
986+
>>> a = torch.ones((3, 3), device='xla')
987+
>>> b = xb.call_jax(f, (a, a))
988+
>>>
989+
>>> # Result is the same as if we ran the equivalent torch ops.
990+
>>> torch.testing.assert_close(b.cpu(), torch.sin(torch.ones(3, 3)) + 1)
991+
992+
## Caching
993+
994+
In order to call `jax_func`, we will jit compile it into HLO, which involves tracing
995+
the function. The address of `jax_func` and the shapes of `args` and `kwargs` is used
996+
as the key into a cache to avoid repeated tracing/compilation, similar to how `jax.jit`
997+
works. If you get tracing overhead, check if `jax_func` is being redefined all the time.
998+
A common mistake is defining `jax_func` as a local function, e.g. during a training step.
884999
"""
8851000

8861001
kwargs = kwargs or {}

torch_xla/experimental/custom_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def _extract_backend_config(
160160

161161
@contextmanager
162162
def _jax_env_context():
163+
# TODO(b/374631442): Get rid of this hack.
163164
try:
164165
previous_skip_megascale_env = os.environ.get('SKIP_MEGASCALE_PJRT_CLIENT',
165166
None)

0 commit comments

Comments
 (0)