Skip to content

Commit b3fe940

Browse files
author
jax authors
committed
Add round lowering rule.
PiperOrigin-RevId: 621110036
1 parent 1d221d1 commit b3fe940

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,18 @@ def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
15271527
lowering_rules[lax.log1p_p] = _log1p_lowering_rule
15281528

15291529

1530+
def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
1531+
if rounding_method == 0:
1532+
return math.RoundOp(x).result
1533+
elif rounding_method == 1:
1534+
return math.RoundEvenOp(x).result
1535+
else:
1536+
raise NotImplementedError(f"Unsupported rounding method: {rounding_method}")
1537+
1538+
1539+
lowering_rules[lax.round_p] = _round_lowering_rule
1540+
1541+
15301542
_cmpi_lowering_types = {
15311543
lax.eq_p: 0,
15321544
lax.ne_p: 1,

0 commit comments

Comments
 (0)