25
25
26
26
import jax
27
27
from jax ._src import core as jax_core
28
+ from jax ._src import dtypes
28
29
from jax ._src import pretty_printer as pp
29
30
from jax ._src import state
30
31
from jax ._src import tree_util
@@ -1164,6 +1165,9 @@ def tcgen05_mma(acc: _Ref,
1164
1165
a : _Ref ,
1165
1166
b : _Ref ,
1166
1167
barrier : _Ref | None = None ,
1168
+ * ,
1169
+ a_scale : _Ref | None = None ,
1170
+ b_scale : _Ref | None = None ,
1167
1171
accumulate : bool | jax .Array = True ,
1168
1172
collective_axis : str | None = None ):
1169
1173
"""Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell).
@@ -1178,6 +1182,9 @@ def tcgen05_mma(acc: _Ref,
1178
1182
| ACC2 | | LHS2 | | | |
1179
1183
----------- ----------- -----------
1180
1184
1185
+ To use the block-scaled matrix-multiply, provide `a_scale` and `b_scale`
1186
+ operands (they must be both present or both unspecified).
1187
+
1181
1188
Args:
1182
1189
acc: The accumulator. Must be a TMEM Ref.
1183
1190
a: The left-hand side. Must be a TMEM/SMEM Ref.
@@ -1186,6 +1193,8 @@ def tcgen05_mma(acc: _Ref,
1186
1193
Must have orders_tensor_core set to True. If not specified, the MMA
1187
1194
completion should be explicitly observed by calling
1188
1195
`tcgen05_commit_arrive`
1196
+ a_scale: An optional scale for the ``a`` operand. Must be a TMEM Ref if present.
1197
+ b_scale: An optional scale for the ``b`` operand. Must be a TMEM Ref if present.
1189
1198
accumulate: Whether to accumulate into acc or overwrite it.
1190
1199
collective_axis: The name of the cluster axis along which to perform
1191
1200
a collective MMA. The cluster axis should have a size of exactly 2,
@@ -1225,6 +1234,28 @@ def tcgen05_mma(acc: _Ref,
1225
1234
else :
1226
1235
b_transforms_leaves , b_transforms_tree = [], None
1227
1236
1237
+ if (is_scaled := a_scale is not None ) != (b_scale is not None ):
1238
+ raise ValueError ("a_scale and b_scale must both be present or absent." )
1239
+ scales = []
1240
+ if isinstance (a_scale , pallas_core .TransformedRef ):
1241
+ a_scale_transforms_leaves , a_scale_transforms_tree = jax .tree .flatten (
1242
+ a_scale .transforms
1243
+ )
1244
+ scales .append (a_scale .ref )
1245
+ else :
1246
+ a_scale_transforms_leaves , a_scale_transforms_tree = [], None
1247
+ scales .append (a_scale )
1248
+ if isinstance (b_scale , pallas_core .TransformedRef ):
1249
+ b_scale_transforms_leaves , b_scale_transforms_tree = jax .tree .flatten (
1250
+ b_scale .transforms
1251
+ )
1252
+ scales .append (b_scale .ref )
1253
+ else :
1254
+ b_scale_transforms_leaves , b_scale_transforms_tree = [], None
1255
+ scales .append (b_scale )
1256
+ if not is_scaled :
1257
+ scales = []
1258
+
1228
1259
if isinstance (barrier , pallas_core .TransformedRef ):
1229
1260
barrier_transforms_leaves , barrier_transforms_tree = jax .tree .flatten (
1230
1261
barrier .transforms
@@ -1240,26 +1271,33 @@ def tcgen05_mma(acc: _Ref,
1240
1271
barrier_ref = []
1241
1272
arrive = False
1242
1273
1243
- tcgen05_mma_p .bind (acc , a , b , accumulate , * barrier_ref ,
1274
+ tcgen05_mma_p .bind (acc , a , b , accumulate , * barrier_ref , * scales ,
1244
1275
* acc_transforms_leaves , * a_transforms_leaves ,
1245
1276
* b_transforms_leaves ,
1246
1277
* barrier_transforms_leaves ,
1278
+ * a_scale_transforms_leaves , * b_scale_transforms_leaves ,
1247
1279
acc_transforms_tree = acc_transforms_tree ,
1248
1280
a_transforms_tree = a_transforms_tree ,
1249
1281
b_transforms_tree = b_transforms_tree ,
1250
1282
barrier_transforms_tree = barrier_transforms_tree ,
1283
+ a_scale_transforms_tree = a_scale_transforms_tree ,
1284
+ b_scale_transforms_tree = b_scale_transforms_tree ,
1251
1285
collective_axis = collective_axis ,
1252
- arrive = arrive )
1286
+ arrive = arrive ,
1287
+ scaled = bool (scales ))
1253
1288
1254
1289
1255
1290
@tcgen05_mma_p .def_abstract_eval
1256
1291
def _tcgen05_mma_abstract_eval (acc , a , b , accumulate ,
1257
- * barrier_and_transforms_leaves ,
1292
+ * barrier_scales_and_transforms_leaves ,
1258
1293
acc_transforms_tree , a_transforms_tree ,
1259
1294
b_transforms_tree ,
1260
1295
barrier_transforms_tree ,
1296
+ a_scale_transforms_tree ,
1297
+ b_scale_transforms_tree ,
1261
1298
collective_axis ,
1262
- arrive ):
1299
+ arrive ,
1300
+ scaled ):
1263
1301
del (accumulate , acc_transforms_tree ,
1264
1302
a_transforms_tree , b_transforms_tree , barrier_transforms_tree )
1265
1303
@@ -1281,12 +1319,19 @@ def _tcgen05_mma_abstract_eval(acc, a, b, accumulate,
1281
1319
raise ValueError (
1282
1320
"LHS Ref must be collective if collective_axis is set." )
1283
1321
1322
+ scales_and_transforms_leaves = barrier_scales_and_transforms_leaves
1284
1323
if arrive :
1285
- barrier = barrier_and_transforms_leaves [ 0 ]
1324
+ barrier , * scales_and_transforms_leaves = barrier_scales_and_transforms_leaves
1286
1325
orders_tensor_core = getattr (
1287
1326
barrier .inner_aval .dtype , "orders_tensor_core" , False )
1288
1327
if not orders_tensor_core :
1289
1328
raise ValueError ("MMA barrier must have orders_tensor_core set to True." )
1329
+ if scaled :
1330
+ a_scale , b_scale = scales_and_transforms_leaves [:2 ]
1331
+ if a_scale .memory_space != gpu_core .TMEM :
1332
+ raise ValueError ("a_scale must be a TMEM Ref" )
1333
+ if b_scale .memory_space != gpu_core .TMEM :
1334
+ raise ValueError ("b_scale must be a TMEM Ref" )
1290
1335
1291
1336
return []
1292
1337
@@ -1299,35 +1344,52 @@ def _tcgen05_mma_lowering(
1299
1344
a_ref ,
1300
1345
b_ref ,
1301
1346
accumulate : bool | ir .Value ,
1302
- * barrier_and_transforms_leaves ,
1347
+ * barrier_scales_and_transforms_leaves ,
1303
1348
acc_transforms_tree ,
1304
1349
a_transforms_tree ,
1305
1350
b_transforms_tree ,
1306
1351
barrier_transforms_tree ,
1352
+ a_scale_transforms_tree ,
1353
+ b_scale_transforms_tree ,
1307
1354
collective_axis ,
1308
1355
arrive ,
1356
+ scaled : bool ,
1309
1357
):
1310
1358
_ , a_aval , b_aval , * _ = ctx .avals_in
1311
1359
lhs_swizzle : int | None = None
1312
1360
lhs_transpose : bool = False
1313
1361
if arrive :
1314
- barrier_ref , * transforms_leaves = barrier_and_transforms_leaves
1362
+ barrier_ref , * scales_and_transforms_leaves = barrier_scales_and_transforms_leaves
1315
1363
else :
1316
1364
barrier_ref = None
1317
- transforms_leaves = barrier_and_transforms_leaves # type: ignore[assignment]
1365
+ scales_and_transforms_leaves = barrier_scales_and_transforms_leaves # type: ignore[assignment]
1366
+ if scaled :
1367
+ a_scale_ref , b_scale_ref , * transforms_leaves = scales_and_transforms_leaves
1368
+ else :
1369
+ a_scale_ref = b_scale_ref = None
1370
+ transforms_leaves = scales_and_transforms_leaves # type: ignore[assignment]
1318
1371
1319
1372
transforms_trees = (
1320
1373
acc_transforms_tree ,
1321
1374
a_transforms_tree ,
1322
1375
b_transforms_tree ,
1323
1376
barrier_transforms_tree ,
1324
- )
1325
- (acc_transforms_leaves , a_transforms_leaves , b_transforms_leaves , barrier_transforms_leaves , _ ) = (
1326
- util .split_list (
1327
- transforms_leaves ,
1328
- [getattr (tree , "num_leaves" , 0 ) for tree in transforms_trees ],
1329
- )
1330
- )
1377
+ a_scale_transforms_tree ,
1378
+ b_scale_transforms_tree ,
1379
+ )
1380
+ (
1381
+ acc_transforms_leaves ,
1382
+ a_transforms_leaves ,
1383
+ b_transforms_leaves ,
1384
+ barrier_transforms_leaves ,
1385
+ a_scale_transforms_leaves ,
1386
+ b_scale_transforms_leaves ,
1387
+ leftovers ,
1388
+ ) = util .split_list (
1389
+ transforms_leaves ,
1390
+ [getattr (tree , "num_leaves" , 0 ) for tree in transforms_trees ],
1391
+ )
1392
+ assert not leftovers
1331
1393
1332
1394
if acc_transforms_tree is not None :
1333
1395
acc_transforms = acc_transforms_tree .unflatten (acc_transforms_leaves )
@@ -1359,7 +1421,7 @@ def _tcgen05_mma_lowering(
1359
1421
f"Unsupported transforms: { a_transforms } ."
1360
1422
)
1361
1423
if not isinstance (a_ref , tcgen05 .TMEMRef ):
1362
- swizzle_elems = lhs_swizzle // a_dtype . itemsize # type: ignore
1424
+ swizzle_elems = 8 * lhs_swizzle // dtypes . bit_width ( a_dtype ) # type: ignore
1363
1425
if lhs_tiling != (8 , swizzle_elems ):
1364
1426
raise ValueError ("MMA lhs tiling does not fit swizzle. "
1365
1427
f"{ lhs_tiling = } expected={ (8 , swizzle_elems )} " )
@@ -1383,7 +1445,7 @@ def _tcgen05_mma_lowering(
1383
1445
raise NotImplementedError (
1384
1446
f"Unsupported transforms: { b_transforms } ."
1385
1447
)
1386
- swizzle_elems = rhs_swizzle // b_dtype . itemsize
1448
+ swizzle_elems = 8 * rhs_swizzle // dtypes . bit_width ( b_dtype )
1387
1449
if rhs_tiling != (8 , swizzle_elems ):
1388
1450
raise ValueError (
1389
1451
"MMA rhs tiling does not fit swizzle"
@@ -1417,6 +1479,25 @@ def _tcgen05_mma_lowering(
1417
1479
accumulate = accumulate .registers .item ()
1418
1480
assert isinstance (accumulate , ir .Value )
1419
1481
1482
+ if a_scale_transforms_tree is not None :
1483
+ a_scale_transforms = a_scale_transforms_tree .unflatten (
1484
+ a_scale_transforms_leaves
1485
+ )
1486
+ a_scale_ref , a_scale_transforms = lowering ._handle_transforms (
1487
+ ctx , a_scale_ref , a_scale_transforms
1488
+ )
1489
+ if a_scale_transforms :
1490
+ raise NotImplementedError (f"Unsupported transforms: { a_scale_transforms } " )
1491
+ if b_scale_transforms_tree is not None :
1492
+ b_scale_transforms = b_scale_transforms_tree .unflatten (
1493
+ b_scale_transforms_leaves
1494
+ )
1495
+ b_scale_ref , b_scale_transforms = lowering ._handle_transforms (
1496
+ ctx , b_scale_ref , b_scale_transforms
1497
+ )
1498
+ if b_scale_transforms :
1499
+ raise NotImplementedError (f"Unsupported transforms: { b_scale_transforms } " )
1500
+
1420
1501
predicate = ctx .module_ctx .single_lane_predicate
1421
1502
if collective_axis is not None :
1422
1503
is_leader_block = _collective_mma_predicate (ctx , collective_axis )
@@ -1432,6 +1513,8 @@ def _tcgen05_mma_lowering(
1432
1513
b_ref ,
1433
1514
a_swizzle = int (lhs_swizzle ),
1434
1515
b_swizzle = int (rhs_swizzle ),
1516
+ a_scale = a_scale_ref ,
1517
+ b_scale = b_scale_ref ,
1435
1518
accumulate = accumulate ,
1436
1519
collective = collective ,
1437
1520
)
@@ -2225,3 +2308,60 @@ def _async_store_tmem_lowering_rule(
2225
2308
)
2226
2309
x_tmem .store (value )
2227
2310
return ()
2311
+
2312
+
2313
+ async_copy_scales_to_tmem_p = jax_core .Primitive ("async_copy_scales_to_tmem" )
2314
+ async_copy_scales_to_tmem_p .multiple_results = True
2315
+
2316
+ def async_copy_scales_to_tmem (smem_ref : _Ref , tmem_ref : _Ref ):
2317
+ """Copies the MMA scales from SMEM to TMEM.
2318
+
2319
+ The copy is performed asynchronously and can be awaited by calling
2320
+ ``tcgen05_commit_arrive`` and waiting on the specified barrier. However, if
2321
+ the copy is consumed by an MMA operation issued in the same thread, no
2322
+ synchronization is necessary (except for eventually awaiting the MMA operation
2323
+ itself).
2324
+ """
2325
+ smem_ref , smem_transforms = state_primitives .get_ref_and_transforms (
2326
+ smem_ref , None , "async_copy_scales_to_tmem" , force_trailing_indexer = True ,
2327
+ )
2328
+ flat_smem_transforms , smem_transforms_treedef = tree_util .tree_flatten (
2329
+ smem_transforms
2330
+ )
2331
+ tmem_ref , tmem_transforms = state_primitives .get_ref_and_transforms (
2332
+ tmem_ref , None , "async_copy_scales_to_tmem" , force_trailing_indexer = True ,
2333
+ )
2334
+ flat_tmem_transforms , tmem_transforms_treedef = tree_util .tree_flatten (
2335
+ tmem_transforms
2336
+ )
2337
+ async_copy_scales_to_tmem_p .bind (
2338
+ smem_ref , tmem_ref , * flat_smem_transforms , * flat_tmem_transforms ,
2339
+ smem_tree = smem_transforms_treedef , tmem_tree = tmem_transforms_treedef ,
2340
+ )
2341
+
2342
+
2343
+ @async_copy_scales_to_tmem_p .def_effectful_abstract_eval
2344
+ def _async_copy_scales_to_tmem_abstract_eval (smem_ref , tmem_ref , * avals_flat , smem_tree , tmem_tree ):
2345
+ if smem_ref .memory_space != gpu_core .MemorySpace .SMEM :
2346
+ raise ValueError ("async_copy_scales_to_tmem source must be an SMEM ref" )
2347
+ if tmem_ref .memory_space != gpu_core .MemorySpace .TMEM :
2348
+ raise ValueError ("async_copy_scales_to_tmem target must be a TMEM ref" )
2349
+ return (), {gpu_core ._memory_effect }
2350
+
2351
+
2352
+ @lowering .register_lowering_rule (async_copy_scales_to_tmem_p , mgpu .LoweringSemantics .Lane )
2353
+ def _async_copy_scales_to_tmem_lowering_rule (
2354
+ ctx : lowering .LoweringRuleContext , smem_ref , tmem_ref , * leaves , smem_tree , tmem_tree
2355
+ ):
2356
+ assert isinstance (tmem_ref , tcgen05 .TMEMRef )
2357
+ smem_leaves , tmem_leaves = util .split_list (leaves , [smem_tree .num_leaves ])
2358
+ smem_transforms = jax .tree .unflatten (smem_tree , smem_leaves )
2359
+ tmem_transforms = jax .tree .unflatten (tmem_tree , tmem_leaves )
2360
+ smem_ref , smem_transforms = lowering ._handle_transforms (ctx , smem_ref , smem_transforms )
2361
+ tmem_ref , tmem_transforms = lowering ._handle_transforms (ctx , tmem_ref , tmem_transforms )
2362
+ if smem_transforms :
2363
+ raise NotImplementedError (f"Unimplemented transforms for SMEM refs: { smem_transforms } " )
2364
+ if tmem_transforms :
2365
+ raise NotImplementedError (f"Unimplemented transforms for TMEM refs: { tmem_transforms } " )
2366
+ tcgen05 .async_copy_scales_smem_to_tmem (smem_ref , tmem_ref )
2367
+ return ()
0 commit comments