Skip to content

Commit 9252315

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Disallow direct reads from TMEM refs
Leaving this piece of code there was an oversight in the CL that introduced plgpu.async_load_tmem. PiperOrigin-RevId: 781505693
1 parent 88af892 commit 9252315

File tree

4 files changed

+115
-114
lines changed

4 files changed

+115
-114
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,35 @@ def __call__(
146146
*,
147147
transforms: Sequence[MemoryRefTransform] = (),
148148
packed: bool | None = None,
149-
collective: bool | None = None
149+
collective: bool | None = None,
150+
layout: TMEMLayout | None = None,
150151
) -> pallas_core.MemoryRef:
151-
# A convenience function for constructing MemoryRef types.
152+
if self == MemorySpace.TMEM:
153+
if transforms:
154+
raise ValueError("transforms are not supported for TMEM")
155+
if collective is None:
156+
collective = False
157+
if layout is None:
158+
if packed is None:
159+
if dtypes.bit_width(dtype) != 32:
160+
raise ValueError(
161+
"dtypes narrower than 32-bit require either the packed argument"
162+
" or an explicit TMEM layout"
163+
)
164+
packed = False
165+
mgpu_layout = infer_tmem_layout(
166+
shape, dtype, packed=packed, collective=collective
167+
)
168+
else:
169+
if packed is not None:
170+
raise ValueError("packed cannot be specified if layout is specified.")
171+
mgpu_layout = layout.to_mgpu()
172+
else:
173+
if packed is not None or collective is not None or layout is not None:
174+
raise ValueError("packed, collective and layout arguments are only supported for TMEM.")
175+
mgpu_layout = None
152176
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms,
153-
packed=packed, collective=collective)
177+
layout=mgpu_layout, collective=collective)
154178

155179

156180
class SemaphoreType(enum.Enum):
@@ -223,38 +247,26 @@ def cmap_body():
223247
class GPUMemoryRef(pallas_core.MemoryRef):
224248
transforms: Sequence[MemoryRefTransform] = ()
225249

226-
# Whether to allow TMEM packing for sub 32-bit dtypes.
227-
packed: bool | None = dataclasses.field(default=None, kw_only=True)
250+
layout: tcgen05.TMEMLayout | None = dataclasses.field(default=None, kw_only=True)
228251
collective: bool | None = dataclasses.field(default=None, kw_only=True)
229252

230253
def __post_init__(self):
231-
if self.memory_space == MemorySpace.TMEM:
232-
if dtypes.bit_width(self.dtype) < 32 and self.packed is None:
233-
raise ValueError(
234-
"Packed option must be specified for sub-32 bit dtypes.")
235-
else:
236-
if self.packed is not None:
237-
raise ValueError("Packed option is only supported for TMEM.")
238-
if self.collective is not None:
239-
raise ValueError("Collective option is only supported for TMEM.")
254+
is_tmem = self.memory_space == MemorySpace.TMEM
255+
assert (self.layout is not None) == is_tmem
256+
assert (self.collective is not None) == is_tmem
257+
assert not (self.transforms and is_tmem)
240258

