Skip to content

Commit 9c9e805

Browse files
author
jax authors
committed
[Pallas TPU] Generalize while_loop lowering in Pallas -> Mosaic.
The existing lowering path supports only while_loops which can be converted to fori_loop. That path makes it significantly easier to optimize and unroll, but cannot support a large class of interesting loop formulations. This patch draws from the Pallas -> Triton while_loop lowering rule to support such loops in Pallas. Matching is still performed against fori_loop, to lower via that mechanism if possible -- as it is likely more straightforwardly optimizable compared to general "while". PiperOrigin-RevId: 626089180
1 parent 6ca69f3 commit 9c9e805

File tree

2 files changed

+274
-8
lines changed

2 files changed

+274
-8
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,28 +1833,23 @@ def _scan_lowering_rule(
18331833
skip_mlir_conversions.add(lax.scan_p)
18341834

18351835

1836-
def _while_lowering_rule(
1836+
def _lower_while_via_fori(
18371837
ctx: LoweringRuleContext,
18381838
*args,
1839+
fori_jaxpr,
18391840
cond_nconsts,
18401841
cond_jaxpr,
18411842
body_nconsts,
18421843
body_jaxpr,
18431844
):
1844-
jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
1845-
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
1846-
)
1847-
if jaxpr is None:
1848-
raise NotImplementedError(err)
1849-
18501845
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
18511846
(lb, ub), args = carry[:2], carry[2:]
18521847
for_out = _lower_jaxpr_to_for_loop(
18531848
ctx.replace(
18541849
block_shapes=ctx.block_shapes[: body_nconsts + 1]
18551850
+ ctx.block_shapes[body_nconsts + 2 :],
18561851
),
1857-
jaxpr,
1852+
fori_jaxpr,
18581853
lb,
18591854
arith.subi(ub, lb),
18601855
body_consts,
@@ -1865,6 +1860,84 @@ def _while_lowering_rule(
18651860
return [ub, ub, *for_out]
18661861

18671862

1863+
def _while_lowering_rule(
1864+
ctx: LoweringRuleContext,
1865+
*args,
1866+
cond_nconsts,
1867+
cond_jaxpr,
1868+
body_nconsts,
1869+
body_jaxpr,
1870+
):
1871+
# First try to lower via a simpler fori loop, which may optimize better.
1872+
fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
1873+
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
1874+
)
1875+
if fori_jaxpr is not None:
1876+
return _lower_while_via_fori(
1877+
ctx,
1878+
*args,
1879+
fori_jaxpr=fori_jaxpr,
1880+
cond_nconsts=cond_nconsts,
1881+
cond_jaxpr=cond_jaxpr,
1882+
body_nconsts=body_nconsts,
1883+
body_jaxpr=body_jaxpr,
1884+
)
1885+
1886+
# If we fail conversion to fori, fallback to an ordinary while loop.
1887+
cond_consts, body_consts, carry = split_list(
1888+
args, [cond_nconsts, body_nconsts]
1889+
)
1890+
cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
1891+
split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
1892+
)
1893+
cond_const_types = [a.type for a in cond_consts]
1894+
body_const_types = [a.type for a in body_consts]
1895+
carry_types = [a.type for a in carry]
1896+
all_types = [*cond_const_types, *body_const_types, *carry_types]
1897+
while_op = scf.WhileOp(all_types, args)
1898+
1899+
before_block = while_op.before.blocks.append(*all_types)
1900+
cond_consts_, _, carry_ = split_list(
1901+
before_block.arguments,
1902+
[cond_nconsts, body_nconsts],
1903+
)
1904+
cond_args = [*cond_consts_, *carry_]
1905+
with ir.InsertionPoint.at_block_begin(before_block):
1906+
[cond] = jaxpr_subcomp(
1907+
ctx.lowering_context.replace(
1908+
block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
1909+
),
1910+
cond_jaxpr.jaxpr,
1911+
*cond_args,
1912+
)
1913+
scf.condition(cond, before_block.arguments)
1914+
1915+
after_block = while_op.after.blocks.append(*all_types)
1916+
cond_consts_, body_consts_, carry_ = split_list(
1917+
after_block.arguments,
1918+
[cond_nconsts, body_nconsts],
1919+
)
1920+
all_args = [*cond_consts_, *body_consts_, *carry_]
1921+
cond_const_args, body_const_args, carry_args = split_list(
1922+
all_args, [cond_nconsts, body_nconsts]
1923+
)
1924+
with ir.InsertionPoint.at_block_begin(after_block):
1925+
loop_out = jaxpr_subcomp(
1926+
ctx.lowering_context.replace(
1927+
block_shapes=[*body_const_block_shapes, *carry_block_shapes],
1928+
),
1929+
body_jaxpr.jaxpr,
1930+
*body_const_args,
1931+
*carry_args,
1932+
)
1933+
all_handles = [*cond_const_args, *body_const_args, *loop_out]
1934+
if all_handles:
1935+
scf.yield_(all_handles)
1936+
1937+
all_out = list(while_op.results_)
1938+
return all_out[cond_nconsts + body_nconsts :]
1939+
1940+
18681941
lowering_rules[lax.while_p] = _while_lowering_rule
18691942

18701943
def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches, linear):

