@@ -215,8 +215,8 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping):
215
215
for i , s in enumerate (collapse_dims ):
216
216
out_idx = launch_grid_to_pallas_grid [i ]
217
217
s = _i32_constant (s )
218
- out_indices [out_idx ] = _mod (grid0 , s )
219
- grid0 = _floordiv (grid0 , s )
218
+ out_indices [out_idx ] = _mod (grid0 , s , signed = False )
219
+ grid0 = _floordiv (grid0 , s , signed = False )
220
220
221
221
for i in range (len (prog_id_dims )):
222
222
out_idx = launch_grid_to_pallas_grid [num_collapse + i ]
@@ -558,7 +558,11 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
558
558
for aval , arg , arg_type in zip (ctx .avals_in , args , extern .arg_types ):
559
559
bcast_arg = _bcast_to (_ensure_ir_value (arg , aval ), out_aval .shape )
560
560
if aval .weak_type and aval .dtype .name != arg_type :
561
- bcast_arg = _cast (bcast_arg , _dtype_to_ir_type (jnp .dtype (arg_type )))
561
+ bcast_arg = _cast (
562
+ bcast_arg ,
563
+ _dtype_to_ir_type (jnp .dtype (arg_type )),
564
+ signed = jnp .issubdtype (aval .dtype , jnp .signedinteger ),
565
+ )
562
566
bcast_args .append (bcast_arg )
563
567
564
568
result_type = _dtype_to_ir_type (jnp .dtype (extern .result_type ))
@@ -831,35 +835,33 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
831
835
raise NotImplementedError (f"unsupported types: { x .type } and { y .type } " )
832
836
833
837
834
- def _floordiv (x : ir .Value , y : ir .Value ) -> ir .Value :
838
+ def _floordiv (x : ir .Value , y : ir .Value , * , signed : bool ) -> ir .Value :
835
839
assert x .type == y .type , (str (x .type ), str (y .type ))
836
- x_element_type = _element_type (x .type )
837
- if not isinstance (x_element_type , ir .IntegerType ):
840
+ if not isinstance (_element_type (x .type ), ir .IntegerType ):
838
841
raise NotImplementedError (f"unsupported types: { x .type } and { y .type } " )
839
- if x_element_type . is_signed :
842
+ if signed :
840
843
return arith_dialect .divsi (x , y )
841
844
else :
842
845
return arith_dialect .divui (x , y )
843
846
844
847
845
- def _truediv (x : ir .Value , y : ir .Value ) -> ir .Value :
848
+ def _truediv (x : ir .Value , y : ir .Value , * , signed : bool ) -> ir .Value :
846
849
assert x .type == y .type , (str (x .type ), str (y .type ))
847
850
x_element_type = _element_type (x .type )
848
851
if isinstance (x_element_type , ir .IntegerType ):
849
852
x_element_type = ir .F32Type .get ()
850
- x = _int_float_cast (x , x_element_type )
851
- y = _int_float_cast (y , x_element_type )
853
+ x = _int_float_cast (x , x_element_type , signed = signed )
854
+ y = _int_float_cast (y , x_element_type , signed = signed )
852
855
if isinstance (x_element_type , ir .FloatType ):
853
856
return arith_dialect .divf (x , y )
854
857
raise NotImplementedError (f"unsupported types: { x .type } and { y .type } " )
855
858
856
859
857
- def _mod (x : ir .Value , y : ir .Value ) -> ir .Value :
860
+ def _mod (x : ir .Value , y : ir .Value , * , signed : bool ) -> ir .Value :
858
861
assert x .type == y .type , (str (x .type ), str (y .type ))
859
- x_element_type = _element_type (x .type )
860
- if not isinstance (x_element_type , ir .IntegerType ):
862
+ if not isinstance (_element_type (x .type ), ir .IntegerType ):
861
863
raise NotImplementedError (f"unsupported types: { x .type } and { y .type } " )
862
- if x_element_type . is_signed :
864
+ if signed :
863
865
return arith_dialect .remsi (x , y )
864
866
else :
865
867
return arith_dialect .remui (x , y )
@@ -871,13 +873,13 @@ def _cmp(
871
873
si_pred : arith_dialect .CmpIPredicate ,
872
874
ui_pred : arith_dialect .CmpIPredicate ,
873
875
f_pred : arith_dialect .CmpFPredicate ,
876
+ * ,
877
+ signed : bool ,
874
878
) -> ir .Value :
875
879
assert x .type == y .type , (str (x .type ), str (y .type ))
876
880
x_element_type = _element_type (x .type )
877
881
if isinstance (x_element_type , ir .IntegerType ):
878
- return arith_dialect .cmpi (
879
- si_pred if x_element_type .is_signed else ui_pred , x , y
880
- )
882
+ return arith_dialect .cmpi (si_pred if signed else ui_pred , x , y )
881
883
elif isinstance (x_element_type , ir .FloatType ):
882
884
return arith_dialect .cmpf (f_pred , x , y )
883
885
else :
@@ -926,29 +928,42 @@ def _cmp(
926
928
lax .add_p : _add ,
927
929
lax .sub_p : _sub ,
928
930
lax .mul_p : _mul ,
929
- lax .rem_p : _mod ,
930
931
lax .and_p : arith_dialect .andi ,
931
932
lax .or_p : arith_dialect .ori ,
932
933
lax .xor_p : arith_dialect .xori ,
933
934
lax .shift_left_p : arith_dialect .shli ,
934
935
lax .shift_right_arithmetic_p : arith_dialect .shrsi ,
935
936
lax .shift_right_logical_p : arith_dialect .shrui ,
937
+ ad_util .add_any_p : _add ,
938
+ }
939
+
940
+ for prim , fn in _JAX_TO_TRITON_BINARY .items ():
941
+
942
+ def signless_rule (ctx : LoweringRuleContext , x , y , fn = fn ):
943
+ x , y = _bcast (x , y , * ctx .avals_in , * ctx .avals_out )
944
+ return fn (x , y )
945
+
946
+ triton_lowering_rules [prim ] = signless_rule
947
+
948
+
949
+ _JAX_TO_TRITON_SIGNED_BINARY = {
950
+ lax .rem_p : _mod ,
936
951
lax .eq_p : _equal ,
937
952
lax .ne_p : _not_equal ,
938
953
lax .gt_p : _greater_than ,
939
954
lax .ge_p : _greater_equal ,
940
955
lax .lt_p : _less_than ,
941
956
lax .le_p : _less_equal ,
942
- ad_util .add_any_p : _add ,
943
957
}
944
958
945
- for prim , fn in _JAX_TO_TRITON_BINARY .items ():
959
+ for prim , fn in _JAX_TO_TRITON_SIGNED_BINARY .items ():
946
960
947
- def rule (ctx : LoweringRuleContext , x , y , fn = fn ):
961
+ def signed_rule (ctx : LoweringRuleContext , x , y , fn = fn ):
962
+ x_aval , _ = ctx .avals_in
948
963
x , y = _bcast (x , y , * ctx .avals_in , * ctx .avals_out )
949
- return fn (x , y )
964
+ return fn (x , y , signed = jnp . issubdtype ( x_aval . dtype , jnp . signedinteger ) )
950
965
951
- triton_lowering_rules [prim ] = rule
966
+ triton_lowering_rules [prim ] = signed_rule
952
967
953
968
954
969
def _set_attr (v : ir .Value , name : str , attr : ir .Attribute ) -> None :
@@ -1080,11 +1095,14 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
1080
1095
def _div_lowering_rule (ctx : LoweringRuleContext , x , y ):
1081
1096
x_aval , y_aval = ctx .avals_in
1082
1097
x , y = _bcast (x , y , * ctx .avals_in , * ctx .avals_out )
1098
+ signed = jnp .issubdtype (x_aval .dtype , jnp .signedinteger ) or jnp .issubdtype (
1099
+ y_aval .dtype , jnp .signedinteger
1100
+ )
1083
1101
if np .issubdtype (x_aval .dtype , np .floating ) or np .issubdtype (
1084
1102
y_aval .dtype , np .floating
1085
1103
):
1086
- return _truediv (x , y )
1087
- return _floordiv (x , y )
1104
+ return _truediv (x , y , signed = signed )
1105
+ return _floordiv (x , y , signed = signed )
1088
1106
1089
1107
1090
1108
triton_lowering_rules [lax .div_p ] = _div_lowering_rule
@@ -1093,7 +1111,7 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
1093
1111
def _iota_lowering_rule (ctx : LoweringRuleContext , * , dtype , shape , dimension ):
1094
1112
if dimension != 0 :
1095
1113
raise NotImplementedError
1096
- return _cast (_make_range (0 , * shape ), _dtype_to_ir_type (dtype ))
1114
+ return _cast (_make_range (0 , * shape ), _dtype_to_ir_type (dtype ), signed = False )
1097
1115
1098
1116
1099
1117
triton_lowering_rules [lax .iota_p ] = _iota_lowering_rule
@@ -1168,49 +1186,54 @@ def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1168
1186
raise NotImplementedError
1169
1187
1170
1188
1171
- def _int_int_cast (src : ir .Value , dst_type : ir .Type ) -> ir .Value :
1189
+ def _int_int_cast (src : ir .Value , dst_type : ir .Type , signed : bool ) -> ir .Value :
1172
1190
src_element_type = ir .IntegerType (_element_type (src .type ))
1173
1191
dst_element_type = ir .IntegerType (_element_type (dst_type ))
1174
1192
assert src_element_type != dst_element_type
1175
1193
if dst_element_type .width == 1 :
1176
- return _not_equal (src , _full (src .type , 0 ))
1194
+ return _not_equal (src , _full (src .type , 0 ), signed = signed )
1177
1195
1178
- is_signed = src_element_type .is_signed and src_element_type .width != 1
1179
1196
if src_element_type .width == dst_element_type .width :
1180
1197
return arith_dialect .bitcast (dst_type , src )
1181
1198
elif src_element_type .width > dst_element_type .width :
1182
1199
return arith_dialect .trunci (dst_type , src )
1183
- elif is_signed :
1200
+ elif signed and src_element_type . width != 1 :
1184
1201
return arith_dialect .extsi (dst_type , src )
1185
1202
else :
1186
1203
return arith_dialect .extui (dst_type , src )
1187
1204
1188
1205
1189
- def _float_int_cast (src : ir .Value , dst_type : ir .Type ) -> ir .Value :
1206
+ def _float_int_cast (
1207
+ src : ir .Value , dst_type : ir .Type , * , signed : bool
1208
+ ) -> ir .Value :
1190
1209
src_element_type = _element_type (src .type )
1191
1210
if not isinstance (src_element_type , (ir .BF16Type , ir .F16Type , ir .F32Type , ir .F64Type )):
1192
1211
raise NotImplementedError (f"cannot cast { src } tp { dst_type } " )
1193
1212
dst_element_type = ir .IntegerType (_element_type (dst_type ))
1194
1213
if dst_element_type .width == 1 :
1195
- return _not_equal (src , _full (src .type , 0 ))
1196
- elif dst_element_type . is_signed :
1214
+ return _not_equal (src , _full (src .type , 0 ), signed = signed )
1215
+ elif signed :
1197
1216
return arith_dialect .fptosi (dst_type , src )
1198
1217
else :
1199
1218
return arith_dialect .fptoui (dst_type , src )
1200
1219
1201
1220
1202
- def _int_float_cast (src : ir .Value , dst_type : ir .Type ) -> ir .Value :
1221
+ def _int_float_cast (
1222
+ src : ir .Value , dst_type : ir .Type , * , signed : bool
1223
+ ) -> ir .Value :
1203
1224
src_element_type = ir .IntegerType (_element_type (src .type ))
1204
1225
dst_element_type = _element_type (dst_type )
1205
- if not isinstance (dst_element_type , (ir .BF16Type , ir .F16Type , ir .F32Type , ir .F64Type )):
1226
+ if not isinstance (
1227
+ dst_element_type , (ir .BF16Type , ir .F16Type , ir .F32Type , ir .F64Type )
1228
+ ):
1206
1229
raise NotImplementedError (f"cannot cast { src } tp { dst_type } " )
1207
- if src_element_type .width == 1 or not src_element_type . is_signed :
1230
+ if src_element_type .width == 1 or not signed :
1208
1231
return arith_dialect .uitofp (dst_type , src )
1209
1232
else :
1210
1233
return arith_dialect .sitofp (dst_type , src )
1211
1234
1212
1235
1213
- def _cast (src : ir .Value , dst_type : ir .Type ) -> ir .Value :
1236
+ def _cast (src : ir .Value , dst_type : ir .Type , * , signed : bool ) -> ir .Value :
1214
1237
if ir .RankedTensorType .isinstance (
1215
1238
src .type
1216
1239
) and not ir .RankedTensorType .isinstance (dst_type ):
@@ -1234,7 +1257,9 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1234
1257
if isinstance (src_element_type , (ir .F16Type , ir .BF16Type )) and not isinstance (
1235
1258
dst_element_type , ir .F32Type
1236
1259
):
1237
- return _cast (_cast (src , ir .F32Type .get ()), dst_type )
1260
+ return _cast (
1261
+ _cast (src , ir .F32Type .get (), signed = False ), dst_type , signed = False
1262
+ )
1238
1263
1239
1264
if isinstance (src_element_type , ir .FloatType ) and isinstance (
1240
1265
dst_element_type , ir .FloatType
@@ -1244,26 +1269,26 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
1244
1269
if isinstance (src_element_type , ir .IntegerType ) and isinstance (
1245
1270
dst_element_type , ir .IntegerType
1246
1271
):
1247
- return _int_int_cast (src , dst_type )
1272
+ return _int_int_cast (src , dst_type , signed = signed )
1248
1273
1249
1274
if isinstance (src_element_type , ir .FloatType ) and isinstance (
1250
1275
dst_element_type , ir .IntegerType
1251
1276
):
1252
- return _float_int_cast (src , dst_type )
1277
+ return _float_int_cast (src , dst_type , signed = signed )
1253
1278
if isinstance (src_element_type , ir .IntegerType ) and isinstance (
1254
1279
dst_element_type , ir .FloatType
1255
1280
):
1256
- return _int_float_cast (src , dst_type )
1281
+ return _int_float_cast (src , dst_type , signed = signed )
1257
1282
1258
1283
if tt_dialect .PointerType .isinstance (src_element_type ) and isinstance (
1259
1284
dst_element_type , ir .IntegerType
1260
1285
):
1261
1286
if dst_element_type .width == 64 :
1262
1287
return tt_dialect .ptr_to_int (dst_type , src )
1263
1288
else :
1264
- x = _cast (src , ir .IntegerType .get_signless (64 ))
1289
+ x = _cast (src , ir .IntegerType .get_signless (64 ), signed = signed )
1265
1290
zero = _full (x .type , 0 )
1266
- return _cast (_not_equal (x , zero ), dst_type )
1291
+ return _cast (_not_equal (x , zero , signed = signed ), dst_type , signed = signed )
1267
1292
if isinstance (
1268
1293
src_element_type , ir .IntegerType
1269
1294
) and tt_dialect .PointerType .isinstance (dst_element_type ):
@@ -1283,7 +1308,8 @@ def _convert_element_type_lowering_rule(
1283
1308
x = _ensure_ir_value (x , x_aval )
1284
1309
if new_dtype == x_aval .dtype :
1285
1310
return x
1286
- return _cast (x , _dtype_to_ir_type (new_dtype ))
1311
+ signed = jnp .issubdtype (x_aval .dtype , jnp .signedinteger )
1312
+ return _cast (x , _dtype_to_ir_type (new_dtype ), signed = signed )
1287
1313
1288
1314
1289
1315
triton_lowering_rules [lax .convert_element_type_p ] = (
@@ -1428,7 +1454,7 @@ def _compute_pointers_from_indices(
1428
1454
else :
1429
1455
ptr_dim_offset = _add (
1430
1456
_bcast_to (index .start , [index .size ]),
1431
- _cast (_make_range (0 , index .size ), index .start .type ),
1457
+ _cast (_make_range (0 , index .size ), index .start .type , signed = False ),
1432
1458
)
1433
1459
# We need to add broadcastable dimensions for the advanced int indexing
1434
1460
# and for previous slices
@@ -1466,7 +1492,7 @@ def _compute_pointers_from_indices(
1466
1492
ptr_dim_offset = _bcast_to (ptr_dim_offset , indexer_shape )
1467
1493
index_type = ir .IntegerType (_element_type (ptr_dim_offset .type ))
1468
1494
if start_offset is not None :
1469
- start_offset = _cast (start_offset , index_type )
1495
+ start_offset = _cast (start_offset , index_type , signed = False )
1470
1496
ptr_dim_offset = _add (
1471
1497
ptr_dim_offset , _bcast_to (start_offset , indexer_shape )
1472
1498
)
@@ -1578,11 +1604,13 @@ def _load(
1578
1604
if is_int1 :
1579
1605
pointee_type = ir .IntegerType .get_signless (8 )
1580
1606
ptr = _cast (
1581
- ptr , tt_dialect .PointerType .get (pointee_type , ptr_type .address_space )
1607
+ ptr ,
1608
+ tt_dialect .PointerType .get (pointee_type , ptr_type .address_space ),
1609
+ signed = False ,
1582
1610
)
1583
1611
1584
1612
if other is not None :
1585
- other = _cast (other , pointee_type )
1613
+ other = _cast (other , pointee_type , signed = False )
1586
1614
1587
1615
result = tt_dialect .load (
1588
1616
_infer_load_return_type (ptr ),
@@ -1594,7 +1622,9 @@ def _load(
1594
1622
is_volatile = is_volatile ,
1595
1623
)
1596
1624
return (
1597
- result if not is_int1 else _cast (result , ir .IntegerType .get_signless (1 ))
1625
+ result
1626
+ if not is_int1
1627
+ else _cast (result , ir .IntegerType .get_signless (1 ), signed = False )
1598
1628
)
1599
1629
1600
1630
@@ -1696,10 +1726,12 @@ def _store(
1696
1726
if isinstance (pointee_type , ir .IntegerType ) and pointee_type .width == 1 :
1697
1727
pointee_type = ir .IntegerType .get_signless (8 )
1698
1728
ptr = _cast (
1699
- ptr , tt_dialect .PointerType .get (pointee_type , ptr_type .address_space )
1729
+ ptr ,
1730
+ tt_dialect .PointerType .get (pointee_type , ptr_type .address_space ),
1731
+ signed = False ,
1700
1732
)
1701
1733
1702
- value = _cast (value , pointee_type )
1734
+ value = _cast (value , pointee_type , signed = False )
1703
1735
return tt_dialect .store (
1704
1736
ptr , value , mask = mask , cache = cache_modifier , evict = eviction_policy
1705
1737
)
@@ -1867,6 +1899,7 @@ def _dot_general_lowering(
1867
1899
out_type = _dtype_to_ir_type (acc_dtype ),
1868
1900
),
1869
1901
_dtype_to_ir_type (out_dtype ),
1902
+ signed = jnp .issubdtype (out_aval .dtype , jnp .signedinteger ),
1870
1903
)
1871
1904
1872
1905
@@ -2340,7 +2373,7 @@ def to_type(out_aval):
2340
2373
2341
2374
out_types = [to_type (out ) for out in ctx .avals_out ]
2342
2375
2343
- use_branch0 = _equal (index , _ir_constant (0 , index .type ))
2376
+ use_branch0 = _equal (index , _ir_constant (0 , index .type ), signed = False )
2344
2377
# TODO(bjp): Switch to scf.index_switch once exposed in triton.cc
2345
2378
if_op = scf_dialect .IfOp (use_branch0 , out_types , hasElse = True )
2346
2379
with ir .InsertionPoint .at_block_begin (if_op .then_block ):
0 commit comments