Skip to content

Commit d0eae05

Browse files
apaszkejax authors
authored andcommitted
Add a test for grid overflows in dynamic grid lowering.
PiperOrigin-RevId: 614980113
1 parent 9dbf758 commit d0eae05

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/pallas/pallas_call_tpu_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,30 @@ def dynamic_kernel(steps):
341341
dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32)
342342
)
343343

344+
def test_dynamic_grid_overflow(self):
345+
# If we pad statically the dynamic grid dims to max int32, then the product
346+
# of this grid size will overflow int64 and can cause failing checks in XLA.
347+
shape = (8, 128)
348+
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
349+
350+
def kernel(y_ref):
351+
@pl.when(sum(pl.program_id(i) for i in range(3)) == 0)
352+
def _init():
353+
y_ref[...] = jnp.zeros_like(y_ref)
354+
y_ref[...] += 1
355+
356+
@jax.jit
357+
def dynamic_kernel(steps):
358+
return self.pallas_call(
359+
kernel,
360+
grid=(steps * 2, steps + 1, 3),
361+
out_specs=pl.BlockSpec(lambda *_: (0, 0), shape),
362+
out_shape=result_ty,
363+
)()
364+
np.testing.assert_array_equal(
365+
dynamic_kernel(jnp.int32(4)), np.full(shape, 120.0, np.float32)
366+
)
367+
344368
# TODO(apaszke): Add tests for scalar_prefetch too
345369
def test_dynamic_grid_scalar_input(self):
346370
shape = (8, 128)

0 commit comments

Comments
 (0)