Skip to content

Commit 71ec6e3

Browse files
apaszkejax authors
authored andcommitted
Make pl.num_programs lowering take the vmapped axes into account
Otherwise the size of the wrong axis is returned. PiperOrigin-RevId: 614677218
1 parent de455e7 commit 71ec6e3

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ class MeshContext:
8686
@dataclasses.dataclass
8787
class LoweringContext:
8888
ir_context: ir.Context
89-
grid_indices: Sequence[ir.Value] | None
89+
grid_rank: int # Includes both user and vmap axes.
90+
mapped_dims: tuple[int, ...] # Indices of vmapped grid dimensions.
91+
user_grid_indices: Sequence[ir.Value] | None
9092
block_shapes: list[tuple[int | pl_core.Mapped, ...]]
9193
name_stack: source_info_util.NameStack
9294
mesh_context: MeshContext | None
@@ -475,6 +477,8 @@ def body_func(*args):
475477
mesh_context = None
476478
lowering_context = LoweringContext(
477479
ctx,
480+
len(mosaic_grid_mapping.grid),
481+
mosaic_grid_mapping.mapped_dims,
478482
None,
479483
arg_block_shapes,
480484
source_info_util.NameStack(),
@@ -531,6 +535,8 @@ def body_func(*args):
531535
mesh_context = None
532536
lowering_context = LoweringContext(
533537
ctx,
538+
len(mosaic_grid_mapping.grid),
539+
mosaic_grid_mapping.mapped_dims,
534540
jaxpr_indices,
535541
arg_block_shapes,
536542
source_info_util.NameStack(),
@@ -1846,22 +1852,32 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
18461852

18471853

18481854
def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
1849-
if ctx.lowering_context.grid_indices is None:
1855+
if ctx.lowering_context.user_grid_indices is None:
18501856
raise ValueError(
18511857
f"program id: {axis} was passed, but user did not provide a grid."
18521858
)
1853-
length = len(ctx.lowering_context.grid_indices)
1859+
length = len(ctx.lowering_context.user_grid_indices)
18541860
if not (0 <= axis < length):
18551861
raise ValueError(
18561862
f"user passed in program id with axis: {axis}, but grid only has"
18571863
f" length: {length}"
18581864
)
1859-
return ctx.lowering_context.grid_indices[axis]
1865+
return ctx.lowering_context.user_grid_indices[axis]
18601866
lowering_rules[primitives.program_id_p] = _program_id_lowering_rule
18611867

18621868
def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
1863-
del ctx
1864-
return tpu.iteration_bound(axis)
1869+
mapped_axes = set(ctx.lowering_context.mapped_dims)
1870+
seen_user_axes = 0
1871+
for i in range(ctx.lowering_context.grid_rank):
1872+
seen_user_axes += int(i not in mapped_axes)
1873+
if seen_user_axes == axis + 1:
1874+
break
1875+
else:
1876+
raise ValueError(
1877+
f"user passed in program id with axis: {axis}, but grid only has"
1878+
f" length: {len(ctx.lowering_context.grid_rank)}"
1879+
)
1880+
return tpu.iteration_bound(i)
18651881
lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule
18661882

18671883

tests/pallas/pallas_call_tpu_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,27 @@ def dynamic_kernel(steps):
450450

451451
self.assertEqual(dynamic_kernel(4), 8)
452452

453+
@parameterized.parameters(range(1, 4))
454+
def test_vmap_num_programs(self, num_vmaps):
455+
result_ty = jax.ShapeDtypeStruct((8, 128), jnp.int32)
456+
457+
def kernel(y_ref):
458+
y_ref[...] = jnp.full_like(y_ref, pl.num_programs(0))
459+
460+
kernel_call = self.pallas_call(
461+
kernel,
462+
grid=(8,),
463+
out_specs=pl.BlockSpec(lambda i: (0, 0), result_ty.shape),
464+
out_shape=result_ty,
465+
)
466+
467+
out_shape = (*(2 for _ in range(num_vmaps)), *result_ty.shape)
468+
f = kernel_call
469+
for _ in range(num_vmaps):
470+
f = lambda impl=f: jax.vmap(impl, axis_size=2)()
471+
out = jax.jit(f)()
472+
np.testing.assert_array_equal(out, np.full(out_shape, 8.0))
473+
453474
def test_num_programs_block_spec(self):
454475
def kernel(x_ref, y_ref):
455476
y_ref[...] = x_ref[...]

0 commit comments

Comments
 (0)