@@ -580,11 +580,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value:
580
580
for aval , arg , arg_type in zip (ctx .avals_in , args , h .arg_types ):
581
581
bcast_arg = _bcast_to (_ensure_ir_value (arg , aval ), out_aval .shape )
582
582
if aval .weak_type and aval .dtype .name != arg_type :
583
- bcast_arg = _cast (
584
- bcast_arg ,
585
- _dtype_to_ir_type (jnp .dtype (arg_type )),
586
- signed = jnp .issubdtype (aval .dtype , jnp .signedinteger ),
587
- )
583
+ bcast_arg = _cast (bcast_arg , aval .dtype , jnp .dtype (arg_type ))
588
584
bcast_args .append (bcast_arg )
589
585
return h .lower (ctx , * bcast_args )
590
586
@@ -1162,8 +1158,8 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x):
1162
1158
signed = jnp .issubdtype (x_aval .dtype , jnp .signedinteger )
1163
1159
zero = _full (x .type , 0 )
1164
1160
return _sub (
1165
- _cast (_greater_than (x , zero , signed = signed ), x . type , signed = signed ),
1166
- _cast (_less_than (x , zero , signed = signed ), x . type , signed = signed ),
1161
+ _cast (_greater_than (x , zero , signed = signed ), jnp . bool_ , x_aval . dtype ),
1162
+ _cast (_less_than (x , zero , signed = signed ), jnp . bool_ , x_aval . dtype ),
1167
1163
)
1168
1164
1169
1165
@@ -1172,7 +1168,7 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x):
1172
1168
1173
1169
def _iota_lowering_rule (ctx : LoweringRuleContext , * , dtype , shape , dimension ):
1174
1170
iota = _make_range (0 , shape [dimension ])
1175
- iota = _cast (iota , _dtype_to_ir_type ( dtype ), signed = False )
1171
+ iota = _cast (iota , jnp . int32 , dtype )
1176
1172
for i in range (len (shape )):
1177
1173
if i != dimension :
1178
1174
iota = _expand_dims (iota , i )
@@ -1298,7 +1294,19 @@ def _int_float_cast(
1298
1294
return arith_dialect .sitofp (dst_type , src )
1299
1295
1300
1296
1301
- def _cast (src : ir .Value , dst_type : ir .Type , * , signed : bool ) -> ir .Value :
1297
+ def _cast (
1298
+ src : ir .Value ,
1299
+ src_type : jax .typing .DTypeLike ,
1300
+ dst_type : jax .typing .DTypeLike ,
1301
+ ) -> ir .Value :
1302
+ return _ir_cast (
1303
+ src ,
1304
+ _dtype_to_ir_type (dst_type ),
1305
+ signed = jnp .issubdtype (src_type , jnp .signedinteger ),
1306
+ )
1307
+
1308
+
1309
+ def _ir_cast (src : ir .Value , dst_type : ir .Type , * , signed : bool ) -> ir .Value :
1302
1310
if ir .RankedTensorType .isinstance (
1303
1311
src .type
1304
1312
) and not ir .RankedTensorType .isinstance (dst_type ):
@@ -1322,8 +1330,8 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
1322
1330
if isinstance (src_element_type , (ir .F16Type , ir .BF16Type )) and not isinstance (
1323
1331
dst_element_type , ir .F32Type
1324
1332
):
1325
- return _cast (
1326
- _cast (src , ir .F32Type .get (), signed = False ), dst_type , signed = False
1333
+ return _ir_cast (
1334
+ _ir_cast (src , ir .F32Type .get (), signed = False ), dst_type , signed = False
1327
1335
)
1328
1336
1329
1337
if isinstance (src_element_type , ir .FloatType ) and isinstance (
@@ -1350,10 +1358,10 @@ def _cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
1350
1358
):
1351
1359
if dst_element_type .width == 64 :
1352
1360
return tt_dialect .ptr_to_int (dst_type , src )
1353
- else :
1354
- x = _cast (src , ir .IntegerType .get_signless (64 ), signed = signed )
1361
+ elif dst_element_type . width == 1 :
1362
+ x = _ir_cast (src , ir .IntegerType .get_signless (64 ), signed = signed )
1355
1363
zero = _full (x .type , 0 )
1356
- return _cast (_not_equal (x , zero , signed = signed ), dst_type , signed = signed )
1364
+ return _ir_cast (_not_equal (x , zero , signed = signed ), dst_type , signed = signed )
1357
1365
if isinstance (
1358
1366
src_element_type , ir .IntegerType
1359
1367
) and tt_dialect .PointerType .isinstance (dst_element_type ):
@@ -1373,8 +1381,7 @@ def _convert_element_type_lowering_rule(
1373
1381
x = _ensure_ir_value (x , x_aval )
1374
1382
if new_dtype == x_aval .dtype :
1375
1383
return x
1376
- signed = jnp .issubdtype (x_aval .dtype , jnp .signedinteger )
1377
- return _cast (x , _dtype_to_ir_type (new_dtype ), signed = signed )
1384
+ return _cast (x , x_aval .dtype , new_dtype )
1378
1385
1379
1386
1380
1387
triton_lowering_rules [lax .convert_element_type_p ] = (
@@ -1519,7 +1526,7 @@ def _compute_pointers_from_indices(
1519
1526
else :
1520
1527
ptr_dim_offset = _add (
1521
1528
_bcast_to (index .start , [index .size ]),
1522
- _cast (_make_range (0 , index .size ), index .start .type , signed = False ),
1529
+ _ir_cast (_make_range (0 , index .size ), index .start .type , signed = False ),
1523
1530
)
1524
1531
# We need to add broadcastable dimensions for the advanced int indexing
1525
1532
# and for previous slices
@@ -1557,7 +1564,7 @@ def _compute_pointers_from_indices(
1557
1564
ptr_dim_offset = _bcast_to (ptr_dim_offset , indexer_shape )
1558
1565
index_type = ir .IntegerType (_element_type (ptr_dim_offset .type ))
1559
1566
if start_offset is not None :
1560
- start_offset = _cast (start_offset , index_type , signed = False )
1567
+ start_offset = _ir_cast (start_offset , index_type , signed = False )
1561
1568
ptr_dim_offset = _add (
1562
1569
ptr_dim_offset , _bcast_to (start_offset , indexer_shape )
1563
1570
)
@@ -1660,14 +1667,14 @@ def _load(
1660
1667
is_int1 = isinstance (pointee_type , ir .IntegerType ) and pointee_type .width == 1
1661
1668
if is_int1 :
1662
1669
pointee_type = ir .IntegerType .get_signless (8 )
1663
- ptr = _cast (
1670
+ ptr = _ir_cast (
1664
1671
ptr ,
1665
1672
tt_dialect .PointerType .get (pointee_type , ptr_type .address_space ),
1666
1673
signed = False ,
1667
1674
)
1668
1675
1669
1676
if other is not None :
1670
- other = _cast (other , pointee_type , signed = False )
1677
+ other = _ir_cast (other , pointee_type , signed = False )
1671
1678
1672
1679
result = tt_dialect .load (
1673
1680
_infer_load_return_type (ptr ),
@@ -1681,7 +1688,7 @@ def _load(
1681
1688
return (
1682
1689
result
1683
1690
if not is_int1
1684
- else _cast (result , ir .IntegerType .get_signless (1 ), signed = False )
1691
+ else _ir_cast (result , ir .IntegerType .get_signless (1 ), signed = False )
1685
1692
)
1686
1693
1687
1694
@@ -1782,13 +1789,13 @@ def _store(
1782
1789
pointee_type = ptr_type .pointee_type
1783
1790
if isinstance (pointee_type , ir .IntegerType ) and pointee_type .width == 1 :
1784
1791
pointee_type = ir .IntegerType .get_signless (8 )
1785
- ptr = _cast (
1792
+ ptr = _ir_cast (
1786
1793
ptr ,
1787
1794
tt_dialect .PointerType .get (pointee_type , ptr_type .address_space ),
1788
1795
signed = False ,
1789
1796
)
1790
1797
1791
- value = _cast (value , pointee_type , signed = False )
1798
+ value = _ir_cast (value , pointee_type , signed = False )
1792
1799
return tt_dialect .store (
1793
1800
ptr , value , mask = mask , cache = cache_modifier , evict = eviction_policy
1794
1801
)
@@ -1955,8 +1962,8 @@ def _dot_general_lowering(
1955
1962
allow_tf32 = allow_tf32 ,
1956
1963
out_type = _dtype_to_ir_type (acc_dtype ),
1957
1964
),
1958
- _dtype_to_ir_type ( out_dtype ) ,
1959
- signed = jnp . issubdtype ( out_aval . dtype , jnp . signedinteger ) ,
1965
+ acc_dtype ,
1966
+ out_dtype ,
1960
1967
)
1961
1968
1962
1969
0 commit comments