Skip to content

Commit bc2674d

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Refactor casting logic to use bits instead of bytes and allow uint upcasts.
PiperOrigin-RevId: 781255443
1 parent 03af274 commit bc2674d

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2240,27 +2240,29 @@ def _dot_general_lowering_rule(
22402240
def _convert_helper(x: Array, *, to_dtype: jnp.dtype) -> Array:
22412241
# Helper function for dtype conversion
22422242
from_dtype = x.dtype
2243+
from_bitwidth = pallas_utils.dtype_bitwidth(from_dtype)
2244+
to_bitwidth = pallas_utils.dtype_bitwidth(to_dtype)
22432245
if from_dtype == jnp.bool_:
22442246
x = x.astype(jnp.int32)
22452247
return _convert_helper(x, to_dtype=to_dtype)
22462248
if to_dtype == jnp.bool_:
22472249
# Lower float32 or (u)int32 -> bool to cmp neq %in, 0
22482250
# TODO(apaszke,mvoz): Move the upcasts for cmpi to the Mosaic canonicalizer.
22492251
if jnp.issubdtype(from_dtype, jnp.floating):
2250-
if from_dtype.itemsize < 4:
2252+
if from_bitwidth < 32:
22512253
x = x.astype(jnp.float32)
22522254
elif jnp.issubdtype(from_dtype, jnp.integer):
2253-
if from_dtype.itemsize < 4:
2255+
if from_bitwidth < 32:
22542256
x = x.astype(jnp.int32)
22552257
return x != jnp.asarray(0, dtype=x.dtype)
22562258
if jnp.issubdtype(from_dtype, jnp.signedinteger):
2257-
if from_dtype.itemsize < 4:
2259+
if from_bitwidth < 32:
22582260
x = x.astype(jnp.int32)
2259-
if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
2261+
if jnp.issubdtype(to_dtype, jnp.floating) and to_bitwidth < 32:
22602262
x = x.astype(jnp.float32)
22612263
return x.astype(to_dtype)
22622264
if jnp.issubdtype(from_dtype, jnp.unsignedinteger):
2263-
if from_dtype.itemsize < 4:
2265+
if from_bitwidth < 32:
22642266
x = x.astype(jnp.uint32)
22652267
# unsigned -> float is unsupported. We fall through and raise at the bottom.
22662268
if not jnp.issubdtype(to_dtype, jnp.floating):
@@ -2294,25 +2296,30 @@ def _convert_element_type_lowering_rule(
22942296
floating = jnp.floating
22952297
integer = jnp.integer
22962298
signed = jnp.signedinteger
2297-
both_32bit = old_dtype.itemsize == 4 and new_dtype.itemsize == 4
2299+
unsigned = jnp.unsignedinteger
2300+
old_bitwidth = pallas_utils.dtype_bitwidth(old_dtype)
2301+
new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype)
2302+
both_32bit = old_bitwidth == 32 and new_bitwidth == 32
22982303
if _from(floating) and _to(floating):
22992304
forward_compat = ctx.forward_compatible or is_cloud_tpu_older_than(
23002305
2025, 6, 29
23012306
)
2302-
if old_dtype.itemsize < new_dtype.itemsize and (
2303-
new_dtype.itemsize == 4 or not forward_compat
2307+
if old_bitwidth < new_bitwidth and (
2308+
new_bitwidth == 32 or not forward_compat
23042309
):
23052310
return arith.extf(out_type, x)
2306-
elif old_dtype.itemsize > new_dtype.itemsize and (
2307-
old_dtype.itemsize == 4 or not forward_compat
2311+
elif old_bitwidth > new_bitwidth and (
2312+
old_bitwidth == 32 or not forward_compat
23082313
):
23092314
return arith.truncf(out_type, x)
23102315
elif _from(integer) and _to(integer):
2311-
if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
2312-
if not (_from(signed) and _to(signed)):
2313-
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
2314-
return arith.extsi(out_type, x)
2315-
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
2316+
if old_bitwidth < new_bitwidth and new_bitwidth == 32:
2317+
if (_from(unsigned) and _to(unsigned)):
2318+
return arith.extui(out_type, x)
2319+
if (_from(signed) and _to(signed)):
2320+
return arith.extsi(out_type, x)
2321+
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
2322+
elif old_bitwidth > new_bitwidth and old_bitwidth == 32:
23162323
return arith.trunci(out_type, x)
23172324
elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
23182325
# This case triggers when casting signed to unsigned or vice versa.
@@ -2325,7 +2332,7 @@ def _convert_element_type_lowering_rule(
23252332
or both_32bit
23262333
):
23272334
return arith.sitofp(out_type, x)
2328-
elif old_dtype == jnp.bool_ and _to(integer) and new_dtype.itemsize == 4:
2335+
elif old_dtype == jnp.bool_ and _to(integer) and new_bitwidth == 32:
23292336
return arith.extui(out_type, x)
23302337
return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
23312338
multiple_results=False)(ctx, x)

tests/pallas/tpu_ops_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,18 @@ def body(x_ref, o_ref):
244244
result = self.pallas_call(body, out_shape=out)(x)
245245
np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0)
246246

247-
def test_tpu_signed_int_upcast(self):
247+
@parameterized.parameters([jnp.uint4, jnp.int4])
248+
def test_tpu_int4_upcast(self, dtype):
248249
if not jtu.is_device_tpu_at_least(version=5):
249250
self.skipTest("TPUv5+ needed for integer matmuls")
250251

251252
def body(x_ref, o_ref):
252-
# Test cast from int4 -> int8
253+
# Test cast from (u)int4 -> int8
253254
ux = lax.convert_element_type(x_ref[...], jnp.int8)
254255
o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32)
255256

256257
out = jax.ShapeDtypeStruct((128, 128), jnp.int32)
257-
x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128))
258+
x = jnp.arange(128 * 128, dtype=dtype).reshape((128, 128))
258259
result = self.pallas_call(body, out_shape=out)(x)
259260
np.testing.assert_array_equal(
260261
result,

0 commit comments

Comments
 (0)