Skip to content

Commit 8e6a5e5

Browse files
authored
composibility of assume_pure and call_jax (#8989)
1 parent e9cb086 commit 8e6a5e5

File tree

11 files changed

+323
-54
lines changed

11 files changed

+323
-54
lines changed

.github/workflows/torchax.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,5 @@ jobs:
5656
pytest test/test_libraries.py
5757
pytest test/test_symbolic_shapes.py
5858
pytest test/test_exports.py
59+
pytest test/test_util.py
5960
XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest -n 0 test_dist/

test/scan/test_scan_spmd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch_xla
88
import torch.nn as nn
99
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh
10+
from torch_xla.experimental.assume_pure import assume_pure
1011
from torch_xla.experimental.scan import scan
1112
from torch_xla.experimental.scan_layers import scan_layers
1213
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh
@@ -229,6 +230,19 @@ def check_dots_in_model(self, model, x, expect_pattern):
229230
def count_regex(self, hlo_text, regex_str):
230231
return len(re.findall(regex_str, hlo_text))
231232

233+
def test_assume_pure_works_with_mark_sharding(self):
234+
x = torch.randn((3, 4, 5, 128), device='xla')
235+
assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None))
236+
# assert not throwing
237+
238+
def test_convert_to_jax_mesh(self):
239+
jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh()
240+
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
241+
np.testing.assert_equal(
242+
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
243+
self.spmd_mesh.device_ids)
244+
# assert not throwing
245+
232246

233247
if __name__ == '__main__':
234248
test = unittest.main()

test/test_assume_pure.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,44 @@ def original_func(a, b):
369369
self.assertIsNone(a_pure.grad)
370370
self.assertIsNone(b_pure.grad)
371371

372+
def test_composibility_with_call_jax(self):
373+
374+
def jax_func(a, b):
375+
return jnp.dot(a, b)
376+
377+
def f(a, b):
378+
return xb.call_jax(jax_func, (a, b))
379+
380+
a = torch.randn(3, 3, device='xla')
381+
b = torch.randn(3, 3, device='xla')
382+
383+
output_pure = assume_pure(f)(a, b)
384+
torch.testing.assert_close(
385+
output_pure,
386+
a @ b,
387+
msg="Forward outputs do not match",
388+
check_device=False)
389+
390+
def test_assume_pure_recursive(self):
391+
392+
@assume_pure
393+
def torch_func(a, b):
394+
return torch.matmul(a, b)
395+
396+
def f(a, b):
397+
y = torch_func(a, b)
398+
return y + 1
399+
400+
a = torch.randn(3, 3, device='xla')
401+
b = torch.randn(3, 3, device='xla')
402+
403+
output_pure = assume_pure(f)(a, b)
404+
torch.testing.assert_close(
405+
output_pure,
406+
a @ b + 1,
407+
msg="Forward outputs do not match",
408+
check_device=False)
409+
372410

373411
FLAGS = flags.FLAGS
374412
flags.DEFINE_integer(
@@ -436,5 +474,6 @@ def pure_call(params, x):
436474
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
437475
jax_import_guard()
438476
import torchax
477+
import jax.numpy as jnp
439478
torchax.enable_accuracy_mode()
440479
absltest.main()

torch_xla/_internal/jax_workarounds.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,16 @@ def jax_env_context():
4242
os.environ['SKIP_MEGASCALE_PJRT_CLIENT'] = previous_skip_megascale_env
4343
else:
4444
os.environ.pop('SKIP_MEGASCALE_PJRT_CLIENT', None)
45+
46+
47+
def maybe_get_torchax():
48+
try:
49+
jax_import_guard()
50+
with jax_env_context():
51+
import torchax
52+
import torchax.tensor
53+
import torchax.interop
54+
import torchax.ops.mappings
55+
return torchax
56+
except ImportError:
57+
return None

torch_xla/core/xla_builder.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch_xla
66
from torch.utils._pytree import tree_flatten
7-
from torch_xla._internal.jax_workarounds import jax_env_context, jax_import_guard, requires_jax
7+
from torch_xla._internal.jax_workarounds import jax_env_context, jax_import_guard, requires_jax, maybe_get_torchax
88

99

1010
class Type:
@@ -869,7 +869,7 @@ def jax_func_to_xla_computation(jax_func, args, kwargs, name=None):
869869
# then the MegaScale device discovery will hang.
870870
with jax_env_context():
871871
import jax
872-
import torchax.ops.mappings as mappings
872+
tx = maybe_get_torchax()
873873

874874
flattened_inputs, spec = jax.tree.flatten((args, kwargs))
875875

@@ -878,7 +878,7 @@ def abstractify(a): # make a pytree leaf abstract
878878
return None
879879
if isinstance(a, torch.Tensor):
880880
assert a.device.type == 'xla', f"Inputs must be XLA tensors. Got {a.device}"
881-
return jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype))
881+
return jax.ShapeDtypeStruct(a.shape, tx.ops.mappings.t2j_dtype(a.dtype))
882882
return a
883883

