Skip to content

Commit 0bd7070

Browse files
superbobryjax authors
authored andcommitted
Fixed lowering of binary ops for signed dtypes
All integers in Trition are signless, so we need to manually forward the signedness of the abstract values. I wonder if we should avoid relying on MLIR types altogether and change _cast and similar APIs to accept JAX dtypes instead? PiperOrigin-RevId: 614803683
1 parent b6e985f commit 0bd7070

File tree

2 files changed

+116
-53
lines changed

2 files changed

+116
-53
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 85 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping):
215215
for i, s in enumerate(collapse_dims):
216216
out_idx = launch_grid_to_pallas_grid[i]
217217
s = _i32_constant(s)
218-
out_indices[out_idx] = _mod(grid0, s)
219-
grid0 = _floordiv(grid0, s)
218+
out_indices[out_idx] = _mod(grid0, s, signed=False)
219+
grid0 = _floordiv(grid0, s, signed=False)
220220

221221
for i in range(len(prog_id_dims)):
222222
out_idx = launch_grid_to_pallas_grid[num_collapse + i]
@@ -558,7 +558,11 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
558558
for aval, arg, arg_type in zip(ctx.avals_in, args, extern.arg_types):
559559
bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape)
560560
if aval.weak_type and aval.dtype.name != arg_type:
561-
bcast_arg = _cast(bcast_arg, _dtype_to_ir_type(jnp.dtype(arg_type)))
561+
bcast_arg = _cast(
562+
bcast_arg,
563+
_dtype_to_ir_type(jnp.dtype(arg_type)),
564+
signed=jnp.issubdtype(aval.dtype, jnp.signedinteger),
565+
)
562566
bcast_args.append(bcast_arg)
563567

564568
result_type = _dtype_to_ir_type(jnp.dtype(extern.result_type))
@@ -831,35 +835,33 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
831835
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
832836

833837

834-
def _floordiv(x: ir.Value, y: ir.Value) -> ir.Value:
838+
def _floordiv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
835839
assert x.type == y.type, (str(x.type), str(y.type))
836-
x_element_type = _element_type(x.type)
837-
if not isinstance(x_element_type, ir.IntegerType):
840+
if not isinstance(_element_type(x.type), ir.IntegerType):
838841
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
839-
if x_element_type.is_signed:
842+
if signed:
840843
return arith_dialect.divsi(x, y)
841844
else:
842845
return arith_dialect.divui(x, y)
843846

844847

