Skip to content

Commit 2b7b4d6

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Align TMEM allocations to 16 bytes
This does not seem to be documented very well, but many tcgen05 instructions seem to assume that the TMEM addresses they receive are aligned to 16-byte boundaries. PiperOrigin-RevId: 781488939
1 parent cc5bbb5 commit 2b7b4d6

File tree

5 files changed

+148
-104
lines changed

5 files changed

+148
-104
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 85 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
# sensitive to alignment and while this is quite conservative, it gets the job
5858
# done. We should make this more refined in the future.
5959
SMEM_ALIGNMENT = 1024
60+
TMEM_COL_ALIGNMENT = 4
6061

6162

6263
def is_trivial_index(idx, shape) -> bool:
@@ -146,11 +147,36 @@ def __call__(
146147
*,
147148
transforms: Sequence[MemoryRefTransform] = (),
148149
packed: bool | None = None,
149-
collective: bool | None = None
150+
collective: bool | None = None,
151+
layout: TMEMLayout | None = None,
150152
) -> pallas_core.MemoryRef:
151-
# A convenience function for constructing MemoryRef types.
153+
if self == MemorySpace.TMEM:
154+
if transforms:
155+
raise ValueError("transforms are not supported for TMEM")
156+
if collective is None:
157+
collective = False
158+
if layout is None:
159+
if packed is None:
160+
if dtypes.bit_width(dtype) != 32:
161+
raise ValueError(
162+
"dtypes narrower than 32-bit require either the packed argument"
163+
" or an explicit TMEM layout"
164+
)
165+
packed = False
166+
layout = infer_tmem_layout(
167+
shape, dtype, packed=packed, collective=collective
168+
)
169+
else:
170+
if packed is not None:
171+
raise ValueError("packed cannot be specified if layout is specified.")
172+
# We allow tcgen05.TMEMLayout to be passed in from our internal APIs.
173+
if not isinstance(layout, tcgen05.TMEMLayout):
174+
layout = layout.to_mgpu()
175+
else:
176+
if packed is not None or collective is not None or layout is not None:
177+
raise ValueError("packed, collective and layout arguments are only supported for TMEM.")
152178
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms,
153-
packed=packed, collective=collective)
179+
layout=layout, collective=collective)
154180

155181

156182
class SemaphoreType(enum.Enum):
@@ -223,38 +249,26 @@ def cmap_body():
223249
class GPUMemoryRef(pallas_core.MemoryRef):
224250
transforms: Sequence[MemoryRefTransform] = ()
225251

226-
# Whether to allow TMEM packing for sub 32-bit dtypes.
227-
packed: bool | None = dataclasses.field(default=None, kw_only=True)
252+
layout: tcgen05.TMEMLayout | None = dataclasses.field(default=None, kw_only=True)
228253
collective: bool | None = dataclasses.field(default=None, kw_only=True)
229254

230255
def __post_init__(self):
231-
if self.memory_space == MemorySpace.TMEM:
232-
if dtypes.bit_width(self.dtype) < 32 and self.packed is None:
233-
raise ValueError(
234-
"Packed option must be specified for sub-32 bit dtypes.")
235-
else:
236-
if self.packed is not None:
237-
raise ValueError("Packed option is only supported for TMEM.")
238-
if self.collective is not None:
239-
raise ValueError("Collective option is only supported for TMEM.")
256+
is_tmem = self.memory_space == MemorySpace.TMEM
257+
assert (self.layout is not None) == is_tmem
258+
assert (self.collective is not None) == is_tmem
259+
assert not (self.transforms and is_tmem)
240260