tests/pallas/pallas_call_tpu_test.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,199 @@ def body(i, _):
17511751
)(*(jnp.array([[x]]) for x in (2, 6)))
17521752
np.testing.assert_array_equal(r, 4)
17531753

1754+
def test_non_range_while_loop(self):
1755+
"""Tests lowering of a while_loop which cannot reduce to a fori_loop."""
1756+
1757+
def kernel(x_ref, r_ref):
1758+
@pl.when(pl.program_id(0) == 0)
1759+
def _():
1760+
pl.store(r_ref, (0, 0), 0)
1761+
1762+
def cond(state):
1763+
i, s = state
1764+
return jnp.logical_and(i < 1024, s < 1024)
1765+
1766+
def body(state):
1767+
i, s = state
1768+
sl = sl = jax.lax.div(i, 128)
1769+
l = jax.lax.rem(i, 128)
1770+
v = pl.load(x_ref, (0, sl, l))
1771+
return i + 1, s + v
1772+
1773+
i = jnp.int32(0)
1774+
s = pl.load(r_ref, (0, 0))
1775+
1776+
i, s = jax.lax.while_loop(cond, body, (i, s))
1777+
pl.store(r_ref, (0, 0), s)
1778+
1779+
x = jnp.arange(4096)
1780+
x = jnp.reshape(x, [4, 8, 128])
1781+
1782+
r = pl.pallas_call(
1783+
kernel,
1784+
grid=(4,),
1785+
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
1786+
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
1787+
in_specs=[
1788+
pl.BlockSpec(
1789+
lambda i: (i, 0, 0),
1790+
block_shape=(1, 8, 128),
1791+
memory_space=pltpu.SMEM,
1792+
)
1793+
],
1794+
)(x)
1795+
np.testing.assert_array_equal(r, [[1035]])
1796+
1797+
def test_vector_carry_while_loop(self):
1798+
"""Tests lowering of a while_loop which carries a vector quantity."""
1799+
1800+
def kernel(x_ref, r_ref):
1801+
1802+
def cond(v):
1803+
return v[0, 0] < 16
1804+
1805+
def body(v):
1806+
return v * 2
1807+
1808+
r_ref[:] = jax.lax.while_loop(cond, body, x_ref[:])
1809+
1810+
x = jnp.full((8, 128), 3, dtype=jnp.int32)
1811+
fn = pl.pallas_call(
1812+
kernel,
1813+
grid=(1,),
1814+
in_specs=[pl.BlockSpec(lambda i: (0, 0), (8, 128))],
1815+
out_specs=pl.BlockSpec(lambda i: (0, 0), (8, 128)),
1816+
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32),
1817+
)
1818+
r = fn(x)
1819+
reduced = jnp.sum(r)
1820+
# 3 -> 6 -> 12 -> 24
1821+
np.testing.assert_array_equal(reduced, 1024 * 24)
1822+
1823+
@parameterized.named_parameters(
1824+
('1x128', (1, 128)),
1825+
('2x128', (2, 128)),
1826+
('4x128', (4, 128)),
1827+
('8x128', (8, 128)),
1828+
('8x256', (8, 256)),
1829+
)
1830+
def test_while_loop_carry_memref(self, shape):
1831+
"""Tests a while loop carrying a memref."""
1832+
1833+
# TODO(hmckenzie): Investigate further why this occurs.
1834+
if shape == (1, 128):
1835+
self.skipTest('memref<1x128> inexplicably doubles to 2x128.')
1836+
1837+
def kernel(out_ref, bound):
1838+
def cond(i):
1839+
return i < bound
1840+
1841+
def body(i):
1842+
out_ref[0, i] = 2
1843+
return i + 1
1844+
1845+
jax.lax.while_loop(cond, body, 0)
1846+
1847+
x = jnp.asarray([1, 1, 1, 1])
1848+
x = jnp.asarray(x)
1849+
x = jnp.pad(x, (0, np.prod(shape) - 4), constant_values=0)
1850+
x = jnp.reshape(x, shape)
1851+
kernel = partial(kernel, bound=x.shape[1])
1852+
1853+
fn = pl.pallas_call(
1854+
kernel,
1855+
grid=(1,),
1856+
out_specs=[
1857+
pl.BlockSpec(
1858+
lambda i: (0, 0), block_shape=shape, memory_space=pltpu.SMEM
1859+
),
1860+
],
1861+
out_shape=[
1862+
jax.ShapeDtypeStruct(shape, jnp.int32),
1863+
],
1864+
)
1865+
y = fn()[0]
1866+
np.testing.assert_array_equal(y[0, 0], 2)
1867+
np.testing.assert_array_equal(y[0, 1], 2)
1868+
np.testing.assert_array_equal(y[0, 2], 2)
1869+
np.testing.assert_array_equal(y[0, 3], 2)
1870+
1871+
def test_nested_while_loop(self):
1872+
"""Tests lowering a nested while_loop."""
1873+
1874+
def kernel(in_key_ref, out_segment_count, out_size_ref, key_count):
1875+
# Compute the length of contiguous segments of keys.
1876+
1877+
def inner_cond(carry):
1878+
i, prev_key = carry
1879+
sl = sl = jax.lax.div(i, 128)
1880+
l = jax.lax.rem(i, 128)
1881+
key = jax.lax.cond(
1882+
i < key_count, lambda i: in_key_ref[sl, l], lambda i: -1, i
1883+
)
1884+
return jnp.logical_and(i < key_count, key == prev_key)
1885+
1886+
def inner_body(carry):
1887+
i, key = carry
1888+
return i + 1, key
1889+
1890+
def outer_cond(carry):
1891+
i, _ = carry
1892+
return i < key_count
1893+
1894+
def outer_body(carry):
1895+
i, next_out_idx = carry
1896+
sl = sl = jax.lax.div(i, 128)
1897+
l = jax.lax.rem(i, 128)
1898+
key = in_key_ref[sl, l]
1899+
end, _ = jax.lax.while_loop(inner_cond, inner_body, (i + 1, key))
1900+
1901+
sl = sl = jax.lax.div(next_out_idx, 128)
1902+
l = jax.lax.rem(next_out_idx, 128)
1903+
out_size_ref[sl, l] = end - i
1904+
return end, next_out_idx + 1
1905+
1906+
_, count = jax.lax.while_loop(outer_cond, outer_body, (0, 0))
1907+
out_segment_count[0, 0] = count
1908+
1909+
keys = [4, 4, 4, 3, 2, 2, 7, 7, 7, 7]
1910+
keys = jnp.asarray(keys)
1911+
real_keys = keys.shape[0]
1912+
key_count = 1024
1913+
keys = jnp.pad(keys, (0, key_count - real_keys), constant_values=32768)
1914+
keys = jnp.reshape(keys, (8, 128))
1915+
kernel_fn = partial(kernel, key_count=key_count)
1916+
1917+
fn = pl.pallas_call(
1918+
kernel_fn,
1919+
grid=(1,),
1920+
in_specs=[
1921+
# keys.
1922+
pl.BlockSpec(
1923+
lambda i: (0, 0),
1924+
block_shape=(8, 128),
1925+
memory_space=pltpu.SMEM,
1926+
),
1927+
],
1928+
out_specs=[
1929+
# Segments found.
1930+
pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
1931+
# Segment sizes.
1932+
pl.BlockSpec(block_shape=(8, 128), memory_space=pltpu.SMEM),
1933+
],
1934+
out_shape=[
1935+
jax.ShapeDtypeStruct((1, 1), jnp.int32),
1936+
jax.ShapeDtypeStruct((8, 128), jnp.int32),
1937+
],
1938+
)
1939+
count, sizes = fn(keys)
1940+
np.testing.assert_equal(count[0, 0], jnp.asarray(5))
1941+
np.testing.assert_equal(sizes[0, 0], jnp.asarray(3))
1942+
np.testing.assert_equal(sizes[0, 1], jnp.asarray(1))
1943+
np.testing.assert_equal(sizes[0, 2], jnp.asarray(2))
1944+
np.testing.assert_equal(sizes[0, 3], jnp.asarray(4))
1945+
np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys))
1946+
17541947

17551948
class PallasCallPipelineTest(parameterized.TestCase):
17561949

0 commit comments

Comments
 (0)