@@ -1654,6 +1654,58 @@ def _false():
1654
1654
return
1655
1655
1656
1656
1657
+ class PallasCallWhileLoopTest (PallasTPUTest ):
1658
+
1659
+ def setUp (self ):
1660
+ super ().setUp ()
1661
+ if jtu .device_under_test () != 'tpu' :
1662
+ self .skipTest ('Test only works on TPU' )
1663
+
1664
+ def test_range_while_loop (self ):
1665
+ """Tests lowering of a while_loop which can reduce to a fori_loop."""
1666
+
1667
+ def kernel (x_ref , r_ref ):
1668
+ @pl .when (pl .program_id (0 ) == 0 )
1669
+ def _ ():
1670
+ pl .store (r_ref , (0 , 0 ), 0 )
1671
+
1672
+ def cond (carry ):
1673
+ i , j = carry
1674
+ return i < j
1675
+
1676
+ def body (carry ):
1677
+ i , j = carry
1678
+ sl = sl = jax .lax .div (i , 128 )
1679
+ l = jax .lax .rem (i , 128 )
1680
+ v = x_ref [0 , sl , l ]
1681
+ s = pl .load (r_ref , (0 , 0 ))
1682
+ pl .store (r_ref , (0 , 0 ), s + v )
1683
+ return i + 1 , j
1684
+
1685
+ i = 0
1686
+ j = 1024
1687
+ i , j = jax .lax .while_loop (cond , body , (i , j ))
1688
+
1689
+ x = jnp .arange (4096 )
1690
+ x = jnp .reshape (x , [4 , 8 , 128 ])
1691
+
1692
+ r = pl .pallas_call (
1693
+ kernel ,
1694
+ grid = (1 ,),
1695
+ out_specs = pl .BlockSpec (block_shape = (1 , 1 ), memory_space = pltpu .SMEM ),
1696
+ out_shape = jax .ShapeDtypeStruct ([1 , 1 ], jnp .int32 ),
1697
+ in_specs = [
1698
+ pl .BlockSpec (
1699
+ lambda i : (i , 0 , 0 ),
1700
+ block_shape = (1 , 8 , 128 ),
1701
+ memory_space = pltpu .SMEM ,
1702
+ )
1703
+ ],
1704
+ )(x )
1705
+ expected = jnp .sum (jnp .arange (1024 ))
1706
+ np .testing .assert_array_equal (r , expected )
1707
+
1708
+
1657
1709
class PallasCallPipelineTest (parameterized .TestCase ):
1658
1710
1659
1711
def setUp (self ):
0 commit comments