241261
def get_ref_aval(self) -> _Ref:
242262
aval = jax_core.ShapedArray(self.shape, self.dtype)
243263
for t in self.transforms:
244264
aval = t(aval)
245265
if self.memory_space == MemorySpace.TMEM:
246-
collective = self.collective if self.collective is not None else False
247-
packed = self.packed if self.packed is not None else False
248-
ref = pallas_core.TransformedRef(
249-
AbstractTMEMRef(aval,
250-
memory_space=self.memory_space,
251-
packed=packed,
252-
collective=collective), ()
266+
aval = AbstractTMEMRef(
267+
aval, self.memory_space, self.layout, self.collective
253268
)
254269
else:
255-
ref = pallas_core.TransformedRef(
256-
state.AbstractRef(aval, memory_space=self.memory_space), ()
257-
)
270+
aval = state.AbstractRef(aval, memory_space=self.memory_space)
271+
ref = pallas_core.TransformedRef(aval, ())
258272
for t in reversed(self.transforms):
259273
ref = t.undo(ref)
260274
if not ref.transforms:
@@ -295,32 +309,23 @@ def _ref_group_tmem_col_size(refs: _GPUMemoryRefTree) -> int:
295309
"""
296310
ncols = 0
297311
for ref in jax.tree.leaves(refs):
298-
ncols += infer_tmem_cols_layout(ref.shape, ref.dtype,
299-
collective=ref.collective,
300-
packed=ref.packed)[0]
312+
ref_ncols = ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))
313+
ncols += align_to(ref_ncols, TMEM_COL_ALIGNMENT)
301314
return ncols
302315

303316

304-
def infer_tmem_cols_layout(
317+
def infer_tmem_layout(
305318
shape: tuple[int, ...],
306319
dtype: jnp.dtype,
307320
*,
308321
packed: bool,
309-
collective: bool,
310-
layout: tcgen05.TMEMLayout | None = None) -> tuple[int, tcgen05.TMEMLayout]:
322+
collective: bool) -> tcgen05.TMEMLayout:
311323
"""Infers the number of columns used and layout for allocating TMEM Refs."""
312324
if packed:
313325
packing = 32 // dtypes.bit_width(dtype)
314326
else:
315327
packing = 1
316-
if layout is None:
317-
layout = tcgen05._infer_tmem_layout(shape, # type: ignore[arg-type]
318-
collective=collective,
319-
packing=packing)
320-
with ir.Context():
321-
ir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
322-
cols_used = layout.cols_in_shape(shape, ir_dtype) # type: ignore[arg-type]
323-
return cols_used, layout
328+
return tcgen05._infer_tmem_layout(shape, collective=collective, packing=packing)
324329

325330

326331
def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
@@ -363,13 +368,12 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
363368
for ref_group in ref_union.refs:
364369
col_offset = 0
365370
for ref in jax.tree.leaves(ref_group):
371+
col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT)
366372
if not isinstance(ref, pallas_core.TransformedRef):
367373
ref = pallas_core.TransformedRef(ref, transforms=())
368-
ncols, _ = infer_tmem_cols_layout(
369-
ref.shape, ref.dtype, # type: ignore[arg-type]
370-
packed=ref.packed, collective=ref.collective)
374+
ncols = ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))
371375
transform = ExtractAliasedRef.from_transformed_ref(
372-
ref, col_offset, packed=ref.packed, collective=ref.collective)
376+
ref, col_offset, layout=ref.layout)
373377
flat_refs.append(
374378
pallas_core.TransformedRef(
375379
ref_union, transforms=(transform, *ref.transforms)
@@ -409,19 +413,23 @@ def update(self, inner_aval=None, memory_space=None):
409413
ref = super().update(inner_aval, memory_space)
410414
return AbstractRefUnion(ref.inner_aval, self.refs, self.memory_space)
411415

416+
@functools.cached_property
417+
def layout(self) -> tcgen05.TMEMLayout:
418+
if self.memory_space != TMEM:
419+
raise ValueError("layout attribute is only defined for TMEM refs")
420+
return tcgen05.tmem_default_layout(packing=1)
421+
412422
@functools.cached_property
413423
def collective(self) -> bool:
414424
if self.memory_space != TMEM:
415-
raise ValueError("Collective is only supported for TMEM.")
425+
raise ValueError("collective attribute is only defined for TMEM refs")
416426
ref_leaves = jax.tree.leaves(self.refs)
417427
first_ref = ref_leaves[0]
418-
# Check if all Refs have the same collective attribute.
419-
if not all(ref.collective == first_ref.collective for ref in ref_leaves):
420-
raise ValueError(f"All Refs must be either collective/not collective."
421-
f" Got: {[ref.collective for ref in ref_leaves]}")
428+
assert all(ref.collective == first_ref.collective for ref in ref_leaves)
422429
return first_ref.collective
423430

424431

432+
425433
@dataclasses.dataclass(init=False, frozen=True)
426434
class RefUnion(GPUMemoryRef):
427435
"""A sequence of trees of refs that are allowed to reuse the same memory.
@@ -450,11 +458,18 @@ def __init__(self, *refs: _GPUMemoryRefTree):
450458
elif all(ref.memory_space == TMEM for ref in ref_leaves):
451459
object.__setattr__(self, "refs", refs)
452460
max_cols = max(map(_ref_group_tmem_col_size, self.refs))
461+
is_collective = ref_leaves[0].collective
462+
if any(r.collective != is_collective for r in ref_leaves):
463+
raise ValueError(
464+
"Some aliased TMEM references are collective and some are not."
465+
)
453466
super().__init__(
454467
shape=(128, max_cols,),
455468
dtype=jnp.int32,
456469
memory_space=TMEM,
457470
transforms=(),
471+
layout=tcgen05.tmem_default_layout(packing=1),
472+
collective=all(ref.collective for ref in ref_leaves),
458473
)
459474
else:
460475
raise NotImplementedError(
@@ -752,20 +767,16 @@ class ExtractAliasedRef(state_types.Transform):
752767
shape: tuple[int, ...]
753768
offset: int
754769
# TMEM-specific params
755-
packed: bool | None
756-
collective: bool | None
770+
layout: tcgen05.TMEMLayout | None
757771

758772
@classmethod
759773
def from_transformed_ref(
760-
cls, ref: pallas_core.TransformedRef, byte_offset: int,
761-
packed: bool | None = None,
762-
collective: bool | None = None,
774+
cls,
775+
ref: pallas_core.TransformedRef,
776+
byte_offset: int,
777+
layout: tcgen05.TMEMLayout | None = None,
763778
):
764-
return cls(
765-
dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset,
766-
packed=packed,
767-
collective=collective,
768-
)
779+
return cls(dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset, layout)
769780

770781
def transform_shape(self, shape):
771782
if shape is None:
@@ -777,8 +788,7 @@ def transform_dtype(self, dtype):
777788
return self.dtype
778789

779790
def tree_flatten(self):
780-
return (), (self.dtype, self.shape, self.offset,
781-
self.packed, self.collective)
791+
return (), (self.dtype, self.shape, self.offset, self.layout)
782792

783793
@classmethod
784794
def tree_unflatten(cls, metadata, arrays):
@@ -1040,20 +1050,20 @@ def _getitem(self, tracer, idx):
10401050

10411051

10421052
class AbstractTMEMRef(state.AbstractRef):
1043-
__slots__ = ["inner_aval", "memory_space", "packed", "collective"]
1053+
__slots__ = ["inner_aval", "memory_space", "layout", "collective"]
10441054

1045-
def __init__(self, inner_aval, memory_space, packed, collective):
1055+
def __init__(self, inner_aval, memory_space, layout, collective):
10461056
super().__init__(inner_aval, memory_space)
1047-
self.packed = packed
1057+
self.layout = layout
10481058
self.collective = collective
10491059

10501060
def __repr__(self) -> str:
1051-
return f'TMEM({self.inner_aval.str_short()},packed={self.packed})'
1061+
return f'TMEM({self.inner_aval.str_short()}, layout={self.layout}, collective={self.collective})'
10521062

10531063
def update(self, inner_aval=None, memory_space=None):
10541064
ref = super().update(inner_aval, memory_space)
10551065
return AbstractTMEMRef(
1056-
ref.inner_aval, ref.memory_space, self.packed, self.collective
1066+
ref.inner_aval, ref.memory_space, self.layout, self.collective
10571067
)
10581068

10591069

@@ -1246,6 +1256,7 @@ def to_mgpu(self) -> mgpu.FragmentedLayout:
12461256
raise ValueError("Only TiledLayout supports reductions.")
12471257
return layout.reduce(self.axes)
12481258

1259+
12491260
class Layout(SomeLayout, enum.Enum):
12501261
#: [m, n] matrix, where m % 64 == 0 == n % 8.
12511262
WGMMA = enum.auto()
@@ -1297,3 +1308,13 @@ def check_no_args():
12971308
Layout.TCGEN05_ROW = Layout.TCGEN05.reduce(1)
12981309
Layout.TCGEN05_COL = Layout.TCGEN05.reduce(0)
12991310
Layout.TCGEN05_TMEM_NATIVE_ROW = Layout.TCGEN05_TMEM_NATIVE.reduce(1)
1311+
1312+
1313+
class TMEMLayout(enum.Enum):
1314+
"""Layout for TMEM references."""
1315+
SCALES_LAYOUT = enum.auto()
1316+
1317+
def to_mgpu(self) -> mgpu.FragmentedLayout:
1318+
match self:
1319+
case TMEMLayout.SCALES_LAYOUT:
1320+
return tcgen05.scales_layout()

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -305,17 +305,14 @@ def _run_scoped_resource_estimator(
305305
if aval.memory_space == gpu_core.TMEM:
306306
if len(aval.shape) != 2:
307307
raise ValueError(f"TMEM allocations must be 2D. Got {aval.shape}")
308-
if aval.shape[0] not in (64, 128):
309-
raise ValueError(
310-
f"TMEM shape[0] must be 64 or 128. Got {aval.shape[0]}.")
311308
# Estimate columns used.
312309
if isinstance(aval, gpu_core.AbstractRefUnion):
313310
assert aval.shape[0] == 128
314311
cols_used = aval.shape[1]
315312
else:
316-
cols_used, _ = gpu_core.infer_tmem_cols_layout(
317-
aval.shape, aval.dtype, packed=aval.packed, collective=aval.collective)
318-
cols_used = tcgen05._alloc_ncols(cols_used, exact=False)
313+
cols_used = aval.layout.cols_in_shape(
314+
aval.shape, dtypes.bit_width(aval.dtype)
315+
)
319316
if aval.collective:
320317
rs += Resources(tmem_collective_scratch_cols=cols_used)
321318
else:
@@ -443,10 +440,7 @@ def alloc_tmem(
443440
*,
444441
layout: tcgen05.TMEMLayout | None = None,
445442
collective: bool = False,
446-
packed: bool = False,
447443
) -> Iterator[ir.Value]:
448-
cols_used, layout = gpu_core.infer_tmem_cols_layout(
449-
struct.shape, struct.dtype, packed=packed, collective=collective, layout=layout)
450444
if collective:
451445
off = arith_dialect.addi(
452446
self.tmem_collective_base_ptr,
@@ -461,6 +455,10 @@ def alloc_tmem(
461455
shape=struct.shape,
462456
dtype=mgpu_utils.dtype_to_ir_type(struct.dtype),
463457
layout=layout)
458+
cols_used = layout.cols_in_shape(
459+
struct.shape, dtypes.bit_width(struct.dtype)
460+
)
461+
cols_used = gpu_core.align_to(cols_used, gpu_core.TMEM_COL_ALIGNMENT)
464462
if collective:
465463
self.tmem_collective_used_cols += cols_used
466464
yield tmem_ref
@@ -745,7 +743,9 @@ def ref_for_aval(aval: ShapedAbstractValue):
745743
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
746744
return gpu_core.WGMMAAccumulatorRef(aval.shape, aval.dtype)
747745
elif isinstance(aval, gpu_core.AbstractTMEMRef):
748-
return gpu_core.TMEM(aval.shape, aval.dtype, packed=aval.packed)
746+
return gpu_core.TMEM(
747+
aval.shape, aval.dtype, layout=aval.layout, collective=aval.collective
748+
)
749749
elif isinstance(aval, state_types.AbstractRef):
750750
return pallas_core.MemoryRef(aval.shape, aval.dtype, aval.memory_space)
751751
else:
@@ -1309,35 +1309,27 @@ def _extract_aliased_ref(
13091309
match transforms:
13101310
case (
13111311
gpu_core.ExtractAliasedRef(
1312-
dtype, transformed_shape, offset, packed, collective
1312+
dtype, transformed_shape, offset, layout
13131313
),
13141314
*other_transforms,
13151315
):
13161316
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
13171317
if isinstance(ref, tcgen05.TMEMRef):
1318-
assert packed is not None
1319-
assert collective is not None
1318+
assert layout is not None
13201319
if ref.shape[0] != transformed_shape[0]:
13211320
raise ValueError(
13221321
"TMEM aliasing only supported for Refs with the same first"
13231322
f" dimension, got {ref.shape[0]} != {transformed_shape[0]}."
13241323
)
13251324
address = arith_dialect.addi(ref.address, _i32_constant(offset))
1326-
_, tmem_layout = gpu_core.infer_tmem_cols_layout(
1327-
transformed_shape, dtype, packed=packed, collective=collective
1328-
)
13291325
ref = tcgen05.TMEMRef(
1330-
address=address,
1331-
shape=transformed_shape,
1332-
dtype=mgpu_utils.dtype_to_ir_type(dtype),
1333-
layout=tmem_layout,
1334-
)
1326+
address=address,
1327+
shape=transformed_shape,
1328+
dtype=mgpu_utils.dtype_to_ir_type(dtype),
1329+
layout=layout)
13351330
else:
1336-
assert packed is None
1337-
assert collective is None
1338-
ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(
1339-
mlir_dtype
1340-
)
1331+
assert layout is None
1332+
ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(mlir_dtype)
13411333
if ref_bits % 8:
13421334
raise NotImplementedError("Only byte-aligned bitcasts are supported.")
13431335
assert offset % gpu_core.SMEM_ALIGNMENT == 0
@@ -2546,14 +2538,10 @@ def _run_scoped_lowering_rule(
25462538
input_refs.append(input_ref)
25472539
should_discharge.append(False)
25482540
elif aval.memory_space == gpu_core.TMEM:
2549-
if isinstance(aval, gpu_core.AbstractRefUnion):
2550-
packed = False
2551-
else:
2552-
packed = aval.packed
25532541
input_ref = alloc_stack.enter_context(
25542542
ctx.module_ctx.alloc_tmem(
25552543
jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype),
2556-
packed=packed,
2544+
layout=aval.layout,
25572545
collective=aval.collective,
25582546
)
25592547
)

0 commit comments

Comments
 (0)