Skip to content

[Pallas TPU] Refactor casting logic to use bits instead of bytes and allow uint upcasts. #30103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,27 +2240,29 @@ def _dot_general_lowering_rule(
def _convert_helper(x: Array, *, to_dtype: jnp.dtype) -> Array:
# Helper function for dtype conversion
from_dtype = x.dtype
from_bitwidth = pallas_utils.dtype_bitwidth(from_dtype)
to_bitwidth = pallas_utils.dtype_bitwidth(to_dtype)
if from_dtype == jnp.bool_:
x = x.astype(jnp.int32)
return _convert_helper(x, to_dtype=to_dtype)
if to_dtype == jnp.bool_:
# Lower float32 or (u)int32 -> bool to cmp neq %in, 0
# TODO(apaszke,mvoz): Move the upcasts for cmpi to the Mosaic canonicalizer.
if jnp.issubdtype(from_dtype, jnp.floating):
if from_dtype.itemsize < 4:
if from_bitwidth < 32:
x = x.astype(jnp.float32)
elif jnp.issubdtype(from_dtype, jnp.integer):
if from_dtype.itemsize < 4:
if from_bitwidth < 32:
x = x.astype(jnp.int32)
return x != jnp.asarray(0, dtype=x.dtype)
if jnp.issubdtype(from_dtype, jnp.signedinteger):
if from_dtype.itemsize < 4:
if from_bitwidth < 32:
x = x.astype(jnp.int32)
if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
if jnp.issubdtype(to_dtype, jnp.floating) and to_bitwidth < 32:
x = x.astype(jnp.float32)
return x.astype(to_dtype)
if jnp.issubdtype(from_dtype, jnp.unsignedinteger):
if from_dtype.itemsize < 4:
if from_bitwidth < 32:
x = x.astype(jnp.uint32)
# unsigned -> float is unsupported. We fall through and raise at the bottom.
if not jnp.issubdtype(to_dtype, jnp.floating):
Expand Down Expand Up @@ -2294,25 +2296,30 @@ def _convert_element_type_lowering_rule(
floating = jnp.floating
integer = jnp.integer
signed = jnp.signedinteger
both_32bit = old_dtype.itemsize == 4 and new_dtype.itemsize == 4
unsigned = jnp.unsignedinteger
old_bitwidth = pallas_utils.dtype_bitwidth(old_dtype)
new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype)
both_32bit = old_bitwidth == 32 and new_bitwidth == 32
if _from(floating) and _to(floating):
forward_compat = ctx.forward_compatible or is_cloud_tpu_older_than(
2025, 6, 29
)
if old_dtype.itemsize < new_dtype.itemsize and (
new_dtype.itemsize == 4 or not forward_compat
if old_bitwidth < new_bitwidth and (
new_bitwidth == 32 or not forward_compat
):
return arith.extf(out_type, x)
elif old_dtype.itemsize > new_dtype.itemsize and (
old_dtype.itemsize == 4 or not forward_compat
elif old_bitwidth > new_bitwidth and (
old_bitwidth == 32 or not forward_compat
):
return arith.truncf(out_type, x)
elif _from(integer) and _to(integer):
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
if not (_from(signed) and _to(signed)):
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
return arith.extsi(out_type, x)
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
if old_bitwidth < new_bitwidth and new_bitwidth == 32:
if (_from(unsigned) and _to(unsigned)):
return arith.extui(out_type, x)
if (_from(signed) and _to(signed)):
return arith.extsi(out_type, x)
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
elif old_bitwidth > new_bitwidth and old_bitwidth == 32:
return arith.trunci(out_type, x)
elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
# This case triggers when casting signed to unsigned or vice versa.
Expand All @@ -2325,7 +2332,7 @@ def _convert_element_type_lowering_rule(
or both_32bit
):
return arith.sitofp(out_type, x)
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
elif old_dtype == jnp.bool_ and _to(integer) and new_bitwidth == 32:
return arith.extui(out_type, x)
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
multiple_results=False)(ctx, x)
Expand Down
7 changes: 4 additions & 3 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,18 @@ def body(x_ref, o_ref):
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0)

def test_tpu_signed_int_upcast(self):
@parameterized.parameters([jnp.uint4, jnp.int4])
def test_tpu_int4_upcast(self, dtype):
if not jtu.is_device_tpu_at_least(version=5):
self.skipTest("TPUv5+ needed for integer matmuls")

def body(x_ref, o_ref):
# Test cast from int4 -> int8
# Test cast from (u)int4 -> int8
ux = lax.convert_element_type(x_ref[...], jnp.int8)
o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32)

out = jax.ShapeDtypeStruct((128, 128), jnp.int32)
x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128))
x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128))
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(
result,
Expand Down
Loading