845-
def _truediv(x: ir.Value, y: ir.Value) -> ir.Value:
848+
def _truediv(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
846849
assert x.type == y.type, (str(x.type), str(y.type))
847850
x_element_type = _element_type(x.type)
848851
if isinstance(x_element_type, ir.IntegerType):
849852
x_element_type = ir.F32Type.get()
850-
x = _int_float_cast(x, x_element_type)
851-
y = _int_float_cast(y, x_element_type)
853+
x = _int_float_cast(x, x_element_type, signed=signed)
854+
y = _int_float_cast(y, x_element_type, signed=signed)
852855
if isinstance(x_element_type, ir.FloatType):
853856
return arith_dialect.divf(x, y)
854857
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
855858

856859

857-
def _mod(x: ir.Value, y: ir.Value) -> ir.Value:
860+
def _mod(x: ir.Value, y: ir.Value, *, signed: bool) -> ir.Value:
858861
assert x.type == y.type, (str(x.type), str(y.type))
859-
x_element_type = _element_type(x.type)
860-
if not isinstance(x_element_type, ir.IntegerType):
862+
if not isinstance(_element_type(x.type), ir.IntegerType):
861863
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
862-
if x_element_type.is_signed:
864+
if signed:
863865
return arith_dialect.remsi(x, y)
864866
else:
865867
return arith_dialect.remui(x, y)
@@ -871,13 +873,13 @@ def _cmp(
871873
si_pred: arith_dialect.CmpIPredicate,
872874
ui_pred: arith_dialect.CmpIPredicate,
873875
f_pred: arith_dialect.CmpFPredicate,
876+
*,
877+
signed: bool,
874878
) -> ir.Value:
875879
assert x.type == y.type, (str(x.type), str(y.type))
876880
x_element_type = _element_type(x.type)
877881
if isinstance(x_element_type, ir.IntegerType):
878-
return arith_dialect.cmpi(
879-
si_pred if x_element_type.is_signed else ui_pred, x, y
880-
)
882+
return arith_dialect.cmpi(si_pred if signed else ui_pred, x, y)
881883
elif isinstance(x_element_type, ir.FloatType):
882884
return arith_dialect.cmpf(f_pred, x, y)
883885
else:
@@ -926,29 +928,42 @@ def _cmp(
926928
lax.add_p: _add,
927929
lax.sub_p: _sub,
928930
lax.mul_p: _mul,
929-
lax.rem_p: _mod,
930931
lax.and_p: arith_dialect.andi,
931932
lax.or_p: arith_dialect.ori,
932933
lax.xor_p: arith_dialect.xori,
933934
lax.shift_left_p: arith_dialect.shli,
934935
lax.shift_right_arithmetic_p: arith_dialect.shrsi,
935936
lax.shift_right_logical_p: arith_dialect.shrui,
937+
ad_util.add_any_p: _add,
938+
}
939+
940+
for prim, fn in _JAX_TO_TRITON_BINARY.items():
941+
942+
def signless_rule(ctx: LoweringRuleContext, x, y, fn=fn):
943+
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
944+
return fn(x, y)
945+
946+
triton_lowering_rules[prim] = signless_rule
947+
948+
949+
_JAX_TO_TRITON_SIGNED_BINARY = {
950+
lax.rem_p: _mod,
936951
lax.eq_p: _equal,
937952
lax.ne_p: _not_equal,
938953
lax.gt_p: _greater_than,
939954
lax.ge_p: _greater_equal,
940955
lax.lt_p: _less_than,
941956
lax.le_p: _less_equal,
942-
ad_util.add_any_p: _add,
943957
}
944958

945-
for prim, fn in _JAX_TO_TRITON_BINARY.items():
959+
for prim, fn in _JAX_TO_TRITON_SIGNED_BINARY.items():
946960

947-
def rule(ctx: LoweringRuleContext, x, y, fn=fn):
961+
def signed_rule(ctx: LoweringRuleContext, x, y, fn=fn):
962+
x_aval, _ = ctx.avals_in
948963
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
949-
return fn(x, y)
964+
return fn(x, y, signed=jnp.issubdtype(x_aval.dtype, jnp.signedinteger))
950965

951-
triton_lowering_rules[prim] = rule
966+
triton_lowering_rules[prim] = signed_rule
952967

953968

954969
def _set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None:
@@ -1080,11 +1095,14 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
10801095
def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
10811096
x_aval, y_aval = ctx.avals_in
10821097
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
1098+
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) or jnp.issubdtype(
1099+
y_aval.dtype, jnp.signedinteger
1100+
)
10831101
if np.issubdtype(x_aval.dtype, np.floating) or np.issubdtype(
10841102
y_aval.dtype, np.floating
10851103
):
1086-
return _truediv(x, y)
1087-
return _floordiv(x, y)
1104+
return _truediv(x, y, signed=signed)
1105+
return _floordiv(x, y, signed=signed)
10881106

10891107

10901108
triton_lowering_rules[lax.div_p] = _div_lowering_rule
@@ -1093,7 +1111,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
10931111
def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
10941112
if dimension != 0:
10951113
raise NotImplementedError
1096-
return _cast(_make_range(0, *shape), _dtype_to_ir_type(dtype))
1114+
return _cast(_make_range(0, *shape), _dtype_to_ir_type(dtype), signed=False)
10971115

10981116

10991117
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule
@@ -1168,49 +1186,54 @@ def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
11681186
raise NotImplementedError
11691187

11701188

1171-
def _int_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1189+
def _int_int_cast(src: ir.Value, dst_type: ir.Type, signed: bool) -> ir.Value:
11721190
src_element_type = ir.IntegerType(_element_type(src.type))
11731191
dst_element_type = ir.IntegerType(_element_type(dst_type))
11741192
assert src_element_type != dst_element_type
11751193
if dst_element_type.width == 1:
1176-
return _not_equal(src, _full(src.type, 0))
1194+
return _not_equal(src, _full(src.type, 0), signed=signed)
11771195

