|
26 | 26 | from jax.sharding import Mesh # pylint: disable=g-importing-member
|
27 | 27 | import numpy as np
|
28 | 28 |
|
| 29 | +# pyformat: disable |
29 | 30 |
|
30 | 31 | @dataclasses.dataclass(frozen=True)
|
31 | 32 | class MockTpuDevice:
|
@@ -241,22 +242,79 @@ def test_create_device_mesh_for_nd_torus_split_axes_backward_compatible(
|
241 | 242 | self.assertArraysEqual(assignment, expected_assignment_matrix)
|
242 | 243 |
|
243 | 244 | @parameterized.named_parameters(
|
244 |
| - ('4x4x4a', mock_4x4x4_devices, [2, 1, 32]), |
245 |
| - ('4x4x4b', mock_4x4x4_devices, [8, 8, 1]), |
246 |
| - ('4x4x8a', mock_4x4x8_devices, [2, 2, 8, 4]), |
247 |
| - ('4x4x8b', mock_4x4x8_devices, [2, 4, 1, 16]), |
248 |
| - ('4x8x8', mock_4x8x8_devices, [1, 128, 2]), |
249 |
| - ('8x8', mock_8x8_devices, [2, 1, 32, 1]), |
| 245 | + ( |
| 246 | + '4x4x4a', |
| 247 | + mock_4x4x4_devices, |
| 248 | + [2, 1, 32], |
| 249 | + [ |
| 250 | + [1, 1, 4], |
| 251 | + [1, 1, 4], |
| 252 | + [2, 1, 2], |
| 253 | + ], |
| 254 | + ), |
| 255 | + ( |
| 256 | + '4x4x4b', |
| 257 | + mock_4x4x4_devices, |
| 258 | + [8, 8, 1], |
| 259 | + [ |
| 260 | + [1, 4, 1], |
| 261 | + [2, 2, 1], |
| 262 | + [4, 1, 1], |
| 263 | + ], |
| 264 | + ), |
| 265 | + ( |
| 266 | + '4x4x8a', |
| 267 | + mock_4x4x8_devices, |
| 268 | + [2, 2, 8, 4], |
| 269 | + [ |
| 270 | + [1, 1, 1, 4], |
| 271 | + [2, 2, 1, 1], |
| 272 | + [1, 1, 8, 1], |
| 273 | + ], |
| 274 | + ), |
| 275 | + ( |
| 276 | + '4x4x8b', |
| 277 | + mock_4x4x8_devices, |
| 278 | + [2, 4, 1, 16], |
| 279 | + [ |
| 280 | + [1, 1, 1, 4], |
| 281 | + [1, 1, 1, 4], |
| 282 | + [2, 4, 1, 1], |
| 283 | + ], |
| 284 | + ), |
| 285 | + ( |
| 286 | + '4x8x8', |
| 287 | + mock_4x8x8_devices, |
| 288 | + [1, 128, 2], |
| 289 | + [ |
| 290 | + [1, 2, 2], |
| 291 | + [1, 8, 1], |
| 292 | + [1, 8, 1], |
| 293 | + ], |
| 294 | + ), |
| 295 | + ( |
| 296 | + '8x8', |
| 297 | + mock_8x8_devices, |
| 298 | + [2, 1, 32, 1], |
| 299 | + [ |
| 300 | + [1, 1, 8, 1], |
| 301 | + [2, 1, 4, 1], |
| 302 | + [1, 1, 1, 1], |
| 303 | + ], |
| 304 | + ), |
250 | 305 | )
|
251 | 306 | def test_create_device_mesh_for_nd_torus_split_axes_can_handle_axes_split(
|
252 |
| - self, devices, mesh_shape |
| 307 | + self, devices, mesh_shape, assignment_matrix |
253 | 308 | ):
|
254 | 309 | jax_devices = devices(True)
|
255 | 310 | physical_mesh = mesh_utils._get_physical_tpu_mesh(jax_devices)
|
256 |
| - logical_mesh, _ = mesh_utils._create_device_mesh_for_nd_torus( |
| 311 | + logical_mesh, assignment = mesh_utils._create_device_mesh_for_nd_torus( |
257 | 312 | physical_mesh, mesh_shape, allow_split_physical_axes=True
|
258 | 313 | )
|
259 | 314 | self.assertEqual(logical_mesh.shape, tuple(mesh_shape))
|
| 315 | + self.assertArraysEqual( |
| 316 | + assignment, np.array(assignment_matrix, dtype=np.int64) |
| 317 | + ) |
260 | 318 |
|
261 | 319 | @parameterized.named_parameters(
|
262 | 320 | ('2X4x4x4a', (1, 16, 4), (2, 1, 1)),
|
|
0 commit comments