Skip to content

Commit b865c5b

Browse files
author
jax authors
committed
[Pallas TPU] Convert pattern_match_while_to_fori_loop to return (Jaxpr, str) rather than throw exceptions.
Currently, pattern_match_while_to_fori_loop attempts to convert a while_loop jaxpr into a type of fori_loop which Pallas can lower. To do so, it validates the conditions which would block the jaxpr from being lowered successfully. Because Pallas presently only supports "fori convertable" loops, this matching code also throws Exceptions when the supported conditions are violated. In the near future, we aim to have support for more ordinary while loops -- but we still would like to perform this match-and-convert procedure when possible. To facilitate that, this updates the error handling in pattern_match_while_to_fori_loop to simply return errors when hit, so the calling code can determine if they should be thrown. PiperOrigin-RevId: 623274837
1 parent 967c38d commit b865c5b

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1822,9 +1822,12 @@ def _while_lowering_rule(
18221822
body_nconsts,
18231823
body_jaxpr,
18241824
):
1825-
jaxpr = pallas_utils.pattern_match_while_to_fori_loop(
1825+
jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop(
18261826
cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
18271827
)
1828+
if jaxpr is None:
1829+
raise NotImplementedError(err)
1830+
18281831
_, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
18291832
(lb, ub), args = carry[:2], carry[2:]
18301833
for_out = _lower_jaxpr_to_for_loop(

jax/_src/pallas/utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Pallas utility functions."""
1616

17+
from __future__ import annotations
18+
1719
from jax import lax
1820
from jax._src import core as jax_core
1921
from jax._src.util import split_list
@@ -98,44 +100,42 @@ def pattern_match_while_to_fori_loop(
98100
cond_nconsts: int,
99101
body_jaxpr: jax_core.Jaxpr,
100102
body_nconsts: int,
101-
) -> tuple[jax_core.Jaxpr, bool]:
103+
) -> tuple[jax_core.Jaxpr | None, str | None]:
102104
# Try to pattern match to fori loop.
105+
# Successful matches produce (jaxpr, None), while failures use the str
106+
# component of the return tuple to capture information about the failure.
103107
if cond_nconsts:
104-
raise NotImplementedError("Conditional jaxpr can't contain consts.")
108+
return (None, "Conditional jaxpr can't contain consts.")
105109
_, cond_invars = split_list(cond_jaxpr.jaxpr.invars, [cond_nconsts])
106110
cond_in_avals = [v.aval for v in cond_invars]
107111
if len(cond_in_avals) < 2:
108-
raise NotImplementedError("Conditional jaxpr have only two carry args.")
112+
return (None, "Conditional jaxpr have only two carry args.")
109113
# Check that the first two carry values are scalar ints
110114
a1, a2 = cond_in_avals[:2]
111115
if a1.shape or a1.dtype not in (jnp.int32, jnp.int64):
112-
raise NotImplementedError(
113-
"First conditional jaxpr carry arg is not a scalar int."
114-
)
116+
return (None, "First conditional jaxpr carry arg is not a scalar int.")
115117
if a2.shape or a2.dtype not in (jnp.int32, jnp.int64):
116-
raise NotImplementedError(
117-
"Second conditional jaxpr carry arg is not a scalar int."
118-
)
118+
return (None, "Second conditional jaxpr carry arg is not a scalar int.")
119119
# Check that the only eqn in the cond checks the loop index condition
120120
v1, v2 = cond_invars[:2]
121121
outvar = cond_jaxpr.jaxpr.outvars[0]
122122
assert outvar.aval.dtype == jnp.bool_
123123
if len(cond_jaxpr.jaxpr.eqns) != 1:
124-
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
124+
return (None, "Non-trivial conditional jaxprs not supported.")
125125
eqn = cond_jaxpr.jaxpr.eqns[0]
126126
if eqn.primitive != lax.lt_p:
127-
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
127+
return (None, "Non-trivial conditional jaxprs not supported.")
128128
if eqn.outvars != [outvar]:
129-
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
129+
return (None, "Non-trivial conditional jaxprs not supported.")
130130
if eqn.invars != [v1, v2]:
131-
raise NotImplementedError("Non-trivial conditional jaxprs not supported.")
131+
return (None, "Non-trivial conditional jaxprs not supported.")
132132
# Check that the carry is updated in the body appropriately
133133
_, body_invars = split_list(body_jaxpr.jaxpr.invars, [body_nconsts])
134134
v1, v2 = body_invars[:2]
135135
vo1, vo2 = body_jaxpr.jaxpr.outvars[:2]
136136
# Upper bound should be constant
137137
if v2 is not vo2:
138-
raise NotImplementedError("Loop upper bound is not constant.")
138+
return (None, "Loop upper bound is not constant.")
139139
# Check that we increment the loop index in the body
140140
for i, eqn in enumerate(body_jaxpr.jaxpr.eqns):
141141
if eqn.primitive is lax.add_p:
@@ -146,7 +146,7 @@ def pattern_match_while_to_fori_loop(
146146
eqn_index = i
147147
break
148148
else:
149-
raise NotImplementedError("Loop index not incremented in body.")
149+
return (None, "Loop index not incremented in body.")
150150
jaxpr = body_jaxpr.jaxpr
151151
new_invars = (
152152
*jaxpr.invars[:body_nconsts],
@@ -159,4 +159,4 @@ def pattern_match_while_to_fori_loop(
159159
invars=new_invars,
160160
outvars=new_outvars,
161161
)
162-
return jaxpr
162+
return jaxpr, None

0 commit comments

Comments
 (0)