Skip to content

Commit 44be575

Browse files
author
jax authors
committed
Compute axis_index without creating an entire grid of device IDs.
For large meshes, this numpy array can exceed the size of SMEM. We can perform the same calculation using just the grid shape. PiperOrigin-RevId: 618167202
1 parent 2d65571 commit 44be575

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777

7878
@dataclasses.dataclass
7979
class MeshContext:
80-
logical_to_mesh: ir.Value
80+
mesh_shape: tuple[int, ...]
8181
axis_names: tuple[str, ...]
8282
mesh_strides: tuple[int, ...]
8383

@@ -298,20 +298,7 @@ def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
298298
mesh_strides = pallas_utils.strides_from_shape(tuple(
299299
mesh.shape[a] for a in axis_names
300300
))
301-
logical_to_mesh = np.empty((mesh.size, len(axis_names)), dtype=np.int32)
302-
for i, idx in enumerate(np.ndindex(*mesh.device_ids.shape)):
303-
logical_to_mesh[i] = np.array(idx)
304-
self.mesh_info = MeshInfo(logical_to_mesh, axis_names, mesh_strides)
305-
l_to_m_aval = pl_core.AbstractMemoryRef(
306-
jax_core.raise_to_shaped(jax_core.get_aval(logical_to_mesh)),
307-
TPUMemorySpace.SMEM,
308-
)
309-
# We are now passing in the logical -> mesh index mapping
310-
# TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index
311-
# mapping as scalar prefetch and instead just mark it as an SMEM operand.
312-
self.scalar_prefetch_types = (
313-
_get_arg_type(l_to_m_aval, None)[0],
314-
*self.scalar_prefetch_types)
301+
self.mesh_info = MeshInfo(mesh.device_ids.shape, axis_names, mesh_strides)
315302

316303
def maybe_compress_grid(self):
317304
# If we have many leading parallel dimensions, we should "compress" them
@@ -324,9 +311,7 @@ def has_communication(self) -> bool:
324311
return bool(jax_core.used_axis_names_jaxpr(self.jaxpr))
325312

326313
def get_extra_args(self) -> tuple[Any, ...]:
327-
if self.mesh_info is None:
328-
return ()
329-
return (self.mesh_info.logical_to_mesh,)
314+
return ()
330315

331316
def get_dimension_semantics(self) -> ir.ArrayAttr:
332317

@@ -344,7 +329,7 @@ def _get_semantics(s: str | None) -> str:
344329

345330
@dataclasses.dataclass
346331
class MeshInfo:
347-
logical_to_mesh: np.ndarray
332+
mesh_shape: tuple[int, ...]
348333
axis_names: list[str]
349334
mesh_strides: tuple[int, ...]
350335

@@ -469,9 +454,9 @@ def body_func(*args):
469454

470455
mesh_info = mosaic_grid_mapping.mesh_info
471456
if mesh_info is not None:
472-
(l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1])
473-
mesh_context = MeshContext(l_to_m, mesh_info.axis_names,
474-
mesh_info.mesh_strides)
457+
mesh_context = MeshContext(
458+
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
459+
)
475460
else:
476461
mesh_context = None
477462
lowering_context = LoweringContext(
@@ -527,9 +512,9 @@ def body_func(*args):
527512
if i not in mosaic_grid_mapping.mapped_dims)
528513
mesh_info = mosaic_grid_mapping.mesh_info
529514
if mesh_info is not None:
530-
(l_to_m,), scalar_prefetch = split_list(scalar_prefetch, [1])
531-
mesh_context = MeshContext(l_to_m, mesh_info.axis_names,
532-
mesh_info.mesh_strides)
515+
mesh_context = MeshContext(
516+
mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
517+
)
533518
else:
534519
mesh_context = None
535520
lowering_context = LoweringContext(
@@ -2145,11 +2130,15 @@ def _device_id_lowering_rule(ctx: LoweringRuleContext):
21452130
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule
21462131

21472132
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: str):
2148-
device_id = _make_index(tpu.DeviceIdOp().result)
2149-
l_to_m = ctx.lowering_context.mesh_context.logical_to_mesh
2133+
device_id = tpu.DeviceIdOp().result
2134+
mesh_shape = ctx.lowering_context.mesh_context.mesh_shape
21502135
axis_names = ctx.lowering_context.mesh_context.axis_names
2151-
col = _make_index(axis_names.index(axis_name))
2152-
return memref.LoadOp(l_to_m, [device_id, col]).result
2136+
axis_index = axis_names.index(axis_name)
2137+
axis_size = ir_constant(mesh_shape[axis_index])
2138+
minor_divisor = ir_constant(
2139+
np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32)
2140+
)
2141+
return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size)
21532142
lowering_rules[lax.axis_index_p] = _axis_index_rule
21542143

21552144
def _get_barrier_semaphore_rule(ctx: LoweringRuleContext):

0 commit comments

Comments
 (0)