1178-
is_signed = src_element_type.is_signed and src_element_type.width != 1
11791196
if src_element_type.width == dst_element_type.width:
11801197
return arith_dialect.bitcast(dst_type, src)
11811198
elif src_element_type.width > dst_element_type.width:
11821199
return arith_dialect.trunci(dst_type, src)
1183-
elif is_signed:
1200+
elif signed and src_element_type.width != 1:
11841201
return arith_dialect.extsi(dst_type, src)
11851202
else:
11861203
return arith_dialect.extui(dst_type, src)
11871204

11881205

1189-
def _float_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1206+
def _float_int_cast(
1207+
src: ir.Value, dst_type: ir.Type, *, signed: bool
1208+
) -> ir.Value:
11901209
src_element_type = _element_type(src.type)
11911210
if not isinstance(src_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)):
11921211
raise NotImplementedError(f"cannot cast {src} tp {dst_type}")
11931212
dst_element_type = ir.IntegerType(_element_type(dst_type))
11941213
if dst_element_type.width == 1:
1195-
return _not_equal(src, _full(src.type, 0))
1196-
elif dst_element_type.is_signed:
1214+
return _not_equal(src, _full(src.type, 0), signed=signed)
1215+
elif signed:
11971216
return arith_dialect.fptosi(dst_type, src)
11981217
else:
11991218
return arith_dialect.fptoui(dst_type, src)
12001219

12011220

1202-
def _int_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1221+
def _int_float_cast(
1222+
src: ir.Value, dst_type: ir.Type, *, signed: bool
1223+
) -> ir.Value:
12031224
src_element_type = ir.IntegerType(_element_type(src.type))
12041225
dst_element_type = _element_type(dst_type)
1205-
if not isinstance(dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)):
1226+
if not isinstance(
1227+
dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)
1228+
):
12061229
raise NotImplementedError(f"cannot cast {src} tp {dst_type}")
1207-
if src_element_type.width == 1 or not src_element_type.is_signed:
1230+
if src_element_type.width == 1 or not signed:
12081231
return arith_dialect.uitofp(dst_type, src)
12091232
else:
12101233
return arith_dialect.sitofp(dst_type, src)
12111234

12121235

1213-
def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1236+
def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
12141237
if ir.RankedTensorType.isinstance(
12151238
src.type
12161239
) and not ir.RankedTensorType.isinstance(dst_type):
@@ -1234,7 +1257,9 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
12341257
if isinstance(src_element_type, (ir.F16Type, ir.BF16Type)) and not isinstance(
12351258
dst_element_type, ir.F32Type
12361259
):
1237-
return _cast(_cast(src, ir.F32Type.get()), dst_type)
1260+
return _cast(
1261+
_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
1262+
)
12381263

