Skip to content

Commit b70a88a

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Expose block-scaled MMA to Pallas on Blackwell
PiperOrigin-RevId: 781488938
1 parent 68703c7 commit b70a88a

File tree

5 files changed

+241
-20
lines changed

5 files changed

+241
-20
lines changed

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ pytype_strict_library(
108108
":lowering",
109109
"//jax",
110110
"//jax:core",
111+
"//jax:dtypes",
111112
"//jax:lax",
112113
"//jax:mosaic_gpu",
113114
"//jax:pretty_printer",

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 157 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import jax
2727
from jax._src import core as jax_core
28+
from jax._src import dtypes
2829
from jax._src import pretty_printer as pp
2930
from jax._src import state
3031
from jax._src import tree_util
@@ -1164,6 +1165,9 @@ def tcgen05_mma(acc: _Ref,
11641165
a: _Ref,
11651166
b: _Ref,
11661167
barrier: _Ref | None = None,
1168+
*,
1169+
a_scale: _Ref | None = None,
1170+
b_scale: _Ref | None = None,
11671171
accumulate: bool | jax.Array = True,
11681172
collective_axis: str | None = None):
11691173
"""Asynchronous matrix-multiply accumulate for TensorCore gen 5 (Blackwell).
@@ -1178,6 +1182,9 @@ def tcgen05_mma(acc: _Ref,
11781182
| ACC2 | | LHS2 | | | |
11791183
----------- ----------- -----------
11801184
1185+
To use the block-scaled matrix-multiply, provide `a_scale` and `b_scale`
1186+
operands (they must be both present or both unspecified).
1187+
11811188
Args:
11821189
acc: The accumulator. Must be a TMEM Ref.
11831190
a: The left-hand side. Must be a TMEM/SMEM Ref.
@@ -1186,6 +1193,8 @@ def tcgen05_mma(acc: _Ref,
11861193
Must have orders_tensor_core set to True. If not specified, the MMA
11871194
completion should be explicitly observed by calling
11881195
`tcgen05_commit_arrive`
1196+
a_scale: An optional scale for the ``a`` operand. Must be a TMEM Ref if present.
1197+
b_scale: An optional scale for the ``b`` operand. Must be a TMEM Ref if present.
11891198
accumulate: Whether to accumulate into acc or overwrite it.
11901199
collective_axis: The name of the cluster axis along which to perform
11911200
a collective MMA. The cluster axis should have a size of exactly 2,
@@ -1225,6 +1234,28 @@ def tcgen05_mma(acc: _Ref,
12251234
else:
12261235
b_transforms_leaves, b_transforms_tree = [], None
12271236

1237+
if (is_scaled := a_scale is not None) != (b_scale is not None):
1238+
raise ValueError("a_scale and b_scale must both be present or absent.")
1239+
scales = []
1240+
if isinstance(a_scale, pallas_core.TransformedRef):
1241+
a_scale_transforms_leaves, a_scale_transforms_tree = jax.tree.flatten(
1242+
a_scale.transforms
1243+
)
1244+
scales.append(a_scale.ref)
1245+
else:
1246+
a_scale_transforms_leaves, a_scale_transforms_tree = [], None
1247+
scales.append(a_scale)
1248+
if isinstance(b_scale, pallas_core.TransformedRef):
1249+
b_scale_transforms_leaves, b_scale_transforms_tree = jax.tree.flatten(
1250+
b_scale.transforms
1251+
)
1252+
scales.append(b_scale.ref)
1253+
else:
1254+
b_scale_transforms_leaves, b_scale_transforms_tree = [], None
1255+
scales.append(b_scale)
1256+
if not is_scaled:
1257+
scales = []
1258+
12281259
if isinstance(barrier, pallas_core.TransformedRef):
12291260
barrier_transforms_leaves, barrier_transforms_tree = jax.tree.flatten(
12301261
barrier.transforms
@@ -1240,26 +1271,33 @@ def tcgen05_mma(acc: _Ref,
12401271
barrier_ref = []
12411272
arrive = False
12421273

1243-
tcgen05_mma_p.bind(acc, a, b, accumulate, *barrier_ref,
1274+
tcgen05_mma_p.bind(acc, a, b, accumulate, *barrier_ref, *scales,
12441275
*acc_transforms_leaves, *a_transforms_leaves,
12451276
*b_transforms_leaves,
12461277
*barrier_transforms_leaves,
1278+
*a_scale_transforms_leaves, *b_scale_transforms_leaves,
12471279
acc_transforms_tree=acc_transforms_tree,
12481280
a_transforms_tree=a_transforms_tree,
12491281
b_transforms_tree=b_transforms_tree,
12501282
barrier_transforms_tree=barrier_transforms_tree,
1283+
a_scale_transforms_tree=a_scale_transforms_tree,
1284+
b_scale_transforms_tree=b_scale_transforms_tree,
12511285
collective_axis=collective_axis,
1252-
arrive=arrive)
1286+
arrive=arrive,
1287+
scaled=bool(scales))
12531288

12541289

12551290
@tcgen05_mma_p.def_abstract_eval
12561291
def _tcgen05_mma_abstract_eval(acc, a, b, accumulate,
1257-
*barrier_and_transforms_leaves,
1292+
*barrier_scales_and_transforms_leaves,
12581293
acc_transforms_tree, a_transforms_tree,
12591294
b_transforms_tree,
12601295
barrier_transforms_tree,
1296+
a_scale_transforms_tree,
1297+
b_scale_transforms_tree,
12611298
collective_axis,
1262-
arrive):
1299+
arrive,
1300+
scaled):
12631301
del (accumulate, acc_transforms_tree,
12641302
a_transforms_tree, b_transforms_tree, barrier_transforms_tree)
12651303

@@ -1281,12 +1319,19 @@ def _tcgen05_mma_abstract_eval(acc, a, b, accumulate,
12811319
raise ValueError(
12821320
"LHS Ref must be collective if collective_axis is set.")
12831321

1322+
scales_and_transforms_leaves = barrier_scales_and_transforms_leaves
12841323
if arrive:
1285-
barrier = barrier_and_transforms_leaves[0]
1324+
barrier, *scales_and_transforms_leaves = barrier_scales_and_transforms_leaves
12861325
orders_tensor_core = getattr(
12871326
barrier.inner_aval.dtype, "orders_tensor_core", False)
12881327
if not orders_tensor_core:
12891328
raise ValueError("MMA barrier must have orders_tensor_core set to True.")
1329+
if scaled:
1330+
a_scale, b_scale = scales_and_transforms_leaves[:2]
1331+
if a_scale.memory_space != gpu_core.TMEM:
1332+
raise ValueError("a_scale must be a TMEM Ref")
1333+
if b_scale.memory_space != gpu_core.TMEM:
1334+
raise ValueError("b_scale must be a TMEM Ref")
12901335

12911336
return []
12921337

@@ -1299,35 +1344,52 @@ def _tcgen05_mma_lowering(
12991344
a_ref,
13001345
b_ref,
13011346
accumulate: bool | ir.Value,
1302-
*barrier_and_transforms_leaves,
1347+
*barrier_scales_and_transforms_leaves,
13031348
acc_transforms_tree,
13041349
a_transforms_tree,
13051350
b_transforms_tree,
13061351
barrier_transforms_tree,
1352+
a_scale_transforms_tree,
1353+
b_scale_transforms_tree,
13071354
collective_axis,
13081355
arrive,
1356+
scaled: bool,
13091357
):
13101358
_, a_aval, b_aval, *_ = ctx.avals_in
13111359
lhs_swizzle: int | None = None
13121360
lhs_transpose: bool = False
13131361
if arrive:
1314-
barrier_ref, *transforms_leaves = barrier_and_transforms_leaves
1362+
barrier_ref, *scales_and_transforms_leaves = barrier_scales_and_transforms_leaves
13151363
else:
13161364
barrier_ref = None
1317-
transforms_leaves = barrier_and_transforms_leaves # type: ignore[assignment]
1365+
scales_and_transforms_leaves = barrier_scales_and_transforms_leaves # type: ignore[assignment]
1366+
if scaled:
1367+
a_scale_ref, b_scale_ref, *transforms_leaves = scales_and_transforms_leaves
1368+
else:
1369+
a_scale_ref = b_scale_ref = None
1370+
transforms_leaves = scales_and_transforms_leaves # type: ignore[assignment]
13181371

13191372
transforms_trees = (
13201373
acc_transforms_tree,
13211374
a_transforms_tree,
13221375
b_transforms_tree,
13231376
barrier_transforms_tree,
1324-
)
1325-
(acc_transforms_leaves, a_transforms_leaves, b_transforms_leaves, barrier_transforms_leaves, _) = (
1326-
util.split_list(
1327-
transforms_leaves,
1328-
[getattr(tree, "num_leaves", 0) for tree in transforms_trees],
1329-
)
1330-
)
1377+
a_scale_transforms_tree,
1378+
b_scale_transforms_tree,
1379+
)
1380+
(
1381+
acc_transforms_leaves,
1382+
a_transforms_leaves,
1383+
b_transforms_leaves,
1384+
barrier_transforms_leaves,
1385+
a_scale_transforms_leaves,
1386+
b_scale_transforms_leaves,
1387+
leftovers,
1388+
) = util.split_list(
1389+
transforms_leaves,
1390+
[getattr(tree, "num_leaves", 0) for tree in transforms_trees],
1391+
)
1392+
assert not leftovers
13311393

13321394
if acc_transforms_tree is not None:
13331395
acc_transforms = acc_transforms_tree.unflatten(acc_transforms_leaves)
@@ -1359,7 +1421,7 @@ def _tcgen05_mma_lowering(
13591421
f"Unsupported transforms: {a_transforms}."
13601422
)
13611423
if not isinstance(a_ref, tcgen05.TMEMRef):
1362-
swizzle_elems = lhs_swizzle // a_dtype.itemsize # type: ignore
1424+
swizzle_elems = 8 * lhs_swizzle // dtypes.bit_width(a_dtype) # type: ignore
13631425
if lhs_tiling != (8, swizzle_elems):
13641426
raise ValueError("MMA lhs tiling does not fit swizzle. "
13651427
f"{lhs_tiling=} expected={(8, swizzle_elems)}")
@@ -1383,7 +1445,7 @@ def _tcgen05_mma_lowering(
13831445
raise NotImplementedError(
13841446
f"Unsupported transforms: {b_transforms}."
13851447
)
1386-
swizzle_elems = rhs_swizzle // b_dtype.itemsize
1448+
swizzle_elems = 8 * rhs_swizzle // dtypes.bit_width(b_dtype)
13871449
if rhs_tiling != (8, swizzle_elems):
13881450
raise ValueError(
13891451
"MMA rhs tiling does not fit swizzle"
@@ -1417,6 +1479,25 @@ def _tcgen05_mma_lowering(
14171479
accumulate = accumulate.registers.item()
14181480
assert isinstance(accumulate, ir.Value)
14191481

1482+
if a_scale_transforms_tree is not None:
1483+
a_scale_transforms = a_scale_transforms_tree.unflatten(
1484+
a_scale_transforms_leaves
1485+
)
1486+
a_scale_ref, a_scale_transforms = lowering._handle_transforms(
1487+
ctx, a_scale_ref, a_scale_transforms
1488+
)
1489+
if a_scale_transforms:
1490+
raise NotImplementedError(f"Unsupported transforms: {a_scale_transforms}")
1491+
if b_scale_transforms_tree is not None:
1492+
b_scale_transforms = b_scale_transforms_tree.unflatten(
1493+
b_scale_transforms_leaves
1494+
)
1495+
b_scale_ref, b_scale_transforms = lowering._handle_transforms(
1496+
ctx, b_scale_ref, b_scale_transforms
1497+
)
1498+
if b_scale_transforms:
1499+
raise NotImplementedError(f"Unsupported transforms: {b_scale_transforms}")
1500+
14201501
predicate = ctx.module_ctx.single_lane_predicate
14211502
if collective_axis is not None:
14221503
is_leader_block = _collective_mma_predicate(ctx, collective_axis)
@@ -1432,6 +1513,8 @@ def _tcgen05_mma_lowering(
14321513
b_ref,
14331514
a_swizzle=int(lhs_swizzle),
14341515
b_swizzle=int(rhs_swizzle),
1516+
a_scale=a_scale_ref,
1517+
b_scale=b_scale_ref,
14351518
accumulate=accumulate,
14361519
collective=collective,
14371520
)
@@ -2225,3 +2308,60 @@ def _async_store_tmem_lowering_rule(
22252308
)
22262309
x_tmem.store(value)
22272310
return ()
2311+
2312+
2313+
async_copy_scales_to_tmem_p = jax_core.Primitive("async_copy_scales_to_tmem")
2314+
async_copy_scales_to_tmem_p.multiple_results = True
2315+
2316+
def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
2317+
"""Copies the MMA scales from SMEM to TMEM.
2318+
2319+
The copy is performed asynchronously and can be awaited by calling
2320+
``tcgen05_commit_arrive`` and waiting on the specified barrier. However, if
2321+
the copy is consumed by an MMA operation issued in the same thread, no
2322+
synchronization is necessary (except for eventually awaiting the MMA operation
2323+
itself).
2324+
"""
2325+
smem_ref, smem_transforms = state_primitives.get_ref_and_transforms(
2326+
smem_ref, None, "async_copy_scales_to_tmem", force_trailing_indexer=True,
2327+
)
2328+
flat_smem_transforms, smem_transforms_treedef = tree_util.tree_flatten(
2329+
smem_transforms
2330+
)
2331+
tmem_ref, tmem_transforms = state_primitives.get_ref_and_transforms(
2332+
tmem_ref, None, "async_copy_scales_to_tmem", force_trailing_indexer=True,
2333+
)
2334+
flat_tmem_transforms, tmem_transforms_treedef = tree_util.tree_flatten(
2335+
tmem_transforms
2336+
)
2337+
async_copy_scales_to_tmem_p.bind(
2338+
smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms,
2339+
smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef,
2340+
)
2341+
2342+
2343+
@async_copy_scales_to_tmem_p.def_effectful_abstract_eval
2344+
def _async_copy_scales_to_tmem_abstract_eval(smem_ref, tmem_ref, *avals_flat, smem_tree, tmem_tree):
2345+
if smem_ref.memory_space != gpu_core.MemorySpace.SMEM:
2346+
raise ValueError("async_copy_scales_to_tmem source must be an SMEM ref")
2347+
if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM:
2348+
raise ValueError("async_copy_scales_to_tmem target must be a TMEM ref")
2349+
return (), {gpu_core._memory_effect}
2350+
2351+
2352+
@lowering.register_lowering_rule(async_copy_scales_to_tmem_p, mgpu.LoweringSemantics.Lane)
2353+
def _async_copy_scales_to_tmem_lowering_rule(
2354+
ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree
2355+
):
2356+
assert isinstance(tmem_ref, tcgen05.TMEMRef)
2357+
smem_leaves, tmem_leaves = util.split_list(leaves, [smem_tree.num_leaves])
2358+
smem_transforms = jax.tree.unflatten(smem_tree, smem_leaves)
2359+
tmem_transforms = jax.tree.unflatten(tmem_tree, tmem_leaves)
2360+
smem_ref, smem_transforms = lowering._handle_transforms(ctx, smem_ref, smem_transforms)
2361+
tmem_ref, tmem_transforms = lowering._handle_transforms(ctx, tmem_ref, tmem_transforms)
2362+
if smem_transforms:
2363+
raise NotImplementedError(f"Unimplemented transforms for SMEM refs: {smem_transforms}")
2364+
if tmem_transforms:
2365+
raise NotImplementedError(f"Unimplemented transforms for TMEM refs: {tmem_transforms}")
2366+
tcgen05.async_copy_scales_smem_to_tmem(smem_ref, tmem_ref)
2367+
return ()

jax/_src/pallas/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax
2121
from jax import lax
2222
from jax._src import core as jax_core
23+
from jax._src import dtypes
2324
from jax._src.util import split_list
2425
import jax.numpy as jnp
2526
import numpy as np
@@ -62,10 +63,9 @@ def next_power_of_2(x: int) -> int:
6263
raise ValueError("`next_power_of_2` requires a non-negative integer.")
6364
return 1 if x == 0 else 2 ** (x - 1).bit_length()
6465

66+
# TODO(apaszke): Inline this function into all call sites
6567
def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int:
66-
if jnp.issubdtype(dtype, jnp.integer):
67-
return jnp.iinfo(dtype).bits
68-
return np.dtype(dtype).itemsize * 8
68+
return dtypes.bit_width(dtype)
6969

7070
def pattern_match_scan_to_fori_loop(
7171
jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType
3333
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
3434
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
35+
from jax._src.pallas.mosaic_gpu.core import TMEMLayout as TMEMLayout
3536
from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref
3637
from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref
3738
from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref
@@ -43,6 +44,7 @@
4344
from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop
4445
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
4546
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized
47+
from jax._src.pallas.mosaic_gpu.primitives import async_copy_scales_to_tmem as async_copy_scales_to_tmem
4648
from jax._src.pallas.mosaic_gpu.primitives import async_load_tmem as async_load_tmem
4749
from jax._src.pallas.mosaic_gpu.primitives import async_store_tmem as async_store_tmem
4850
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive

0 commit comments

Comments
 (0)