Skip to content

Commit 96ad8f5

Browse files
authored
Encapsulate Mesh invariants (#8882)
1 parent 8dc5b49 commit 96ad8f5

File tree

2 files changed

+121
-16
lines changed

2 files changed

+121
-16
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import copy
22

3-
import unittest
4-
from unittest.mock import patch
3+
from collections import OrderedDict
54
import math
65
import numpy as np
6+
import unittest
7+
from unittest.mock import patch
78
import sys
89

910
import torch
@@ -762,9 +763,12 @@ def test_hybrid_mesh_shape(self):
762763
"Crash on TPU v2")
763764
@patch('torch_xla.runtime.global_runtime_device_attributes')
764765
@patch('torch_xla.core.xla_model.xla_device_hw')
765-
def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock):
766+
@patch('torch_xla.runtime.global_runtime_device_count')
767+
def test_hybrid_mesh(self, device_count_mock, xla_device_mock,
768+
device_attributes_mock):
766769
# mock device attributes for 2 slices of v4-8
767770
num_slices = 2
771+
device_count_mock.return_value = 8
768772
xla_device_mock.return_value = "TPU"
769773
device_attributes_mock.return_value = [{
770774
'coords': [0, 0, 0],
@@ -1565,6 +1569,97 @@ def test_mark_sharding_with_gradients_annotation(self):
15651569
# Check that the gradient has sharding.
15661570
self.assertIn(sharding_spec, x_grad_sharding)
15671571

1572+
def test_valid_mesh_creation(self):
1573+
mesh_shape = (1, self.n_devices)
1574+
axis_names = ('data', 'model')
1575+
mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)
1576+
1577+
self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))
1578+
self.assertEqual(mesh.mesh_shape, mesh_shape)
1579+
self.assertEqual(mesh.axis_names, axis_names)
1580+
1581+
def test_valid_mesh_without_axis_names(self):
1582+
mesh_shape = (1, self.n_devices)
1583+
mesh = xs.Mesh(self.device_ids, mesh_shape)
1584+
1585+
self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices)))
1586+
self.assertEqual(mesh.mesh_shape, mesh_shape)
1587+
self.assertIsNone(mesh.axis_names)
1588+
1589+
def test_invalid_axis_names_length(self):
1590+
mesh_shape = (1, self.n_devices)
1591+
axis_names = ('data', 'model', 'extra')
1592+
1593+
with self.assertRaisesRegex(
1594+
AssertionError, "Number of axis names .* must match mesh dimensions"):
1595+
xs.Mesh(self.device_ids, mesh_shape, axis_names)
1596+
1597+
def test_duplicate_axis_names(self):
1598+
mesh_shape = (1, self.n_devices)
1599+
axis_names = ('data', 'data')
1600+
1601+
with self.assertRaisesRegex(AssertionError, "Axis names must be unique"):
1602+
xs.Mesh(self.device_ids, mesh_shape, axis_names)
1603+
1604+
def test_invalid_device_count(self):
1605+
mesh_shape = (2, self.n_devices)
1606+
1607+
with self.assertRaisesRegex(AssertionError,
1608+
"Number of device IDs .* must match mesh size"):
1609+
xs.Mesh(self.device_ids, mesh_shape)
1610+
1611+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
1612+
"Multiple devices needed for duplicated device IDs")
1613+
def test_duplicate_device_ids(self):
1614+
mesh_shape = (1, self.n_devices)
1615+
duplicate_ids = np.array([0] * self.n_devices)
1616+
1617+
with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"):
1618+
xs.Mesh(duplicate_ids, mesh_shape)
1619+
1620+
def test_device_ids_out_of_bounds(self):
1621+
mesh_shape = (1, self.n_devices)
1622+
invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1)
1623+
1624+
with self.assertRaisesRegex(AssertionError,
1625+
"Device IDs must be less than mesh size"):
1626+
xs.Mesh(invalid_ids, mesh_shape)
1627+
1628+
def test_mesh_size(self):
1629+
mesh_shape = (1, self.n_devices)
1630+
mesh = xs.Mesh(self.device_ids, mesh_shape)
1631+
self.assertEqual(mesh.size(), self.n_devices)
1632+
1633+
def test_mesh_shape_method(self):
1634+
mesh_shape = (1, self.n_devices)
1635+
axis_names = ('data', 'model')
1636+
mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names)
1637+
1638+
expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)])
1639+
self.assertEqual(mesh.shape(), expected_shape)
1640+
1641+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
1642+
"Multiple devices needed")
1643+
def test_mismatch_global_devices(self):
1644+
partial_num_devices = self.n_devices // 2
1645+
device_ids = np.arange(partial_num_devices)
1646+
mesh_shape = (1, partial_num_devices)
1647+
with self.assertRaisesRegex(
1648+
AssertionError,
1649+
"Number of device IDs .* must match the global number of devices"):
1650+
xs.Mesh(device_ids, mesh_shape)
1651+
1652+
@unittest.skipIf(xr.global_runtime_device_count() == 1,
1653+
"Multiple devices needed")
1654+
def test_get_logical_mesh(self):
1655+
device_ids = np.arange(self.n_devices)
1656+
mesh_shape = (2, self.n_devices // 2)
1657+
mesh = xs.Mesh(device_ids, mesh_shape)
1658+
1659+
logical_mesh = mesh.get_logical_mesh()
1660+
self.assertEqual(logical_mesh.shape, mesh_shape)
1661+
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
1662+
15681663

15691664
if __name__ == '__main__':
15701665
test = unittest.main()

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,32 @@ def __init__(self,
6969
axis_names: Optional[tuple[str, ...]] = None):
7070
if not isinstance(device_ids, np.ndarray):
7171
device_ids = np.array(device_ids)
72-
assert (axis_names is None) or (len(mesh_shape) == len(axis_names))
73-
assert axis_names is None or (len(set(axis_names)) == len(axis_names))
74-
assert (len(device_ids) == np.prod(mesh_shape))
75-
assert len(device_ids) == len(np.unique(device_ids))
72+
73+
# At the moment, XLA requires that the Mesh uses the global number of
74+
# devices.
75+
num_devices = xr.global_runtime_device_count()
76+
assert num_devices > 0, "This requires XLA supported device(s)."
77+
assert num_devices == len(
78+
device_ids
79+
), f"Number of device IDs ({len(device_ids)}) must match the global number of devices ({num_devices})"
80+
81+
if axis_names is not None:
82+
assert len(mesh_shape) == len(axis_names), \
83+
f"Number of axis names ({len(axis_names)}) must match mesh dimensions ({len(mesh_shape)})"
84+
assert len(set(axis_names)) == len(axis_names), \
85+
f"Axis names must be unique, got: {axis_names}"
86+
87+
expected_devices = np.prod(mesh_shape)
88+
assert len(device_ids) == expected_devices, \
89+
f"Number of device IDs ({len(device_ids)}) must match mesh size ({expected_devices})"
90+
assert len(device_ids) == len(np.unique(device_ids)), \
91+
f"Device IDs must be unique, got: {device_ids}"
92+
7693
self.device_ids = device_ids
7794
self.mesh_shape = mesh_shape
7895
self.axis_names = axis_names
79-
assert all(d < self.size() for d in device_ids)
96+
assert all(d < self.size() for d in device_ids), \
97+
f"Device IDs must be less than mesh size ({self.size()}), got: {device_ids}"
8098

8199
def size(self):
82100
return np.prod(self.mesh_shape)
@@ -555,10 +573,6 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
555573
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
556574
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
557575
"""
558-
num_devices = xr.global_runtime_device_count()
559-
assert num_devices > 0, "This requires XLA supported device(s)."
560-
assert mesh.size() == num_devices, \
561-
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
562576
# We only allow fully specified `partition_spec` to be applicable, as opposed
563577
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
564578
# should be of the same rank as `t`. This is to support partial replication
@@ -603,10 +617,6 @@ def mark_sharding_with_gradients(
603617
604618
This version can also be used in AOTAutograd.
605619
"""
606-
num_devices = xr.global_runtime_device_count()
607-
assert num_devices > 0, "This requires XLA supported device(s)."
608-
assert mesh.size() == num_devices, \
609-
f"{mesh.mesh_shape} is not mappable over {num_devices} devices."
610620
# We only allow fully specified `partition_spec` to be applicable, as opposed
611621
# to filling in the unspecified replicated dims. Fully specified `partiion_spec`
612622
# should be of the same rank as `t`. This is to support partial replication

0 commit comments

Comments
 (0)