241259
def get_ref_aval(self) -> _Ref:
242-
aval = jax_core.ShapedArray(self.shape, self.dtype)
260+
aval: Any = jax_core.ShapedArray(self.shape, self.dtype)
243261
for t in self.transforms:
244262
aval = t(aval)
245263
if self.memory_space == MemorySpace.TMEM:
246-
collective = self.collective if self.collective is not None else False
247-
packed = self.packed if self.packed is not None else False
248-
ref = pallas_core.TransformedRef(
249-
AbstractTMEMRef(aval,
250-
memory_space=self.memory_space,
251-
packed=packed,
252-
collective=collective), ()
264+
aval = AbstractTMEMRef(
265+
aval, self.memory_space, self.layout, self.collective
253266
)
254267
else:
255-
ref = pallas_core.TransformedRef(
256-
state.AbstractRef(aval, memory_space=self.memory_space), ()
257-
)
268+
aval = state.AbstractRef(aval, memory_space=self.memory_space)
269+
ref = pallas_core.TransformedRef(aval, ())
258270
for t in reversed(self.transforms):
259271
ref = t.undo(ref)
260272
if not ref.transforms:
@@ -295,32 +307,22 @@ def _ref_group_tmem_col_size(refs: _GPUMemoryRefTree) -> int:
295307
"""
296308
ncols = 0
297309
for ref in jax.tree.leaves(refs):
298-
ncols += infer_tmem_cols_layout(ref.shape, ref.dtype,
299-
collective=ref.collective,
300-
packed=ref.packed)[0]
310+
ncols += ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))
301311
return ncols
302312

303313

304-
def infer_tmem_cols_layout(
314+
def infer_tmem_layout(
305315
shape: tuple[int, ...],
306316
dtype: jnp.dtype,
307317
*,
308318
packed: bool,
309-
collective: bool,
310-
layout: tcgen05.TMEMLayout | None = None) -> tuple[int, tcgen05.TMEMLayout]:
319+
collective: bool) -> tcgen05.TMEMLayout:
311320
"""Infers the number of columns used and layout for allocating TMEM Refs."""
312321
if packed:
313322
packing = 32 // dtypes.bit_width(dtype)
314323
else:
315324
packing = 1
316-
if layout is None:
317-
layout = tcgen05._infer_tmem_layout(shape, # type: ignore[arg-type]
318-
collective=collective,
319-
packing=packing)
320-
with ir.Context():
321-
ir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
322-
cols_used = layout.cols_in_shape(shape, ir_dtype) # type: ignore[arg-type]
323-
return cols_used, layout
325+
return tcgen05._infer_tmem_layout(shape, collective=collective, packing=packing) # typing: ignore
324326

325327

326328
def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
@@ -365,11 +367,9 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
365367
for ref in jax.tree.leaves(ref_group):
366368
if not isinstance(ref, pallas_core.TransformedRef):
367369
ref = pallas_core.TransformedRef(ref, transforms=())
368-
ncols, _ = infer_tmem_cols_layout(
369-
ref.shape, ref.dtype, # type: ignore[arg-type]
370-
packed=ref.packed, collective=ref.collective)
370+
ncols = ref.layout.cols_in_shape(ref.shape, dtypes.bit_width(ref.dtype))
371371
transform = ExtractAliasedRef.from_transformed_ref(
372-
ref, col_offset, packed=ref.packed, collective=ref.collective)
372+
ref, col_offset, layout=ref.layout)
373373
flat_refs.append(
374374
pallas_core.TransformedRef(
375375
ref_union, transforms=(transform, *ref.transforms)
@@ -409,19 +409,23 @@ def update(self, inner_aval=None, memory_space=None):
409409
ref = super().update(inner_aval, memory_space)
410410
return AbstractRefUnion(ref.inner_aval, self.refs, self.memory_space)
411411

412+
@functools.cached_property
413+
def layout(self) -> tcgen05.TMEMLayout:
414+
if self.memory_space != TMEM:
415+
raise ValueError("layout attribute is only defined for TMEM refs")
416+
return tcgen05.tmem_default_layout(packing=1)
417+
412418
@functools.cached_property
413419
def collective(self) -> bool:
414420
if self.memory_space != TMEM:
415-
raise ValueError("Collective is only supported for TMEM.")
421+
raise ValueError("collective attribute is only defined for TMEM refs")
416422
ref_leaves = jax.tree.leaves(self.refs)
417423
first_ref = ref_leaves[0]
418-
# Check if all Refs have the same collective attribute.
419-
if not all(ref.collective == first_ref.collective for ref in ref_leaves):
420-
raise ValueError(f"All Refs must be either collective/not collective."
421-
f" Got: {[ref.collective for ref in ref_leaves]}")
424+
assert all(ref.collective == first_ref.collective for ref in ref_leaves)
422425
return first_ref.collective
423426

424427

428+
425429
@dataclasses.dataclass(init=False, frozen=True)
426430
class RefUnion(GPUMemoryRef):
427431
"""A sequence of trees of refs that are allowed to reuse the same memory.
@@ -450,11 +454,18 @@ def __init__(self, *refs: _GPUMemoryRefTree):
450454
elif all(ref.memory_space == TMEM for ref in ref_leaves):
451455
object.__setattr__(self, "refs", refs)
452456
max_cols = max(map(_ref_group_tmem_col_size, self.refs))
457+
is_collective = ref_leaves[0].collective
458+
if any(r.collective != is_collective for r in ref_leaves):
459+
raise ValueError(
460+
"Some aliased TMEM references are collective and some are not."
461+
)
453462
super().__init__(
454463
shape=(128, max_cols,),
455464
dtype=jnp.int32,
456465
memory_space=TMEM,
457466
transforms=(),
467+
layout=tcgen05.tmem_default_layout(packing=1),
468+
collective=all(ref.collective for ref in ref_leaves),
458469
)
459470
else:
460471
raise NotImplementedError(
@@ -752,20 +763,16 @@ class ExtractAliasedRef(state_types.Transform):
752763
shape: tuple[int, ...]
753764
offset: int
754765
# TMEM-specific params
755-
packed: bool | None
756-
collective: bool | None
766+
layout: tcgen05.TMEMLayout | None
757767

758768
@classmethod
759769
def from_transformed_ref(
760-
cls, ref: pallas_core.TransformedRef, byte_offset: int,
761-
packed: bool | None = None,
762-
collective: bool | None = None,
770+
cls,
771+
ref: pallas_core.TransformedRef,
772+
byte_offset: int,
773+
layout: tcgen05.TMEMLayout | None = None,
763774
):
764-
return cls(
765-
dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset,
766-
packed=packed,
767-
collective=collective,
768-
)
775+
return cls(dtypes.dtype(ref.dtype), ref.ref.shape, byte_offset, layout)
769776

770777
def transform_shape(self, shape):
771778
if shape is None:
@@ -777,8 +784,7 @@ def transform_dtype(self, dtype):
777784
return self.dtype
778785

779786
def tree_flatten(self):
780-
return (), (self.dtype, self.shape, self.offset,
781-
self.packed, self.collective)
787+
return (), (self.dtype, self.shape, self.offset, self.layout)
782788

783789
@classmethod
784790
def tree_unflatten(cls, metadata, arrays):
@@ -1040,20 +1046,20 @@ def _getitem(self, tracer, idx):
10401046

10411047

10421048
class AbstractTMEMRef(state.AbstractRef):
1043-
__slots__ = ["inner_aval", "memory_space", "packed", "collective"]
1049+
__slots__ = ["inner_aval", "memory_space", "layout", "collective"]
10441050

1045-
def __init__(self, inner_aval, memory_space, packed, collective):
1051+
def __init__(self, inner_aval, memory_space, layout, collective):
10461052
super().__init__(inner_aval, memory_space)
1047-
self.packed = packed
1053+
self.layout = layout
10481054
self.collective = collective
10491055

10501056
def __repr__(self) -> str:
1051-
return f'TMEM({self.inner_aval.str_short()},packed={self.packed})'
1057+
return f'TMEM({self.inner_aval.str_short()}, layout={self.layout}, collective={self.collective})'
10521058

10531059
def update(self, inner_aval=None, memory_space=None):
10541060
ref = super().update(inner_aval, memory_space)
10551061
return AbstractTMEMRef(
1056-
ref.inner_aval, ref.memory_space, self.packed, self.collective
1062+
ref.inner_aval, ref.memory_space, self.layout, self.collective
10571063
)
10581064

10591065

@@ -1246,6 +1252,7 @@ def to_mgpu(self) -> mgpu.FragmentedLayout:
12461252
raise ValueError("Only TiledLayout supports reductions.")
12471253
return layout.reduce(self.axes)
12481254

1255+
12491256
class Layout(SomeLayout, enum.Enum):
12501257
#: [m, n] matrix, where m % 64 == 0 == n % 8.
12511258
WGMMA = enum.auto()
@@ -1297,3 +1304,13 @@ def check_no_args():
12971304
Layout.TCGEN05_ROW = Layout.TCGEN05.reduce(1)
12981305
Layout.TCGEN05_COL = Layout.TCGEN05.reduce(0)
12991306
Layout.TCGEN05_TMEM_NATIVE_ROW = Layout.TCGEN05_TMEM_NATIVE.reduce(1)
1307+
1308+
1309+
class TMEMLayout(enum.Enum):
1310+
"""Layout for TMEM references."""
1311+
SCALES_LAYOUT = enum.auto()
1312+
1313+
def to_mgpu(self) -> mgpu.FragmentedLayout:
1314+
match self:
1315+
case TMEMLayout.SCALES_LAYOUT:
1316+
return tcgen05.scales_layout()

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,10 @@ def _run_scoped_resource_estimator(
313313
assert aval.shape[0] == 128
314314
cols_used = aval.shape[1]
315315
else:
316-
cols_used, _ = gpu_core.infer_tmem_cols_layout(
317-
aval.shape, aval.dtype, packed=aval.packed, collective=aval.collective)
316+
cols_used = aval.layout.cols_in_shape(
317+
aval.shape, dtypes.bit_width(aval.dtype)
318+
)
319+
# TODO(apaszke): Remove this. We only need to align the outermost allocation.
318320
cols_used = tcgen05._alloc_ncols(cols_used, exact=False)
319321
if aval.collective:
320322
rs += Resources(tmem_collective_scratch_cols=cols_used)
@@ -441,12 +443,9 @@ def alloc_tmem(
441443
self,
442444
struct: jax.ShapeDtypeStruct,
443445
*,
444-
layout: tcgen05.TMEMLayout | None = None,
445-
collective: bool = False,
446-
packed: bool = False,
446+
layout: tcgen05.TMEMLayout,
447+
collective: bool,
447448
) -> Iterator[ir.Value]:
448-
cols_used, layout = gpu_core.infer_tmem_cols_layout(
449-
struct.shape, struct.dtype, packed=packed, collective=collective, layout=layout)
450449
if collective:
451450
off = arith_dialect.addi(
452451
self.tmem_collective_base_ptr,
@@ -461,6 +460,9 @@ def alloc_tmem(
461460
shape=struct.shape,
462461
dtype=mgpu_utils.dtype_to_ir_type(struct.dtype),
463462
layout=layout)
463+
cols_used = layout.cols_in_shape(
464+
struct.shape, dtypes.bit_width(struct.dtype)
465+
)
464466
if collective:
465467
self.tmem_collective_used_cols += cols_used
466468
yield tmem_ref
@@ -745,7 +747,10 @@ def ref_for_aval(aval: ShapedAbstractValue):
745747
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
746748
return gpu_core.WGMMAAccumulatorRef(aval.shape, aval.dtype)
747749
elif isinstance(aval, gpu_core.AbstractTMEMRef):
748-
return gpu_core.TMEM(aval.shape, aval.dtype, packed=aval.packed)
750+
return gpu_core.GPUMemoryRef(
751+
aval.shape, aval.dtype, gpu_core.TMEM,
752+
transforms=(), layout=aval.layout, collective=aval.collective,
753+
)
749754
elif isinstance(aval, state_types.AbstractRef):
750755
return pallas_core.MemoryRef(aval.shape, aval.dtype, aval.memory_space)
751756
else:
@@ -1309,35 +1314,27 @@ def _extract_aliased_ref(
13091314
match transforms:
13101315
case (
13111316
gpu_core.ExtractAliasedRef(
1312-
dtype, transformed_shape, offset, packed, collective
1317+
dtype, transformed_shape, offset, layout
13131318
),
13141319
*other_transforms,
13151320
):
13161321
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
13171322
if isinstance(ref, tcgen05.TMEMRef):
1318-
assert packed is not None
1319-
assert collective is not None
1323+
assert layout is not None
13201324
if ref.shape[0] != transformed_shape[0]:
13211325
raise ValueError(
13221326
"TMEM aliasing only supported for Refs with the same first"
13231327
f" dimension, got {ref.shape[0]} != {transformed_shape[0]}."
13241328
)
13251329
address = arith_dialect.addi(ref.address, _i32_constant(offset))
1326-
_, tmem_layout = gpu_core.infer_tmem_cols_layout(
1327-
transformed_shape, dtype, packed=packed, collective=collective
1328-
)
13291330
ref = tcgen05.TMEMRef(
1330-
address=address,
1331-
shape=transformed_shape,
1332-
dtype=mgpu_utils.dtype_to_ir_type(dtype),
1333-
layout=tmem_layout,
1334-
)
1331+
address=address,
1332+
shape=transformed_shape,
1333+
dtype=mgpu_utils.dtype_to_ir_type(dtype),
1334+
layout=layout)
13351335
else:
1336-
assert packed is None
1337-
assert collective is None
1338-
ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(
1339-
mlir_dtype
1340-
)
1336+
assert layout is None
1337+
ref_bits = math.prod(transformed_shape) * mgpu_utils.bitwidth(mlir_dtype)
13411338
if ref_bits % 8:
13421339
raise NotImplementedError("Only byte-aligned bitcasts are supported.")
13431340
assert offset % gpu_core.SMEM_ALIGNMENT == 0
@@ -1464,17 +1461,6 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
14641461
def _get_lowering_rule(
14651462
ctx: LoweringRuleContext, x_ref, *leaves, tree, optimized=True
14661463
):
1467-
if isinstance(x_ref, tcgen05.TMEMRef):
1468-
transforms = jax.tree.unflatten(tree, leaves)
1469-
x_tmem, transforms = _handle_transforms(
1470-
ctx, x_ref, transforms, handle_transposes=False, handle_reshapes=False,
1471-
)
1472-
if transforms:
1473-
raise NotImplementedError(
1474-
f"Unimplemented transforms for TMEM refs. {transforms=}"
1475-
)
1476-
return x_tmem.load(layout=ctx.out_layout_hint)
1477-
14781464
if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref):
14791465
raise TypeError(f"Can only load from references (got {x_ref}).")
14801466
dtype = ctx.avals_out[0].dtype
@@ -2546,14 +2532,10 @@ def _run_scoped_lowering_rule(
25462532
input_refs.append(input_ref)
25472533
should_discharge.append(False)
25482534
elif aval.memory_space == gpu_core.TMEM:
2549-
if isinstance(aval, gpu_core.AbstractRefUnion):
2550-
packed = False
2551-
else:
2552-
packed = aval.packed
25532535
input_ref = alloc_stack.enter_context(
25542536
ctx.module_ctx.alloc_tmem(
25552537
jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype),
2556-
packed=packed,
2538+
layout=aval.layout,
25572539
collective=aval.collective,
25582540
)
25592541
)

0 commit comments

Comments
 (0)