Skip to content

Commit 9862236

Browse files
superbobryjax authors
authored andcommitted
Use jnp.issubdtype instead of np.issubdtype
PiperOrigin-RevId: 617610920
1 parent 089651f commit 9862236

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
11031103
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) or jnp.issubdtype(
11041104
y_aval.dtype, jnp.signedinteger
11051105
)
1106-
if np.issubdtype(x_aval.dtype, np.floating) or np.issubdtype(
1106+
if jnp.issubdtype(x_aval.dtype, np.floating) or jnp.issubdtype(
11071107
y_aval.dtype, np.floating
11081108
):
11091109
return _truediv(x, y, signed=signed)
@@ -1115,7 +1115,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
11151115

11161116
def _sign_lowering_rule(ctx: LoweringRuleContext, x):
11171117
[x_aval] = ctx.avals_in
1118-
signed = np.issubdtype(x_aval.dtype, jnp.signedinteger)
1118+
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
11191119
zero = _full(x.type, 0)
11201120
return _sub(
11211121
_cast(_greater_than(x, zero, signed=signed), x.type, signed=signed),
@@ -2445,7 +2445,7 @@ def _i64_constant(v: int) -> ir.Value:
24452445

24462446

24472447
def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type:
2448-
if np.issubdtype(dtype, np.integer):
2448+
if jnp.issubdtype(dtype, np.integer):
24492449
# All integer types in Triton are signless.
24502450
return ir.IntegerType.get_signless(dtype.itemsize * 8)
24512451
return mlir.dtype_to_ir_type(dtype)

0 commit comments

Comments
 (0)