884884
sample_inputs = tuple(abstractify(a) for a in flattened_inputs)
@@ -1019,6 +1019,10 @@ def call_jax(jax_func,
10191019
import jax
10201020
kwargs = kwargs or {}
10211021
flattened, _spec = jax.tree.flatten((args, kwargs))
1022+
tx = maybe_get_torchax()
1023+
if tx is not None and any(isinstance(a, tx.tensor.Tensor) for a in flattened):
1024+
return tx.interop.call_jax(jax_func, *args, **kwargs)
1025+
10221026
xla_computation = jax_func_to_xla_computation(jax_func, args, kwargs, name)
10231027
return xla_computation(flattened)
10241028

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard
1414
import torch_xla.runtime as xr
1515
import torch_xla.debug.profiler as xp
16-
from torch_xla._internal.jax_workarounds import requires_jax
16+
from torch_xla._internal.jax_workarounds import requires_jax, maybe_get_torchax
1717

1818
import numpy as np
1919
import functools
@@ -181,6 +181,29 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]:
181181
except (ValueError, SyntaxError, KeyError, TypeError):
182182
return None
183183

184+
@requires_jax
185+
def maybe_convert_and_get_jax_mesh(self):
186+
# Construct a JAX mesh object with the same device ids shape and ordering
187+
# from torch_xla device mesh.
188+
import jax
189+
import numpy as np
190+
from jax._src import mesh as mesh_lib
191+
192+
axis_names = self.axis_names or tuple(
193+
str(i) for i in range(len(self.mesh_shape)))
194+
195+
# Create a mapping from device ID to device object
196+
all_devices = jax.devices()
197+
device_id_to_device = {device.id: device for device in all_devices}
198+
device_ids_array = self.device_ids.reshape(*self.mesh_shape)
199+
device_array = np.empty(device_ids_array.shape, dtype=object)
200+
device_array = np.vectorize(device_id_to_device.get)(device_ids_array)
201+
if np.any(device_array == None):
202+
raise ValueError(
203+
f"torch_xla device ID {device_ids_array[device_array == None]} not found in available JAX devices"
204+
)
205+
return mesh_lib.Mesh(device_array, axis_names=axis_names)
206+
184207

185208
_GLOBAL_MESH: Optional[Mesh] = None
186209

@@ -584,6 +607,14 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
584607
assert len(t.shape) == len(partition_spec), \
585608
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
586609

610+
tx = maybe_get_torchax()
611+
if tx is not None and isinstance(t, tx.tensor.Tensor):
612+
from jax.sharding import PartitionSpec as P, NamedSharding
613+
op_sharding = tuple(str(i) if i is not None else i for i in partition_spec)
614+
jmesh = mesh.maybe_convert_and_get_jax_mesh()
615+
t.shard_(NamedSharding(jmesh, P(*op_sharding)))
616+
return t
617+
587618
op_sharding = mesh.get_op_sharding(partition_spec)
588619
annotate_func = torch_xla._XLAC._xla_mark_sharding
589620
annotate_func(unwrap_sharded_tensor(t), op_sharding)

torch_xla/experimental/splash_attention.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,32 +63,6 @@ def list_to_tuple(x):
6363
converted_data = {k: list_to_tuple(v) for k, v in json_data.items()}
6464
return SplashAttentionConfig(**converted_data)
6565

66-
@requires_jax
67-
def maybe_convert_and_get_jax_mesh(self):
68-
# Construct a JAX mesh object with the same device ids shape and ordering
69-
# from torch_xla device mesh.
70-
mesh = Mesh.from_str(self.mesh)
71-
import jax
72-
import numpy as np
73-
from jax._src import mesh as mesh_lib
74-
75-
assert mesh.axis_names is not None, "Omitting axis names is not yet supported"
76-
77-
# Create a mapping from device ID to device object
78-
all_devices = jax.devices()
79-
device_id_to_device = {device.id: device for device in all_devices}
80-
device_ids_array = mesh.device_ids.reshape(*mesh.mesh_shape)
81-
device_array = np.empty(device_ids_array.shape, dtype=object)
82-
for idx in np.ndindex(device_ids_array.shape):
83-
device_id = device_ids_array[idx]
84-
if device_id in device_id_to_device:
85-
device_array[idx] = device_id_to_device[device_id]
86-
else:
87-
raise ValueError(
88-
f"torch_xla device ID {device_id} not found in available JAX devices"
89-
)
90-
return mesh_lib.Mesh(device_array, axis_names=mesh.axis_names)
91-
9266