12391264
if isinstance(src_element_type, ir.FloatType) and isinstance(
12401265
dst_element_type, ir.FloatType
@@ -1244,26 +1269,26 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
12441269
if isinstance(src_element_type, ir.IntegerType) and isinstance(
12451270
dst_element_type, ir.IntegerType
12461271
):
1247-
return _int_int_cast(src, dst_type)
1272+
return _int_int_cast(src, dst_type, signed=signed)
12481273

12491274
if isinstance(src_element_type, ir.FloatType) and isinstance(
12501275
dst_element_type, ir.IntegerType
12511276
):
1252-
return _float_int_cast(src, dst_type)
1277+
return _float_int_cast(src, dst_type, signed=signed)
12531278
if isinstance(src_element_type, ir.IntegerType) and isinstance(
12541279
dst_element_type, ir.FloatType
12551280
):
1256-
return _int_float_cast(src, dst_type)
1281+
return _int_float_cast(src, dst_type, signed=signed)
12571282

12581283
if tt_dialect.PointerType.isinstance(src_element_type) and isinstance(
12591284
dst_element_type, ir.IntegerType
12601285
):
12611286
if dst_element_type.width == 64:
12621287
return tt_dialect.ptr_to_int(dst_type, src)
12631288
else:
1264-
x = _cast(src, ir.IntegerType.get_signless(64))
1289+
x = _cast(src, ir.IntegerType.get_signless(64), signed=signed)
12651290
zero = _full(x.type, 0)
1266-
return _cast(_not_equal(x, zero), dst_type)
1291+
return _cast(_not_equal(x, zero, signed=signed), dst_type, signed=signed)
12671292
if isinstance(
12681293
src_element_type, ir.IntegerType
12691294
) and tt_dialect.PointerType.isinstance(dst_element_type):
@@ -1283,7 +1308,8 @@ def _convert_element_type_lowering_rule(
12831308
x = _ensure_ir_value(x, x_aval)
12841309
if new_dtype == x_aval.dtype:
12851310
return x
1286-
return _cast(x, _dtype_to_ir_type(new_dtype))
1311+
signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger)
1312+
return _cast(x, _dtype_to_ir_type(new_dtype), signed=signed)
12871313

12881314

12891315
triton_lowering_rules[lax.convert_element_type_p] = (
@@ -1428,7 +1454,7 @@ def _compute_pointers_from_indices(
14281454
else:
14291455
ptr_dim_offset = _add(
14301456
_bcast_to(index.start, [index.size]),
1431-
_cast(_make_range(0, index.size), index.start.type),
1457+
_cast(_make_range(0, index.size), index.start.type, signed=False),
14321458
)
14331459
# We need to add broadcastable dimensions for the advanced int indexing
14341460
# and for previous slices
@@ -1466,7 +1492,7 @@ def _compute_pointers_from_indices(
14661492
ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape)
14671493
index_type = ir.IntegerType(_element_type(ptr_dim_offset.type))
14681494
if start_offset is not None:
1469-
start_offset = _cast(start_offset, index_type)
1495+
start_offset = _cast(start_offset, index_type, signed=False)
14701496
ptr_dim_offset = _add(
14711497
ptr_dim_offset, _bcast_to(start_offset, indexer_shape)
14721498
)
@@ -1578,11 +1604,13 @@ def _load(
15781604
if is_int1:
15791605
pointee_type = ir.IntegerType.get_signless(8)
15801606
ptr = _cast(
1581-
ptr, tt_dialect.PointerType.get(pointee_type, ptr_type.address_space)
1607+
ptr,
1608+
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
1609+
signed=False,
15821610
)
15831611

15841612
if other is not None:
1585-
other = _cast(other, pointee_type)
1613+
other = _cast(other, pointee_type, signed=False)
15861614

15871615
result = tt_dialect.load(
15881616
_infer_load_return_type(ptr),
@@ -1594,7 +1622,9 @@ def _load(
15941622
is_volatile=is_volatile,
15951623
)
15961624
return (
1597-
result if not is_int1 else _cast(result, ir.IntegerType.get_signless(1))
1625+
result
1626+
if not is_int1
1627+
else _cast(result, ir.IntegerType.get_signless(1), signed=False)
15981628
)
15991629

16001630

@@ -1696,10 +1726,12 @@ def _store(
16961726
if isinstance(pointee_type, ir.IntegerType) and pointee_type.width == 1:
16971727
pointee_type = ir.IntegerType.get_signless(8)
16981728
ptr = _cast(
1699-
ptr, tt_dialect.PointerType.get(pointee_type, ptr_type.address_space)
1729+
ptr,
1730+
tt_dialect.PointerType.get(pointee_type, ptr_type.address_space),
1731+
signed=False,
17001732
)
17011733

1702-
value = _cast(value, pointee_type)
1734+
value = _cast(value, pointee_type, signed=False)
17031735
return tt_dialect.store(
17041736
ptr, value, mask=mask, cache=cache_modifier, evict=eviction_policy
17051737
)
@@ -1867,6 +1899,7 @@ def _dot_general_lowering(
18671899
out_type=_dtype_to_ir_type(acc_dtype),
18681900
),
18691901
_dtype_to_ir_type(out_dtype),
1902+
signed=jnp.issubdtype(out_aval.dtype, jnp.signedinteger),
18701903
)
18711904

18721905

@@ -2340,7 +2373,7 @@ def to_type(out_aval):
23402373

23412374
out_types = [to_type(out) for out in ctx.avals_out]
23422375

2343-
use_branch0 = _equal(index, _ir_constant(0, index.type))
2376+
use_branch0 = _equal(index, _ir_constant(0, index.type), signed=False)
23442377
# TODO(bjp): Switch to scf.index_switch once exposed in triton.cc
23452378
if_op = scf_dialect.IfOp(use_branch0, out_types, hasElse=True)
23462379
with ir.InsertionPoint.at_block_begin(if_op.then_block):

0 commit comments

Comments
 (0)