|
1 | 1 | import copy
|
2 | 2 |
|
3 |
| -import unittest |
4 |
| -from unittest.mock import patch |
| 3 | +from collections import OrderedDict |
5 | 4 | import math
|
6 | 5 | import numpy as np
|
| 6 | +import unittest |
| 7 | +from unittest.mock import patch |
7 | 8 | import sys
|
8 | 9 |
|
9 | 10 | import torch
|
@@ -762,9 +763,12 @@ def test_hybrid_mesh_shape(self):
|
762 | 763 | "Crash on TPU v2")
|
763 | 764 | @patch('torch_xla.runtime.global_runtime_device_attributes')
|
764 | 765 | @patch('torch_xla.core.xla_model.xla_device_hw')
|
765 |
| - def test_hybrid_mesh(self, xla_device_mock, device_attributes_mock): |
| 766 | + @patch('torch_xla.runtime.global_runtime_device_count') |
| 767 | + def test_hybrid_mesh(self, device_count_mock, xla_device_mock, |
| 768 | + device_attributes_mock): |
766 | 769 | # mock device attributes for 2 slices of v4-8
|
767 | 770 | num_slices = 2
|
| 771 | + device_count_mock.return_value = 8 |
768 | 772 | xla_device_mock.return_value = "TPU"
|
769 | 773 | device_attributes_mock.return_value = [{
|
770 | 774 | 'coords': [0, 0, 0],
|
@@ -1565,6 +1569,97 @@ def test_mark_sharding_with_gradients_annotation(self):
|
1565 | 1569 | # Check that the gradient has sharding.
|
1566 | 1570 | self.assertIn(sharding_spec, x_grad_sharding)
|
1567 | 1571 |
|
| 1572 | + def test_valid_mesh_creation(self): |
| 1573 | + mesh_shape = (1, self.n_devices) |
| 1574 | + axis_names = ('data', 'model') |
| 1575 | + mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1576 | + |
| 1577 | + self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices))) |
| 1578 | + self.assertEqual(mesh.mesh_shape, mesh_shape) |
| 1579 | + self.assertEqual(mesh.axis_names, axis_names) |
| 1580 | + |
| 1581 | + def test_valid_mesh_without_axis_names(self): |
| 1582 | + mesh_shape = (1, self.n_devices) |
| 1583 | + mesh = xs.Mesh(self.device_ids, mesh_shape) |
| 1584 | + |
| 1585 | + self.assertEqual(mesh.device_ids.tolist(), list(range(self.n_devices))) |
| 1586 | + self.assertEqual(mesh.mesh_shape, mesh_shape) |
| 1587 | + self.assertIsNone(mesh.axis_names) |
| 1588 | + |
| 1589 | + def test_invalid_axis_names_length(self): |
| 1590 | + mesh_shape = (1, self.n_devices) |
| 1591 | + axis_names = ('data', 'model', 'extra') |
| 1592 | + |
| 1593 | + with self.assertRaisesRegex( |
| 1594 | + AssertionError, "Number of axis names .* must match mesh dimensions"): |
| 1595 | + xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1596 | + |
| 1597 | + def test_duplicate_axis_names(self): |
| 1598 | + mesh_shape = (1, self.n_devices) |
| 1599 | + axis_names = ('data', 'data') |
| 1600 | + |
| 1601 | + with self.assertRaisesRegex(AssertionError, "Axis names must be unique"): |
| 1602 | + xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1603 | + |
| 1604 | + def test_invalid_device_count(self): |
| 1605 | + mesh_shape = (2, self.n_devices) |
| 1606 | + |
| 1607 | + with self.assertRaisesRegex(AssertionError, |
| 1608 | + "Number of device IDs .* must match mesh size"): |
| 1609 | + xs.Mesh(self.device_ids, mesh_shape) |
| 1610 | + |
| 1611 | + @unittest.skipIf(xr.global_runtime_device_count() == 1, |
| 1612 | + "Multiple devices needed for duplicated device IDs") |
| 1613 | + def test_duplicate_device_ids(self): |
| 1614 | + mesh_shape = (1, self.n_devices) |
| 1615 | + duplicate_ids = np.array([0] * self.n_devices) |
| 1616 | + |
| 1617 | + with self.assertRaisesRegex(AssertionError, "Device IDs must be unique"): |
| 1618 | + xs.Mesh(duplicate_ids, mesh_shape) |
| 1619 | + |
| 1620 | + def test_device_ids_out_of_bounds(self): |
| 1621 | + mesh_shape = (1, self.n_devices) |
| 1622 | + invalid_ids = np.arange(self.n_devices + 1, self.n_devices * 2 + 1) |
| 1623 | + |
| 1624 | + with self.assertRaisesRegex(AssertionError, |
| 1625 | + "Device IDs must be less than mesh size"): |
| 1626 | + xs.Mesh(invalid_ids, mesh_shape) |
| 1627 | + |
| 1628 | + def test_mesh_size(self): |
| 1629 | + mesh_shape = (1, self.n_devices) |
| 1630 | + mesh = xs.Mesh(self.device_ids, mesh_shape) |
| 1631 | + self.assertEqual(mesh.size(), self.n_devices) |
| 1632 | + |
| 1633 | + def test_mesh_shape_method(self): |
| 1634 | + mesh_shape = (1, self.n_devices) |
| 1635 | + axis_names = ('data', 'model') |
| 1636 | + mesh = xs.Mesh(self.device_ids, mesh_shape, axis_names) |
| 1637 | + |
| 1638 | + expected_shape = OrderedDict([('data', 1), ('model', self.n_devices)]) |
| 1639 | + self.assertEqual(mesh.shape(), expected_shape) |
| 1640 | + |
| 1641 | + @unittest.skipIf(xr.global_runtime_device_count() == 1, |
| 1642 | + "Multiple devices needed") |
| 1643 | + def test_mismatch_global_devices(self): |
| 1644 | + partial_num_devices = self.n_devices // 2 |
| 1645 | + device_ids = np.arange(partial_num_devices) |
| 1646 | + mesh_shape = (1, partial_num_devices) |
| 1647 | + with self.assertRaisesRegex( |
| 1648 | + AssertionError, |
| 1649 | + "Number of device IDs .* must match the global number of devices"): |
| 1650 | + xs.Mesh(device_ids, mesh_shape) |
| 1651 | + |
| 1652 | + @unittest.skipIf(xr.global_runtime_device_count() == 1, |
| 1653 | + "Multiple devices needed") |
| 1654 | + def test_get_logical_mesh(self): |
| 1655 | + device_ids = np.arange(self.n_devices) |
| 1656 | + mesh_shape = (2, self.n_devices // 2) |
| 1657 | + mesh = xs.Mesh(device_ids, mesh_shape) |
| 1658 | + |
| 1659 | + logical_mesh = mesh.get_logical_mesh() |
| 1660 | + self.assertEqual(logical_mesh.shape, mesh_shape) |
| 1661 | + np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids) |
| 1662 | + |
1568 | 1663 |
|
1569 | 1664 | if __name__ == '__main__':
|
1570 | 1665 | test = unittest.main()
|
|
0 commit comments