Skip to content

Commit 4d9efff

Browse files
superbobryjax authors
authored andcommitted
_cast() now takes JAX dtypes
The MLIR-level cast, which infers the src type from ir.Value, is now called _ir_cast. Hopefully, this makes the casting logic a bit easier to follow. PiperOrigin-RevId: 623654848
1 parent e3018db commit 4d9efff

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
580580
for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types):
581581
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
582582
if aval.weak_type and aval.dtype.name != arg_type:
583-
bcast_arg = _cast(
584-
bcast_arg,
585-
_dtype_to_ir_type(jnp.dtype(arg_type)),
586-
signed=jnp.issubdtype(aval.dtype, jnp.signedinteger),
587-
)
583+
bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type))
588584
bcast_args.append(bcast_arg)
589585
return h.lower(ctx, *bcast_args)
590586

@@ -1162,8 +1158,8 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x):
11621158
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
11631159
zero = _full(x.type, 0)
11641160
return _sub(
1165-
_cast(_greater_than(x, zero, signed=signed), x.type, signed=signed),
1166-
_cast(_less_than(x, zero, signed=signed), x.type, signed=signed),
1161+
_cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
1162+
_cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype),
11671163
)
11681164

11691165

@@ -1172,7 +1168,7 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x):
11721168

11731169
def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
11741170
iota = _make_range(0, shape[dimension])
1175-
iota = _cast(iota, _dtype_to_ir_type(dtype), signed=False)
1171+
iota = _cast(iota, jnp.int32, dtype)
11761172
for i in range(len(shape)):
11771173
if i != dimension:
11781174
iota = _expand_dims(iota, i)
@@ -1298,7 +1294,19 @@ def _int_float_cast(
12981294
return arith_dialect.sitofp(dst_type, src)
12991295

13001296

1301-
def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
1297+
def _cast(
1298+
src: ir.Value,
1299+
src_type: jax.typing.DTypeLike,
1300+
dst_type: jax.typing.DTypeLike,
1301+
) -> ir.Value:
1302+
return _ir_cast(
1303+
src,
1304+
_dtype_to_ir_type(dst_type),
1305+
signed=jnp.issubdtype(src_type, jnp.signedinteger),
1306+
)
1307+
1308+
1309+
def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
13021310
if ir.RankedTensorType.isinstance(
13031311
src.type
13041312
) and not ir.RankedTensorType.isinstance(dst_type):
@@ -1322,8 +1330,8 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
13221330
if isinstance(src_element_type, (ir.F16Type, ir.BF16Type)) and not isinstance(
13231331
dst_element_type, ir.F32Type
13241332
):
1325-
return _cast(
1326-
_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
1333+
return _ir_cast(
1334+
_ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
13271335
)
13281336

13291337
if isinstance(src_element_type, ir.FloatType) and isinstance(
@@ -1350,10 +1358,10 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
13501358
):
13511359
if dst_element_type.width == 64:
13521360
return tt_dialect.ptr_to_int(dst_type, src)
1353-
else:
1354-
x = _cast(src, ir.IntegerType.get_signless(64), signed=signed)
1361+
elif dst_element_type.width == 1:
1362+
x = _ir_cast(src, ir.IntegerType.get_signless(64), signed=signed)
13551363
zero = _full(x.type, 0)
1356-
return _cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
1364+
return _ir_cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
13571365
if isinstance(
13581366
src_element_type, ir.IntegerType
13591367
) and tt_dialect.PointerType.isinstance(dst_element_type):
@@ -1373,8 +1381,7 @@ def _convert_element_type_lowering_rule(
13731381
x = _ensure_ir_value(x, x_aval)
13741382
if new_dtype == x_aval.dtype:
13751383
return x
1376-
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
1377-
return _cast(x, _dtype_to_ir_type(new_dtype), signed=signed)
1384+
return _cast(x, x_aval.dtype, new_dtype)
13781385

13791386

13801387
triton_lowering_rules[lax.convert_element_type_p] = (
@@ -1519,7 +1526,7 @@ def _compute_pointers_from_indices(
15191526
else:
15201527
ptr_dim_offset = _add(
15211528
_bcast_to(index.start, [index.size]),
1522-
_cast(_make_range(0, index.size), index.start.type, signed=False),
1529+
_ir_cast(_make_range(0, index.size), index.start.type, signed=False),
15231530
)
15241531
# We need to add broadcastable dimensions for the advanced int indexing
15251532
# and for previous slices
@@ -1557,7 +1564,7 @@ def _compute_pointers_from_indices(
15571564
ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape)
15581565
index_type = ir.IntegerType(_element_type(ptr_dim_offset.type))
15591566
if start_offset is not None:
1560-
start_offset = _cast(start_offset, index_type, signed=False)
1567+
start_offset = _ir_cast(start_offset, index_type, signed=False)
15611568
ptr_dim_offset = _add(
15621569
ptr_dim_offset, _bcast_to(start_offset, indexer_shape)
15631570
)
@@ -1660,14 +1667,14 @@ def _load(
16601667
is_int1 = isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1
16611668
if is_int1:
16621669
pointee_type = ir.IntegerType.get_signless(8)
1663-
ptr = _cast(
1670+
ptr = _ir_cast(
16641671
ptr,
16651672
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
16661673
signed=False,
16671674
)
16681675

16691676
if other is not None:
1670-
other = _cast(other, pointee_type, signed=False)
1677+
other = _ir_cast(other, pointee_type, signed=False)
16711678

16721679
result = tt_dialect.load(
16731680
_infer_load_return_type(ptr),
@@ -1681,7 +1688,7 @@ def _load(
16811688
return (
16821689
result
16831690
if not is_int1
1684-
else _cast(result, ir.IntegerType.get_signless(1), signed=False)
1691+
else _ir_cast(result, ir.IntegerType.get_signless(1), signed=False)
16851692
)
16861693

16871694

@@ -1782,13 +1789,13 @@ def _store(
17821789
pointee_type = ptr_type.pointee_type
17831790
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
17841791
pointee_type = ir.IntegerType.get_signless(8)
1785-
ptr = _cast(
1792+
ptr = _ir_cast(
17861793
ptr,
17871794
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
17881795
signed=False,
17891796
)
17901797

1791-
value = _cast(value, pointee_type, signed=False)
1798+
value = _ir_cast(value, pointee_type, signed=False)
17921799
return tt_dialect.store(
17931800
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
17941801
)
@@ -1955,8 +1962,8 @@ def _dot_general_lowering(
19551962
allow_tf32=allow_tf32,
19561963
out_type=_dtype_to_ir_type(acc_dtype),
19571964
),
1958-
_dtype_to_ir_type(out_dtype),
1959-
signed=jnp.issubdtype(out_aval.dtype, jnp.signedinteger),
1965+
acc_dtype,
1966+
out_dtype,
19601967
)
19611968

19621969

0 commit comments

Comments
 (0)