Skip to content

Commit a58d72c

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Align TMEM allocations to 16 bytes
This does not seem to be documented very well, but many tcgen05 instructions seem to assume that the TMEM addresses they receive are aligned to 16-byte boundaries. PiperOrigin-RevId: 781488939
1 parent 70cdf17 commit a58d72c

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
# sensitive to alignment and while this is quite conservative, it gets the job
5858
# done. We should make this more refined in the future.
5959
SMEM_ALIGNMENT = 1024
60+
TMEM_COL_ALIGNMENT = 4
6061

6162

6263
def is_trivial_index(idx, shape) -> bool:
@@ -307,7 +308,8 @@ def _ref_group_tmem_col_size(refs: _GPUMemoryRefTree) -> int:
307308
"""
308309
ncols = 0
309310
for ref in jax.tree.leaves(refs):
310-
ncols += ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))
311+
ref_ncols = ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))
312+
ncols += align_to(ref_ncols, TMEM_COL_ALIGNMENT)
311313
return ncols
312314

313315

@@ -365,6 +367,7 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
365367
for ref_group in ref_union.refs:
366368
col_offset = 0
367369
for ref in jax.tree.leaves(ref_group):
370+
col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT)
368371
if not isinstance(ref, pallas_core.TransformedRef):
369372
ref = pallas_core.TransformedRef(ref, transforms=())
370373
ncols = ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,6 @@ def _run_scoped_resource_estimator(
305305
if aval.memory_space == gpu_core.TMEM:
306306
if len(aval.shape) != 2:
307307
raise ValueError(f"TMEM allocations must be 2D. Got {aval.shape}")
308-
if aval.shape[0] not in (64, 128):
309-
raise ValueError(
310-
f"TMEM shape[0] must be 64 or 128. Got {aval.shape[0]}.")
311308
# Estimate columns used.
312309
if isinstance(aval, gpu_core.AbstractRefUnion):
313310
assert aval.shape[0] == 128
@@ -316,8 +313,6 @@ def _run_scoped_resource_estimator(
316313
cols_used = aval.layout.cols_in_shape(
317314
aval.shape, dtypes.bit_width(aval.dtype)
318315
)
319-
# TODO(apaszke): Remove this. We only need to align the outermost allocation.
320-
cols_used = tcgen05._alloc_ncols(cols_used, exact=False)
321316
if aval.collective:
322317
rs += Resources(tmem_collective_scratch_cols=cols_used)
323318
else:
@@ -463,6 +458,7 @@ def alloc_tmem(
463458
cols_used = layout.cols_in_shape(
464459
struct.shape, dtypes.bit_width(struct.dtype)
465460
)
461+
cols_used = gpu_core.align_to(cols_used, gpu_core.TMEM_COL_ALIGNMENT)
466462
if collective:
467463
self.tmem_collective_used_cols += cols_used
468464
yield tmem_ref

tests/pallas/mosaic_gpu_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,6 +2929,39 @@ def kernel(a_smem, b_smem, out_ref, acc_tmem, scratch_smem, barrier_ref,
29292929
expected = x @ y
29302930
np.testing.assert_allclose(result, expected, rtol=1e-3)
29312931

2932+
def test_matmul_alignment(self):
2933+
self.skip_if_wg_semantics()
2934+
m = k = n = 128
2935+
dtype = jnp.float16
2936+
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
2937+
2938+
def kernel(a_smem, b_smem, out_ref, _, acc_tmem, barrier_ref):
2939+
plgpu.tcgen05_mma(acc_tmem, a_smem, b_smem, barrier_ref, accumulate=False)
2940+
plgpu.barrier_wait(barrier_ref)
2941+
# We don't await the load because acc_tmem is never modified again.
2942+
out_ref[...] = plgpu.async_load_tmem(acc_tmem).astype(dtype)
2943+
2944+
spec = plgpu.BlockSpec(transforms=transforms, memory_space=plgpu.SMEM)
2945+
f = self.pallas_call(
2946+
kernel,
2947+
in_specs=(spec, spec),
2948+
out_specs=spec,
2949+
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
2950+
# Add a one column space to test if we align the accumulator.
2951+
scratch_shapes=(
2952+
plgpu.TMEM((128, 1), jnp.float32),
2953+
plgpu.TMEM((m, n), jnp.float32),
2954+
plgpu.Barrier(orders_tensor_core=True),
2955+
),
2956+
)
2957+
lhs_shape = (m, k)
2958+
rhs_shape = (k, n)
2959+
x = jax.random.uniform(jax.random.key(0), shape=lhs_shape, dtype=dtype)
2960+
y = jax.random.uniform(jax.random.key(1), shape=rhs_shape, dtype=dtype)
2961+
result = f(x, y)
2962+
expected = x @ y
2963+
np.testing.assert_allclose(result, expected, rtol=1e-3)
2964+
29322965
@parameterized.parameters(
29332966
(128, jnp.float16)
29342967
)

0 commit comments

Comments
 (0)