Skip to content

Commit ab83469

Browse files
blakehechtmanjax authors
authored andcommitted
[PALLAS] add test for large indexing.
PiperOrigin-RevId: 611925093
1 parent 51a31e5 commit ab83469

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/pallas/pallas_call_tpu_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,32 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem):
10351035
)(x)
10361036
np.testing.assert_array_equal(y, x)
10371037

1038+
def test_large_array_indexing(self):
1039+
n = 6
1040+
dtype = jnp.bfloat16
1041+
x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0)
1042+
1043+
def kernel(index, x, y, sem):
1044+
pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait()
1045+
1046+
run = pl.pallas_call(kernel,
1047+
grid_spec=pltpu.PrefetchScalarGridSpec(
1048+
num_scalar_prefetch=1,
1049+
in_specs=[
1050+
pl.BlockSpec(
1051+
memory_space=pltpu.TPUMemorySpace.ANY)],
1052+
out_specs=pl.BlockSpec(
1053+
memory_space=pltpu.TPUMemorySpace.ANY),
1054+
scratch_shapes=[pltpu.SemaphoreType.DMA],
1055+
),
1056+
out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype),
1057+
)
1058+
1059+
for i in range(x.shape[0]):
1060+
y = run(jnp.array([i], dtype=jnp.int32), x)
1061+
np.testing.assert_array_equal(y, i)
1062+
del y
1063+
10381064

10391065
class PallasCallRemoteDMATest(parameterized.TestCase):
10401066

0 commit comments

Comments
 (0)