Skip to content

Commit 73aadbf

Browse files
author
jax authors
committed
Adding expectations to e2e test with expectations for mesh creation allowing splitting physical axes.
PiperOrigin-RevId: 618028683
1 parent 9523547 commit 73aadbf

File tree

1 file changed

+66
-8
lines changed

1 file changed

+66
-8
lines changed

tests/mesh_utils_test.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax.sharding import Mesh # pylint: disable=g-importing-member
2727
import numpy as np
2828

29+
# pyformat: disable
2930

3031
@dataclasses.dataclass(frozen=True)
3132
class MockTpuDevice:
@@ -241,22 +242,79 @@ def test_create_device_mesh_for_nd_torus_split_axes_backward_compatible(
241242
self.assertArraysEqual(assignment, expected_assignment_matrix)
242243

243244
@parameterized.named_parameters(
244-
('4x4x4a', mock_4x4x4_devices, [2, 1, 32]),
245-
('4x4x4b', mock_4x4x4_devices, [8, 8, 1]),
246-
('4x4x8a', mock_4x4x8_devices, [2, 2, 8, 4]),
247-
('4x4x8b', mock_4x4x8_devices, [2, 4, 1, 16]),
248-
('4x8x8', mock_4x8x8_devices, [1, 128, 2]),
249-
('8x8', mock_8x8_devices, [2, 1, 32, 1]),
245+
(
246+
'4x4x4a',
247+
mock_4x4x4_devices,
248+
[2, 1, 32],
249+
[
250+
[1, 1, 4],
251+
[1, 1, 4],
252+
[2, 1, 2],
253+
],
254+
),
255+
(
256+
'4x4x4b',
257+
mock_4x4x4_devices,
258+
[8, 8, 1],
259+
[
260+
[1, 4, 1],
261+
[2, 2, 1],
262+
[4, 1, 1],
263+
],
264+
),
265+
(
266+
'4x4x8a',
267+
mock_4x4x8_devices,
268+
[2, 2, 8, 4],
269+
[
270+
[1, 1, 1, 4],
271+
[2, 2, 1, 1],
272+
[1, 1, 8, 1],
273+
],
274+
),
275+
(
276+
'4x4x8b',
277+
mock_4x4x8_devices,
278+
[2, 4, 1, 16],
279+
[
280+
[1, 1, 1, 4],
281+
[1, 1, 1, 4],
282+
[2, 4, 1, 1],
283+
],
284+
),
285+
(
286+
'4x8x8',
287+
mock_4x8x8_devices,
288+
[1, 128, 2],
289+
[
290+
[1, 2, 2],
291+
[1, 8, 1],
292+
[1, 8, 1],
293+
],
294+
),
295+
(
296+
'8x8',
297+
mock_8x8_devices,
298+
[2, 1, 32, 1],
299+
[
300+
[1, 1, 8, 1],
301+
[2, 1, 4, 1],
302+
[1, 1, 1, 1],
303+
],
304+
),
250305
)
251306
def test_create_device_mesh_for_nd_torus_split_axes_can_handle_axes_split(
252-
self, devices, mesh_shape
307+
self, devices, mesh_shape, assignment_matrix
253308
):
254309
jax_devices = devices(True)
255310
physical_mesh = mesh_utils._get_physical_tpu_mesh(jax_devices)
256-
logical_mesh, _ = mesh_utils._create_device_mesh_for_nd_torus(
311+
logical_mesh, assignment = mesh_utils._create_device_mesh_for_nd_torus(
257312
physical_mesh, mesh_shape, allow_split_physical_axes=True
258313
)
259314
self.assertEqual(logical_mesh.shape, tuple(mesh_shape))
315+
self.assertArraysEqual(
316+
assignment, np.array(assignment_matrix, dtype=np.int64)
317+
)
260318

261319
@parameterized.named_parameters(
262320
('2X4x4x4a', (1, 16, 4), (2, 1, 1)),

0 commit comments

Comments
 (0)