You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Pallas TPU] Convert pattern_match_while_to_fori_loop to return (Jaxpr, str) rather than throw exceptions.
Currently, pattern_match_while_to_fori_loop attempts to convert a while_loop jaxpr into a type of fori_loop which Pallas can lower.
To do so, it validates the conditions which would block the jaxpr from being lowered successfully. Because Pallas presently only supports "fori convertable" loops, this matching code also throws Exceptions when the supported conditions are violated.
In the near future, we aim to have support for more ordinary while loops -- but we still would like to perform this match-and-convert procedure when possible.
To facilitate that, this updates the error handling in pattern_match_while_to_fori_loop to simply return errors when hit, so the calling code can determine if they should be thrown.
PiperOrigin-RevId: 623274837
0 commit comments