@@ -1751,6 +1751,199 @@ def body(i, _):
1751
1751
)(* (jnp .array ([[x ]]) for x in (2 , 6 )))
1752
1752
np .testing .assert_array_equal (r , 4 )
1753
1753
1754
+ def test_non_range_while_loop (self ):
1755
+ """Tests lowering of a while_loop which cannot reduce to a fori_loop."""
1756
+
1757
+ def kernel (x_ref , r_ref ):
1758
+ @pl .when (pl .program_id (0 ) == 0 )
1759
+ def _ ():
1760
+ pl .store (r_ref , (0 , 0 ), 0 )
1761
+
1762
+ def cond (state ):
1763
+ i , s = state
1764
+ return jnp .logical_and (i < 1024 , s < 1024 )
1765
+
1766
+ def body (state ):
1767
+ i , s = state
1768
+ sl = sl = jax .lax .div (i , 128 )
1769
+ l = jax .lax .rem (i , 128 )
1770
+ v = pl .load (x_ref , (0 , sl , l ))
1771
+ return i + 1 , s + v
1772
+
1773
+ i = jnp .int32 (0 )
1774
+ s = pl .load (r_ref , (0 , 0 ))
1775
+
1776
+ i , s = jax .lax .while_loop (cond , body , (i , s ))
1777
+ pl .store (r_ref , (0 , 0 ), s )
1778
+
1779
+ x = jnp .arange (4096 )
1780
+ x = jnp .reshape (x , [4 , 8 , 128 ])
1781
+
1782
+ r = pl .pallas_call (
1783
+ kernel ,
1784
+ grid = (4 ,),
1785
+ out_specs = pl .BlockSpec (block_shape = (1 , 1 ), memory_space = pltpu .SMEM ),
1786
+ out_shape = jax .ShapeDtypeStruct ([1 , 1 ], jnp .int32 ),
1787
+ in_specs = [
1788
+ pl .BlockSpec (
1789
+ lambda i : (i , 0 , 0 ),
1790
+ block_shape = (1 , 8 , 128 ),
1791
+ memory_space = pltpu .SMEM ,
1792
+ )
1793
+ ],
1794
+ )(x )
1795
+ np .testing .assert_array_equal (r , [[1035 ]])
1796
+
1797
+ def test_vector_carry_while_loop (self ):
1798
+ """Tests lowering of a while_loop which carries a vector quantity."""
1799
+
1800
+ def kernel (x_ref , r_ref ):
1801
+
1802
+ def cond (v ):
1803
+ return v [0 , 0 ] < 16
1804
+
1805
+ def body (v ):
1806
+ return v * 2
1807
+
1808
+ r_ref [:] = jax .lax .while_loop (cond , body , x_ref [:])
1809
+
1810
+ x = jnp .full ((8 , 128 ), 3 , dtype = jnp .int32 )
1811
+ fn = pl .pallas_call (
1812
+ kernel ,
1813
+ grid = (1 ,),
1814
+ in_specs = [pl .BlockSpec (lambda i : (0 , 0 ), (8 , 128 ))],
1815
+ out_specs = pl .BlockSpec (lambda i : (0 , 0 ), (8 , 128 )),
1816
+ out_shape = jax .ShapeDtypeStruct ((8 , 128 ), jnp .int32 ),
1817
+ )
1818
+ r = fn (x )
1819
+ reduced = jnp .sum (r )
1820
+ # 3 -> 6 -> 12 -> 24
1821
+ np .testing .assert_array_equal (reduced , 1024 * 24 )
1822
+
1823
+ @parameterized .named_parameters (
1824
+ ('1x128' , (1 , 128 )),
1825
+ ('2x128' , (2 , 128 )),
1826
+ ('4x128' , (4 , 128 )),
1827
+ ('8x128' , (8 , 128 )),
1828
+ ('8x256' , (8 , 256 )),
1829
+ )
1830
+ def test_while_loop_carry_memref (self , shape ):
1831
+ """Tests a while loop carrying a memref."""
1832
+
1833
+ # TODO(hmckenzie): Investigate further why this occurs.
1834
+ if shape == (1 , 128 ):
1835
+ self .skipTest ('memref<1x128> inexplicably doubles to 2x128.' )
1836
+
1837
+ def kernel (out_ref , bound ):
1838
+ def cond (i ):
1839
+ return i < bound
1840
+
1841
+ def body (i ):
1842
+ out_ref [0 , i ] = 2
1843
+ return i + 1
1844
+
1845
+ jax .lax .while_loop (cond , body , 0 )
1846
+
1847
+ x = jnp .asarray ([1 , 1 , 1 , 1 ])
1848
+ x = jnp .asarray (x )
1849
+ x = jnp .pad (x , (0 , np .prod (shape ) - 4 ), constant_values = 0 )
1850
+ x = jnp .reshape (x , shape )
1851
+ kernel = partial (kernel , bound = x .shape [1 ])
1852
+
1853
+ fn = pl .pallas_call (
1854
+ kernel ,
1855
+ grid = (1 ,),
1856
+ out_specs = [
1857
+ pl .BlockSpec (
1858
+ lambda i : (0 , 0 ), block_shape = shape , memory_space = pltpu .SMEM
1859
+ ),
1860
+ ],
1861
+ out_shape = [
1862
+ jax .ShapeDtypeStruct (shape , jnp .int32 ),
1863
+ ],
1864
+ )
1865
+ y = fn ()[0 ]
1866
+ np .testing .assert_array_equal (y [0 , 0 ], 2 )
1867
+ np .testing .assert_array_equal (y [0 , 1 ], 2 )
1868
+ np .testing .assert_array_equal (y [0 , 2 ], 2 )
1869
+ np .testing .assert_array_equal (y [0 , 3 ], 2 )
1870
+
1871
+ def test_nested_while_loop (self ):
1872
+ """Tests lowering a nested while_loop."""
1873
+
1874
+ def kernel (in_key_ref , out_segment_count , out_size_ref , key_count ):
1875
+ # Compute the length of contiguous segments of keys.
1876
+
1877
+ def inner_cond (carry ):
1878
+ i , prev_key = carry
1879
+ sl = sl = jax .lax .div (i , 128 )
1880
+ l = jax .lax .rem (i , 128 )
1881
+ key = jax .lax .cond (
1882
+ i < key_count , lambda i : in_key_ref [sl , l ], lambda i : - 1 , i
1883
+ )
1884
+ return jnp .logical_and (i < key_count , key == prev_key )
1885
+
1886
+ def inner_body (carry ):
1887
+ i , key = carry
1888
+ return i + 1 , key
1889
+
1890
+ def outer_cond (carry ):
1891
+ i , _ = carry
1892
+ return i < key_count
1893
+
1894
+ def outer_body (carry ):
1895
+ i , next_out_idx = carry
1896
+ sl = sl = jax .lax .div (i , 128 )
1897
+ l = jax .lax .rem (i , 128 )
1898
+ key = in_key_ref [sl , l ]
1899
+ end , _ = jax .lax .while_loop (inner_cond , inner_body , (i + 1 , key ))
1900
+
1901
+ sl = sl = jax .lax .div (next_out_idx , 128 )
1902
+ l = jax .lax .rem (next_out_idx , 128 )
1903
+ out_size_ref [sl , l ] = end - i
1904
+ return end , next_out_idx + 1
1905
+
1906
+ _ , count = jax .lax .while_loop (outer_cond , outer_body , (0 , 0 ))
1907
+ out_segment_count [0 , 0 ] = count
1908
+
1909
+ keys = [4 , 4 , 4 , 3 , 2 , 2 , 7 , 7 , 7 , 7 ]
1910
+ keys = jnp .asarray (keys )
1911
+ real_keys = keys .shape [0 ]
1912
+ key_count = 1024
1913
+ keys = jnp .pad (keys , (0 , key_count - real_keys ), constant_values = 32768 )
1914
+ keys = jnp .reshape (keys , (8 , 128 ))
1915
+ kernel_fn = partial (kernel , key_count = key_count )
1916
+
1917
+ fn = pl .pallas_call (
1918
+ kernel_fn ,
1919
+ grid = (1 ,),
1920
+ in_specs = [
1921
+ # keys.
1922
+ pl .BlockSpec (
1923
+ lambda i : (0 , 0 ),
1924
+ block_shape = (8 , 128 ),
1925
+ memory_space = pltpu .SMEM ,
1926
+ ),
1927
+ ],
1928
+ out_specs = [
1929
+ # Segments found.
1930
+ pl .BlockSpec (block_shape = (1 , 1 ), memory_space = pltpu .SMEM ),
1931
+ # Segment sizes.
1932
+ pl .BlockSpec (block_shape = (8 , 128 ), memory_space = pltpu .SMEM ),
1933
+ ],
1934
+ out_shape = [
1935
+ jax .ShapeDtypeStruct ((1 , 1 ), jnp .int32 ),
1936
+ jax .ShapeDtypeStruct ((8 , 128 ), jnp .int32 ),
1937
+ ],
1938
+ )
1939
+ count , sizes = fn (keys )
1940
+ np .testing .assert_equal (count [0 , 0 ], jnp .asarray (5 ))
1941
+ np .testing .assert_equal (sizes [0 , 0 ], jnp .asarray (3 ))
1942
+ np .testing .assert_equal (sizes [0 , 1 ], jnp .asarray (1 ))
1943
+ np .testing .assert_equal (sizes [0 , 2 ], jnp .asarray (2 ))
1944
+ np .testing .assert_equal (sizes [0 , 3 ], jnp .asarray (4 ))
1945
+ np .testing .assert_equal (sizes [0 , 4 ], jnp .asarray (key_count - real_keys ))
1946
+
1754
1947
1755
1948
class PallasCallPipelineTest (parameterized .TestCase ):
1756
1949
0 commit comments