9367
@xp.trace_me("splash_attention_kernel_wrapper")
9468
def splash_attention_jax_wrapper(
@@ -112,7 +86,7 @@ def splash_attention_jax_wrapper(
11286
splash_attention_kernel,
11387
splash_attention_mask,
11488
)
115-
mesh = config.maybe_convert_and_get_jax_mesh()
89+
mesh = Mesh.from_str(config.mesh).maybe_convert_and_get_jax_mesh()
11690
# input q,k,v shape: [batch, #head, seq_len, head_dim]
11791
if decoder_segment_ids is not None and not decoder_segment_ids.shape:
11892
decoder_segment_ids = None

torchax/test/test_util.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import unittest
2+
from torchax.util import partition, merge
3+
4+
# Helper predicate functions for testing partition
5+
def is_even(n):
6+
return isinstance(n, int) and n % 2 == 0
7+
8+
def is_positive(n):
9+
return isinstance(n, (int, float)) and n > 0
10+
11+
def is_string(s):
12+
return isinstance(s, str)
13+
14+
15+
class TestListUtils(unittest.TestCase):
16+
17+
# --- Tests for partition ---
18+
19+
def test_partition_empty_list(self):
20+
"""Test partition with an empty list."""
21+
self.assertEqual(partition([], is_even), ([], []))
22+
23+
def test_partition_even_odd(self):
24+
"""Test partitioning numbers into even and odd."""
25+
nums = [1, 2, 3, 4, 5, 6]
26+
expected_truthy = [None, 2, None, 4, None, 6]
27+
expected_falsy = [1, None, 3, None, 5, None]
28+
self.assertEqual(partition(nums, is_even), (expected_truthy, expected_falsy))
29+
30+
def test_partition_all_true(self):
31+
"""Test partition when the predicate is always true."""
32+
evens = [2, 4, 6, 8]
33+
expected_truthy = [2, 4, 6, 8]
34+
expected_falsy = [None, None, None, None]
35+
self.assertEqual(partition(evens, is_even), (expected_truthy, expected_falsy))
36+
37+
def test_partition_all_false(self):
38+
"""Test partition when the predicate is always false."""
39+
odds = [1, 3, 5, 7]
40+
expected_truthy = [None, None, None, None]
41+
expected_falsy = [1, 3, 5, 7]
42+
self.assertEqual(partition(odds, is_even), (expected_truthy, expected_falsy))
43+
44+
def test_partition_mixed_types(self):
45+
"""Test partition with a list of mixed types."""
46+
mixed = [1, "hello", 2.5, "world", 3, None]
47+
# Using is_string as the predicate
48+
expected_truthy = [None, "hello", None, "world", None, None]
49+
expected_falsy = [1, None, 2.5, None, 3, None] # Note: None itself is not a string
50+
self.assertEqual(partition(mixed, is_string), (expected_truthy, expected_falsy))
51+
52+
def test_partition_with_lambda(self):
53+
"""Test partition using a lambda function as the predicate."""
54+
nums = [-2, -1, 0, 1, 2]
55+
expected_truthy = [None, None, None, 1, 2]
56+
expected_falsy = [-2, -1, 0, None, None]
57+
self.assertEqual(partition(nums, lambda x: isinstance(x, int) and x > 0), (expected_truthy, expected_falsy))
58+
59+
# --- Tests for merge ---
60+
61+
def test_merge_empty_lists(self):
62+
"""Test merge with empty lists."""
63+
self.assertEqual(merge([], []), [])
64+
65+
def test_merge_basic(self):
66+
"""Test basic merging with None values in the first list."""
67+
list1 = [1, None, 3, None, 5]
68+
list2 = [None, 2, None, 4, None]
69+
expected = [1, 2, 3, 4, 5]
70+
self.assertEqual(merge(list1, list2), expected)
71+
72+
def test_merge_no_none_in_list1(self):
73+
"""Test merge when list1 has no None values."""
74+
list1 = ['a', 'b', 'c']
75+
list2 = [1, 2, 3]
76+
expected = ['a', 'b', 'c'] # Should be identical to list1
77+
self.assertEqual(merge(list1, list2), expected)
78+
79+
def test_merge_all_none_in_list1(self):
80+
"""Test merge when list1 contains only None."""
81+
list1 = [None, None, None]
82+
list2 = ['x', 'y', 'z']
83+
expected = ['x', 'y', 'z'] # Should be identical to list2
84+
self.assertEqual(merge(list1, list2), expected)
85+
86+
def test_merge_mixed_types(self):
87+
"""Test merge with mixed data types."""
88+
list1 = [1, None, "hello", None]
89+
list2 = [None, 2.5, None, True]
90+
expected = [1, 2.5, "hello", True]
91+
self.assertEqual(merge(list1, list2), expected)
92+
93+
def test_merge_unequal_lengths(self):
94+
"""Test that merge raises AssertionError for lists of unequal length."""
95+
list1 = [1, 2, 3]
96+
list2 = [4, 5]
97+
# Use assertRaises as a context manager
98+
with self.assertRaises(AssertionError) as cm:
99+
merge(list1, list2)
100+
101+
list3 = [6, 7]
102+
list4 = [8, 9, 10]
103+
with self.assertRaises(AssertionError):
104+
merge(list3, list4) # No need to check message again if already checked
105+
106+
if __name__ == '__main__':
107+
unittest.main() # For running from command line

torchax/torchax/environment.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

0 commit comments

Comments
 (0)