Skip to content

Commit ae1cfc2

Browse files
superbobryjax authors
authored andcommitted
Added a lowering rule for lax.sign_p and improved test coverage for binary ops
Closes #17317 PiperOrigin-RevId: 615038353
1 parent d0eae05 commit ae1cfc2

File tree

2 files changed

+85
-32
lines changed

2 files changed

+85
-32
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,10 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
837837

838838
def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
839839
assert x.type == y.type, (str(x.type), str(y.type))
840-
if not isinstance(_element_type(x.type), ir.IntegerType):
840+
x_element_type = _element_type(x.type)
841+
if isinstance(x_element_type, ir.FloatType):
842+
return arith_dialect.divf(x, y)
843+
if not isinstance(x_element_type, ir.IntegerType):
841844
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
842845
if signed:
843846
return arith_dialect.divsi(x, y)
@@ -859,7 +862,10 @@ def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
859862

860863
def _mod(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
861864
assert x.type == y.type, (str(x.type), str(y.type))
862-
if not isinstance(_element_type(x.type), ir.IntegerType):
865+
x_element_type = _element_type(x.type)
866+
if isinstance(x_element_type, ir.FloatType):
867+
return arith_dialect.remf(x, y)
868+
if not isinstance(x_element_type, ir.IntegerType):
863869
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
864870
if signed:
865871
return arith_dialect.remsi(x, y)
@@ -1108,6 +1114,19 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
11081114
triton_lowering_rules[lax.div_p] = _div_lowering_rule
11091115

11101116

1117+
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
1118+
[x_aval] = ctx.avals_in
1119+
signed = np.issubdtype(x_aval.dtype, jnp.signedinteger)
1120+
zero = _full(x.type, 0)
1121+
return _sub(
1122+
_cast(_greater_than(x, zero, signed=signed), x.type, signed=signed),
1123+
_cast(_less_than(x, zero, signed=signed), x.type, signed=signed),
1124+
)
1125+
1126+
1127+
triton_lowering_rules[lax.sign_p] = _sign_lowering_rule
1128+
1129+
11111130
def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
11121131
if dimension != 0:
11131132
raise NotImplementedError

tests/pallas/pallas_test.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,6 +1517,7 @@ def add_one(x_ref, o_ref):
15171517
class PallasCallInterpreterVmapTest(PallasCallVmapTest):
15181518
INTERPRET = True
15191519

1520+
15201521
class PallasOpsTest(PallasTest):
15211522

15221523
def test_pow_weak_dtype(self):
@@ -1528,17 +1529,31 @@ def square(x_ref, o_ref):
15281529
x = jnp.array(42.0)
15291530
np.testing.assert_allclose(square(x), x*x)
15301531

1531-
def test_ne(self):
1532+
COMPARISON_OPS = [
1533+
jnp.equal,
1534+
jnp.not_equal,
1535+
jnp.less,
1536+
jnp.less_equal,
1537+
jnp.greater,
1538+
jnp.greater_equal,
1539+
]
1540+
1541+
@parameterized.named_parameters(
1542+
(f"{fn.__name__}_{dtype}", fn, dtype)
1543+
for fn, dtype in itertools.product(
1544+
COMPARISON_OPS, ["int32", "uint32", "float32"]
1545+
)
1546+
)
1547+
def test_comparison(self, fn, dtype):
15321548
@functools.partial(
15331549
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_),
15341550
grid=1)
1535-
def ne(x_ref, y_ref, o_ref):
1536-
o_ref[:] = x_ref[...] != y_ref[...]
1551+
def kernel(x_ref, y_ref, o_ref):
1552+
o_ref[:] = fn(x_ref[...], y_ref[...])
15371553

1538-
x = jnp.ones(8, dtype=jnp.int32)
1539-
y = jnp.arange(8, dtype=jnp.int32)
1540-
not_equal = ne(x, y)
1541-
np.testing.assert_allclose(not_equal, x != y)
1554+
x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
1555+
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
1556+
np.testing.assert_allclose(kernel(x, y), fn(x, y))
15421557

15431558
def test_isnan(self):
15441559
@functools.partial(
@@ -1551,33 +1566,52 @@ def isnan(x_ref, o_ref):
15511566
x = x.at[3].set(jnp.nan)
15521567
np.testing.assert_allclose(isnan(x), jnp.isnan(x))
15531568

1554-
@parameterized.named_parameters(*(
1555-
(fn.__name__, fn, out_dtype)
1556-
for fn, out_dtype in [
1557-
(jnp.add, jnp.int32),
1558-
(jnp.subtract, jnp.int32),
1559-
(jnp.multiply, jnp.int32),
1560-
(jnp.true_divide, jnp.float32),
1561-
(jnp.remainder, jnp.int32),
1562-
(jnp.less, jnp.bool_),
1563-
(jnp.less_equal, jnp.bool_),
1564-
(jnp.greater, jnp.bool_),
1565-
(jnp.greater_equal, jnp.bool_),
1566-
(jnp.equal, jnp.bool_),
1567-
(jnp.not_equal, jnp.bool_),
1568-
]
1569-
))
1570-
def test_signed_int_ops(self, f, out_dtype):
1569+
def test_true_divide(self):
15711570
@functools.partial(
15721571
self.pallas_call,
1573-
out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
1574-
grid=1)
1575-
def f_i32(x_ref, y_ref, o_ref):
1572+
out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
1573+
grid=1,
1574+
)
1575+
def kernel(x_ref, y_ref, o_ref):
1576+
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
1577+
1578+
x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7], dtype=jnp.int32)
1579+
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4], dtype=jnp.int32)
1580+
np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))
1581+
1582+
BINARY_OPS = [
1583+
([jnp.floor_divide], ["int32", "uint32"]),
1584+
(
1585+
[jnp.add, jnp.subtract, jnp.multiply, jnp.remainder],
1586+
["int32", "uint32", "float32"],
1587+
),
1588+
(
1589+
[
1590+
jnp.bitwise_and,
1591+
jnp.bitwise_or,
1592+
jnp.bitwise_xor,
1593+
jnp.bitwise_left_shift,
1594+
jnp.bitwise_right_shift,
1595+
],
1596+
["int32", "uint32"],
1597+
),
1598+
]
1599+
1600+
@parameterized.named_parameters(
1601+
(f"{fn.__name__}_{dtype}", fn, dtype)
1602+
for args in BINARY_OPS
1603+
for fn, dtype in itertools.product(*args)
1604+
)
1605+
def test_binary(self, f, dtype):
1606+
@functools.partial(
1607+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1
1608+
)
1609+
def kernel(x_ref, y_ref, o_ref):
15761610
o_ref[...] = f(x_ref[...], y_ref[...])
15771611

1578-
x = jnp.int32([1, 3, -4, -6, 2, 5, 4, -7])
1579-
y = jnp.int32([3, 1, -4, -5, 2, -2, 0, 4])
1580-
np.testing.assert_allclose(f(x, y), f_i32(x, y))
1612+
x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
1613+
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
1614+
np.testing.assert_allclose(f(x, y), kernel(x, y))
15811615

15821616

15831617
class PallasOpsInterpretTest(PallasOpsTest):

0 commit comments

Comments
 (0)