|
40 | 40 | from jax.experimental.pallas.ops import rms_norm
|
41 | 41 | from jax.experimental.pallas.ops import softmax
|
42 | 42 | try:
|
| 43 | + from jax._src.pallas.triton.lowering import LoweringError |
43 | 44 | from jax._src.pallas.triton.pallas_call_registration import (
|
44 | 45 | compile_jaxpr,
|
45 | 46 | _TRITON_COMPILE_VIA_XLA,
|
46 | 47 | )
|
47 | 48 | from jax.experimental.pallas import gpu as plgpu
|
48 | 49 | except ModuleNotFoundError:
|
| 50 | + LoweringError = Exception |
49 | 51 | compile_jaxpr = None
|
50 | 52 | _TRITON_COMPILE_VIA_XLA = None
|
51 | 53 | import numpy as np
|
@@ -1634,19 +1636,41 @@ def isnan(x_ref, o_ref):
|
1634 | 1636 | x = x.at[3].set(jnp.nan)
|
1635 | 1637 | np.testing.assert_allclose(isnan(x), jnp.isnan(x))
|
1636 | 1638 |
|
1637 |
| - def test_true_divide(self): |
| 1639 | + @parameterized.parameters( |
| 1640 | + ("int32", "float32"), |
| 1641 | + ("float32", "float32"), |
| 1642 | + ) |
| 1643 | + def test_true_divide(self, dtype, out_dtype): |
1638 | 1644 | @functools.partial(
|
1639 | 1645 | self.pallas_call,
|
1640 |
| - out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), |
| 1646 | + out_shape=jax.ShapeDtypeStruct((8,), out_dtype), |
1641 | 1647 | grid=1,
|
1642 | 1648 | )
|
1643 | 1649 | def kernel(x_ref, y_ref, o_ref):
|
1644 | 1650 | o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])
|
1645 | 1651 |
|
1646 |
| - x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7], dtype=jnp.int32) |
1647 |
| - y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4], dtype=jnp.int32) |
| 1652 | + x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype) |
| 1653 | + y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype) |
1648 | 1654 | np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))
|
1649 | 1655 |
|
| 1656 | + @parameterized.parameters("float16", "bfloat16") |
| 1657 | + def test_true_divide_unsupported(self, dtype): |
| 1658 | + if self.INTERPRET: |
| 1659 | + self.skipTest("No lowering in interpreter mode") |
| 1660 | + |
| 1661 | + @functools.partial( |
| 1662 | + self.pallas_call, |
| 1663 | + out_shape=jax.ShapeDtypeStruct((2,), dtype), |
| 1664 | + grid=1, |
| 1665 | + ) |
| 1666 | + def kernel(x_ref, y_ref, o_ref): |
| 1667 | + o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...]) |
| 1668 | + |
| 1669 | + x = jnp.array([2.4, 4.2]).astype(dtype) |
| 1670 | + y = jnp.array([4.2, 2.4]).astype(dtype) |
| 1671 | + with self.assertRaises(LoweringError): |
| 1672 | + kernel(x, y) |
| 1673 | + |
1650 | 1674 | BINARY_OPS = [
|
1651 | 1675 | ([jnp.floor_divide], ["int32", "uint32"]),
|
1652 | 1676 | (
|
|
0 commit comments