@@ -2240,27 +2240,29 @@ def _dot_general_lowering_rule(
2240
2240
def _convert_helper (x : Array , * , to_dtype : jnp .dtype ) -> Array :
2241
2241
# Helper function for dtype conversion
2242
2242
from_dtype = x .dtype
2243
+ from_bitwidth = pallas_utils .dtype_bitwidth (from_dtype )
2244
+ to_bitwidth = pallas_utils .dtype_bitwidth (to_dtype )
2243
2245
if from_dtype == jnp .bool_ :
2244
2246
x = x .astype (jnp .int32 )
2245
2247
return _convert_helper (x , to_dtype = to_dtype )
2246
2248
if to_dtype == jnp .bool_ :
2247
2249
# Lower float32 or (u)int32 -> bool to cmp neq %in, 0
2248
2250
# TODO(apaszke,mvoz): Move the upcasts for cmpi to the Mosaic canonicalizer.
2249
2251
if jnp .issubdtype (from_dtype , jnp .floating ):
2250
- if from_dtype . itemsize < 4 :
2252
+ if from_bitwidth < 32 :
2251
2253
x = x .astype (jnp .float32 )
2252
2254
elif jnp .issubdtype (from_dtype , jnp .integer ):
2253
- if from_dtype . itemsize < 4 :
2255
+ if from_bitwidth < 32 :
2254
2256
x = x .astype (jnp .int32 )
2255
2257
return x != jnp .asarray (0 , dtype = x .dtype )
2256
2258
if jnp .issubdtype (from_dtype , jnp .signedinteger ):
2257
- if from_dtype . itemsize < 4 :
2259
+ if from_bitwidth < 32 :
2258
2260
x = x .astype (jnp .int32 )
2259
- if jnp .issubdtype (to_dtype , jnp .floating ) and to_dtype . itemsize < 4 :
2261
+ if jnp .issubdtype (to_dtype , jnp .floating ) and to_bitwidth < 32 :
2260
2262
x = x .astype (jnp .float32 )
2261
2263
return x .astype (to_dtype )
2262
2264
if jnp .issubdtype (from_dtype , jnp .unsignedinteger ):
2263
- if from_dtype . itemsize < 4 :
2265
+ if from_bitwidth < 32 :
2264
2266
x = x .astype (jnp .uint32 )
2265
2267
# unsigned -> float is unsupported. We fall through and raise at the bottom.
2266
2268
if not jnp .issubdtype (to_dtype , jnp .floating ):
@@ -2294,25 +2296,30 @@ def _convert_element_type_lowering_rule(
2294
2296
floating = jnp .floating
2295
2297
integer = jnp .integer
2296
2298
signed = jnp .signedinteger
2297
- both_32bit = old_dtype .itemsize == 4 and new_dtype .itemsize == 4
2299
+ unsigned = jnp .unsignedinteger
2300
+ old_bitwidth = pallas_utils .dtype_bitwidth (old_dtype )
2301
+ new_bitwidth = pallas_utils .dtype_bitwidth (new_dtype )
2302
+ both_32bit = old_bitwidth == 32 and new_bitwidth == 32
2298
2303
if _from (floating ) and _to (floating ):
2299
2304
forward_compat = ctx .forward_compatible or is_cloud_tpu_older_than (
2300
2305
2025 , 6 , 29
2301
2306
)
2302
- if old_dtype . itemsize < new_dtype . itemsize and (
2303
- new_dtype . itemsize == 4 or not forward_compat
2307
+ if old_bitwidth < new_bitwidth and (
2308
+ new_bitwidth == 32 or not forward_compat
2304
2309
):
2305
2310
return arith .extf (out_type , x )
2306
- elif old_dtype . itemsize > new_dtype . itemsize and (
2307
- old_dtype . itemsize == 4 or not forward_compat
2311
+ elif old_bitwidth > new_bitwidth and (
2312
+ old_bitwidth == 32 or not forward_compat
2308
2313
):
2309
2314
return arith .truncf (out_type , x )
2310
2315
elif _from (integer ) and _to (integer ):
2311
- if old_dtype .itemsize < new_dtype .itemsize and new_dtype .itemsize == 4 :
2312
- if not (_from (signed ) and _to (signed )):
2313
- raise NotImplementedError (f"Unsupported cast: { old_dtype } -> { new_dtype } " )
2314
- return arith .extsi (out_type , x )
2315
- elif old_dtype .itemsize > new_dtype .itemsize and old_dtype .itemsize == 4 :
2316
+ if old_bitwidth < new_bitwidth and new_bitwidth == 32 :
2317
+ if (_from (unsigned ) and _to (unsigned )):
2318
+ return arith .extui (out_type , x )
2319
+ if (_from (signed ) and _to (signed )):
2320
+ return arith .extsi (out_type , x )
2321
+ raise NotImplementedError (f"Unsupported cast: { old_dtype } -> { new_dtype } " )
2322
+ elif old_bitwidth > new_bitwidth and old_bitwidth == 32 :
2316
2323
return arith .trunci (out_type , x )
2317
2324
elif jnp .iinfo (old_dtype ).bits == jnp .iinfo (new_dtype ).bits :
2318
2325
# This case triggers when casting signed to unsigned or vice versa.
@@ -2325,7 +2332,7 @@ def _convert_element_type_lowering_rule(
2325
2332
or both_32bit
2326
2333
):
2327
2334
return arith .sitofp (out_type , x )
2328
- elif old_dtype == jnp .bool_ and _to (integer ) and new_dtype . itemsize == 4 :
2335
+ elif old_dtype == jnp .bool_ and _to (integer ) and new_bitwidth == 32 :
2329
2336
return arith .extui (out_type , x )
2330
2337
return lower_fun (functools .partial (_convert_helper , to_dtype = new_dtype ),
2331
2338
multiple_results = False )(ctx , x )
0 commit comments