Skip to content

Commit cdafb8f

Browse files
superbobryjax authors
authored andcommitted
Update the lowering for div_p to require f32/f64 for floating point inputs
PTX has no div instruction for other floating point types. See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div. PiperOrigin-RevId: 616113396
1 parent c94ea14 commit cdafb8f

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
835835
def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
836836
assert x.type == y.type, (str(x.type), str(y.type))
837837
x_element_type = _element_type(x.type)
838-
if isinstance(x_element_type, ir.FloatType):
838+
if isinstance(x_element_type, (ir.F32Type, ir.F64Type)):
839839
return arith_dialect.divf(x, y)
840840
if not isinstance(x_element_type, ir.IntegerType):
841841
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
@@ -852,7 +852,7 @@ def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
852852
x_element_type = ir.F32Type.get()
853853
x = _int_float_cast(x, x_element_type, signed=signed)
854854
y = _int_float_cast(y, x_element_type, signed=signed)
855-
if isinstance(x_element_type, ir.FloatType):
855+
if isinstance(x_element_type, (ir.F32Type, ir.F64Type)):
856856
return arith_dialect.divf(x, y)
857857
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
858858

tests/pallas/pallas_test.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@
4040
from jax.experimental.pallas.ops import rms_norm
4141
from jax.experimental.pallas.ops import softmax
4242
try:
43+
from jax._src.pallas.triton.lowering import LoweringError
4344
from jax._src.pallas.triton.pallas_call_registration import (
4445
compile_jaxpr,
4546
_TRITON_COMPILE_VIA_XLA,
4647
)
4748
from jax.experimental.pallas import gpu as plgpu
4849
except ModuleNotFoundError:
50+
LoweringError = Exception
4951
compile_jaxpr = None
5052
_TRITON_COMPILE_VIA_XLA = None
5153
import numpy as np
@@ -1634,19 +1636,41 @@ def isnan(x_ref, o_ref):
16341636
x = x.at[3].set(jnp.nan)
16351637
np.testing.assert_allclose(isnan(x), jnp.isnan(x))
16361638

1637-
def test_true_divide(self):
1639+
@parameterized.parameters(
1640+
("int32", "float32"),
1641+
("float32", "float32"),
1642+
)
1643+
def test_true_divide(self, dtype, out_dtype):
16381644
@functools.partial(
16391645
self.pallas_call,
1640-
out_shape=jax.ShapeDtypeStruct((8,), jnp.float32),
1646+
out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
16411647
grid=1,
16421648
)
16431649
def kernel(x_ref, y_ref, o_ref):
16441650
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
16451651

1646-
x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7], dtype=jnp.int32)
1647-
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4], dtype=jnp.int32)
1652+
x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
1653+
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
16481654
np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))
16491655

1656+
@parameterized.parameters("float16", "bfloat16")
1657+
def test_true_divide_unsupported(self, dtype):
1658+
if self.INTERPRET:
1659+
self.skipTest("No lowering in interpreter mode")
1660+
1661+
@functools.partial(
1662+
self.pallas_call,
1663+
out_shape=jax.ShapeDtypeStruct((2,), dtype),
1664+
grid=1,
1665+
)
1666+
def kernel(x_ref, y_ref, o_ref):
1667+
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
1668+
1669+
x = jnp.array([2.4, 4.2]).astype(dtype)
1670+
y = jnp.array([4.2, 2.4]).astype(dtype)
1671+
with self.assertRaises(LoweringError):
1672+
kernel(x, y)
1673+
16501674
BINARY_OPS = [
16511675
([jnp.floor_divide], ["int32", "uint32"]),
16521676
(

0 commit comments

Comments
 (0)