Skip to content

Commit 28b81be

Browse files
author
jax authors
committed
[Pallas TPU] Pallas while loop -> fori test.
PiperOrigin-RevId: 623204164
1 parent 77db7a6 commit 28b81be

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tests/pallas/pallas_call_tpu_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,58 @@ def _false():
16541654
return
16551655

16561656

1657+
class PallasCallWhileLoopTest(PallasTPUTest):
1658+
1659+
def setUp(self):
1660+
super().setUp()
1661+
if jtu.device_under_test() != 'tpu':
1662+
self.skipTest('Test only works on TPU')
1663+
1664+
def test_range_while_loop(self):
1665+
"""Tests lowering of a while_loop which can reduce to a fori_loop."""
1666+
1667+
def kernel(x_ref, r_ref):
1668+
@pl.when(pl.program_id(0) == 0)
1669+
def _():
1670+
pl.store(r_ref, (0, 0), 0)
1671+
1672+
def cond(carry):
1673+
i, j = carry
1674+
return i < j
1675+
1676+
def body(carry):
1677+
i, j = carry
1678+
sl = sl = jax.lax.div(i, 128)
1679+
l = jax.lax.rem(i, 128)
1680+
v = x_ref[0, sl, l]
1681+
s = pl.load(r_ref, (0, 0))
1682+
pl.store(r_ref, (0, 0), s + v)
1683+
return i + 1, j
1684+
1685+
i = 0
1686+
j = 1024
1687+
i, j = jax.lax.while_loop(cond, body, (i, j))
1688+
1689+
x = jnp.arange(4096)
1690+
x = jnp.reshape(x, [4, 8, 128])
1691+
1692+
r = pl.pallas_call(
1693+
kernel,
1694+
grid=(1,),
1695+
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
1696+
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
1697+
in_specs=[
1698+
pl.BlockSpec(
1699+
lambda i: (i, 0, 0),
1700+
block_shape=(1, 8, 128),
1701+
memory_space=pltpu.SMEM,
1702+
)
1703+
],
1704+
)(x)
1705+
expected = jnp.sum(jnp.arange(1024))
1706+
np.testing.assert_array_equal(r, expected)
1707+
1708+
16571709
class PallasCallPipelineTest(parameterized.TestCase):
16581710

16591711
def setUp(self):

0 commit comments

Comments
 (0)