57
57
# sensitive to alignment and while this is quite conservative, it gets the job
58
58
# done. We should make this more refined in the future.
59
59
SMEM_ALIGNMENT = 1024
60
+ TMEM_COL_ALIGNMENT = 4
60
61
61
62
62
63
def is_trivial_index (idx , shape ) -> bool :
@@ -146,11 +147,36 @@ def __call__(
146
147
* ,
147
148
transforms : Sequence [MemoryRefTransform ] = (),
148
149
packed : bool | None = None ,
149
- collective : bool | None = None
150
+ collective : bool | None = None ,
151
+ layout : TMEMLayout | None = None ,
150
152
) -> pallas_core .MemoryRef :
151
- # A convenience function for constructing MemoryRef types.
153
+ if self == MemorySpace .TMEM :
154
+ if transforms :
155
+ raise ValueError ("transforms are not supported for TMEM" )
156
+ if collective is None :
157
+ collective = False
158
+ if layout is None :
159
+ if packed is None :
160
+ if dtypes .bit_width (dtype ) != 32 :
161
+ raise ValueError (
162
+ "dtypes narrower than 32-bit require either the packed argument"
163
+ " or an explicit TMEM layout"
164
+ )
165
+ packed = False
166
+ layout = infer_tmem_layout (
167
+ shape , dtype , packed = packed , collective = collective
168
+ )
169
+ else :
170
+ if packed is not None :
171
+ raise ValueError ("packed cannot be specified if layout is specified." )
172
+ # We allow tcgen05.TMEMLayout to be passed in from our internal APIs.
173
+ if not isinstance (layout , tcgen05 .TMEMLayout ):
174
+ layout = layout .to_mgpu ()
175
+ else :
176
+ if packed is not None or collective is not None or layout is not None :
177
+ raise ValueError ("packed, collective and layout arguments are only supported for TMEM." )
152
178
return GPUMemoryRef (shape , dtype , memory_space = self , transforms = transforms ,
153
- packed = packed , collective = collective )
179
+ layout = layout , collective = collective )
154
180
155
181
156
182
class SemaphoreType (enum .Enum ):
@@ -223,38 +249,26 @@ def cmap_body():
223
249
class GPUMemoryRef (pallas_core .MemoryRef ):
224
250
transforms : Sequence [MemoryRefTransform ] = ()
225
251
226
- # Whether to allow TMEM packing for sub 32-bit dtypes.
227
- packed : bool | None = dataclasses .field (default = None , kw_only = True )
252
+ layout : tcgen05 .TMEMLayout | None = dataclasses .field (default = None , kw_only = True )
228
253
collective : bool | None = dataclasses .field (default = None , kw_only = True )
229
254
230
255
def __post_init__ (self ):
231
- if self .memory_space == MemorySpace .TMEM :
232
- if dtypes .bit_width (self .dtype ) < 32 and self .packed is None :
233
- raise ValueError (
234
- "Packed option must be specified for sub-32 bit dtypes." )
235
- else :
236
- if self .packed is not None :
237
- raise ValueError ("Packed option is only supported for TMEM." )
238
- if self .collective is not None :
239
- raise ValueError ("Collective option is only supported for TMEM." )
256
+ is_tmem = self .memory_space == MemorySpace .TMEM
257
+ assert (self .layout is not None ) == is_tmem
258
+ assert (self .collective is not None ) == is_tmem
259
+ assert not (self .transforms and is_tmem )
240
260
241
261
def get_ref_aval (self ) -> _Ref :
242
262
aval = jax_core .ShapedArray (self .shape , self .dtype )
243
263
for t in self .transforms :
244
264
aval = t (aval )
245
265
if self .memory_space == MemorySpace .TMEM :
246
- collective = self .collective if self .collective is not None else False
247
- packed = self .packed if self .packed is not None else False
248
- ref = pallas_core .TransformedRef (
249
- AbstractTMEMRef (aval ,
250
- memory_space = self .memory_space ,
251
- packed = packed ,
252
- collective = collective ), ()
266
+ aval = AbstractTMEMRef (
267
+ aval , self .memory_space , self .layout , self .collective
253
268
)
254
269
else :
255
- ref = pallas_core .TransformedRef (
256
- state .AbstractRef (aval , memory_space = self .memory_space ), ()
257
- )
270
+ aval = state .AbstractRef (aval , memory_space = self .memory_space )
271
+ ref = pallas_core .TransformedRef (aval , ())
258
272
for t in reversed (self .transforms ):
259
273
ref = t .undo (ref )
260
274
if not ref .transforms :
@@ -295,32 +309,23 @@ def _ref_group_tmem_col_size(refs: _GPUMemoryRefTree) -> int:
295
309
"""
296
310
ncols = 0
297
311
for ref in jax .tree .leaves (refs ):
298
- ncols += infer_tmem_cols_layout (ref .shape , ref .dtype ,
299
- collective = ref .collective ,
300
- packed = ref .packed )[0 ]
312
+ ref_ncols = ref .layout .cols_in_shape (ref .shape , dtypes .bit_width (ref .dtype ))
313
+ ncols += align_to (ref_ncols , TMEM_COL_ALIGNMENT )
301
314
return ncols
302
315
303
316
304
- def infer_tmem_cols_layout (
317
+ def infer_tmem_layout (
305
318
shape : tuple [int , ...],
306
319
dtype : jnp .dtype ,
307
320
* ,
308
321
packed : bool ,
309
- collective : bool ,
310
- layout : tcgen05 .TMEMLayout | None = None ) -> tuple [int , tcgen05 .TMEMLayout ]:
322
+ collective : bool ) -> tcgen05 .TMEMLayout :
311
323
"""Infers the number of columns used and layout for allocating TMEM Refs."""
312
324
if packed :
313
325
packing = 32 // dtypes .bit_width (dtype )
314
326
else :
315
327
packing = 1
316
- if layout is None :
317
- layout = tcgen05 ._infer_tmem_layout (shape , # type: ignore[arg-type]
318
- collective = collective ,
319
- packing = packing )
320
- with ir .Context ():
321
- ir_dtype = mgpu_utils .dtype_to_ir_type (dtype )
322
- cols_used = layout .cols_in_shape (shape , ir_dtype ) # type: ignore[arg-type]
323
- return cols_used , layout
328
+ return tcgen05 ._infer_tmem_layout (shape , collective = collective , packing = packing )
324
329
325
330
326
331
def flatten_ref_union (ref_union : AbstractRefUnion ) -> tuple [_Ref , ...]:
@@ -363,13 +368,12 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
363
368
for ref_group in ref_union .refs :
364
369
col_offset = 0
365
370
for ref in jax .tree .leaves (ref_group ):
371
+ col_offset = align_to (col_offset , TMEM_COL_ALIGNMENT )
366
372
if not isinstance (ref , pallas_core .TransformedRef ):
367
373
ref = pallas_core .TransformedRef (ref , transforms = ())
368
- ncols , _ = infer_tmem_cols_layout (
369
- ref .shape , ref .dtype , # type: ignore[arg-type]
370
- packed = ref .packed , collective = ref .collective )
374
+ ncols = ref .layout .cols_in_shape (ref .shape , dtypes .bit_width (ref .dtype ))
371
375
transform = ExtractAliasedRef .from_transformed_ref (
372
- ref , col_offset , packed = ref .packed , collective = ref . collective )
376
+ ref , col_offset , layout = ref .layout )
373
377
flat_refs .append (
374
378
pallas_core .TransformedRef (
375
379
ref_union , transforms = (transform , * ref .transforms )
@@ -409,19 +413,23 @@ def update(self, inner_aval=None, memory_space=None):
409
413
ref = super ().update (inner_aval , memory_space )
410
414
return AbstractRefUnion (ref .inner_aval , self .refs , self .memory_space )
411
415
416
+ @functools .cached_property
417
+ def layout (self ) -> tcgen05 .TMEMLayout :
418
+ if self .memory_space != TMEM :
419
+ raise ValueError ("layout attribute is only defined for TMEM refs" )
420
+ return tcgen05 .tmem_default_layout (packing = 1 )
421
+
412
422
@functools .cached_property
413
423
def collective (self ) -> bool :
414
424
if self .memory_space != TMEM :
415
- raise ValueError ("Collective is only supported for TMEM. " )
425
+ raise ValueError ("collective attribute is only defined for TMEM refs " )
416
426
ref_leaves = jax .tree .leaves (self .refs )
417
427
first_ref = ref_leaves [0 ]
418
- # Check if all Refs have the same collective attribute.
419
- if not all (ref .collective == first_ref .collective for ref in ref_leaves ):
420
- raise ValueError (f"All Refs must be either collective/not collective."
421
- f" Got: { [ref .collective for ref in ref_leaves ]} " )
428
+ assert all (ref .collective == first_ref .collective for ref in ref_leaves )
422
429
return first_ref .collective
423
430
424
431
432
+
425
433
@dataclasses .dataclass (init = False , frozen = True )
426
434
class RefUnion (GPUMemoryRef ):
427
435
"""A sequence of trees of refs that are allowed to reuse the same memory.
@@ -450,11 +458,18 @@ def __init__(self, *refs: _GPUMemoryRefTree):
450
458
elif all (ref .memory_space == TMEM for ref in ref_leaves ):
451
459
object .__setattr__ (self , "refs" , refs )
452
460
max_cols = max (map (_ref_group_tmem_col_size , self .refs ))
461
+ is_collective = ref_leaves [0 ].collective
462
+ if any (r .collective != is_collective for r in ref_leaves ):
463
+ raise ValueError (
464
+ "Some aliased TMEM references are collective and some are not."
465
+ )
453
466
super ().__init__ (
454
467
shape = (128 , max_cols ,),
455
468
dtype = jnp .int32 ,
456
469
memory_space = TMEM ,
457
470
transforms = (),
471
+ layout = tcgen05 .tmem_default_layout (packing = 1 ),
472
+ collective = all (ref .collective for ref in ref_leaves ),
458
473
)
459
474
else :
460
475
raise NotImplementedError (
@@ -752,20 +767,16 @@ class ExtractAliasedRef(state_types.Transform):
752
767
shape : tuple [int , ...]
753
768
offset : int
754
769
# TMEM-specific params
755
- packed : bool | None
756
- collective : bool | None
770
+ layout : tcgen05 .TMEMLayout | None
757
771
758
772
@classmethod
759
773
def from_transformed_ref (
760
- cls , ref : pallas_core .TransformedRef , byte_offset : int ,
761
- packed : bool | None = None ,
762
- collective : bool | None = None ,
774
+ cls ,
775
+ ref : pallas_core .TransformedRef ,
776
+ byte_offset : int ,
777
+ layout : tcgen05 .TMEMLayout | None = None ,
763
778
):
764
- return cls (
765
- dtypes .dtype (ref .dtype ), ref .ref .shape , byte_offset ,
766
- packed = packed ,
767
- collective = collective ,
768
- )
779
+ return cls (dtypes .dtype (ref .dtype ), ref .ref .shape , byte_offset , layout )
769
780
770
781
def transform_shape (self , shape ):
771
782
if shape is None :
@@ -777,8 +788,7 @@ def transform_dtype(self, dtype):
777
788
return self .dtype
778
789
779
790
def tree_flatten (self ):
780
- return (), (self .dtype , self .shape , self .offset ,
781
- self .packed , self .collective )
791
+ return (), (self .dtype , self .shape , self .offset , self .layout )
782
792
783
793
@classmethod
784
794
def tree_unflatten (cls , metadata , arrays ):
@@ -1040,20 +1050,20 @@ def _getitem(self, tracer, idx):
1040
1050
1041
1051
1042
1052
class AbstractTMEMRef (state .AbstractRef ):
1043
- __slots__ = ["inner_aval" , "memory_space" , "packed " , "collective" ]
1053
+ __slots__ = ["inner_aval" , "memory_space" , "layout " , "collective" ]
1044
1054
1045
- def __init__ (self , inner_aval , memory_space , packed , collective ):
1055
+ def __init__ (self , inner_aval , memory_space , layout , collective ):
1046
1056
super ().__init__ (inner_aval , memory_space )
1047
- self .packed = packed
1057
+ self .layout = layout
1048
1058
self .collective = collective
1049
1059
1050
1060
def __repr__ (self ) -> str :
1051
- return f'TMEM({ self .inner_aval .str_short ()} ,packed ={ self .packed } )'
1061
+ return f'TMEM({ self .inner_aval .str_short ()} , layout ={ self .layout } , collective= { self . collective } )'
1052
1062
1053
1063
def update (self , inner_aval = None , memory_space = None ):
1054
1064
ref = super ().update (inner_aval , memory_space )
1055
1065
return AbstractTMEMRef (
1056
- ref .inner_aval , ref .memory_space , self .packed , self .collective
1066
+ ref .inner_aval , ref .memory_space , self .layout , self .collective
1057
1067
)
1058
1068
1059
1069
@@ -1246,6 +1256,7 @@ def to_mgpu(self) -> mgpu.FragmentedLayout:
1246
1256
raise ValueError ("Only TiledLayout supports reductions." )
1247
1257
return layout .reduce (self .axes )
1248
1258
1259
+
1249
1260
class Layout (SomeLayout , enum .Enum ):
1250
1261
#: [m, n] matrix, where m % 64 == 0 == n % 8.
1251
1262
WGMMA = enum .auto ()
@@ -1297,3 +1308,13 @@ def check_no_args():
1297
1308
Layout .TCGEN05_ROW = Layout .TCGEN05 .reduce (1 )
1298
1309
Layout .TCGEN05_COL = Layout .TCGEN05 .reduce (0 )
1299
1310
Layout .TCGEN05_TMEM_NATIVE_ROW = Layout .TCGEN05_TMEM_NATIVE .reduce (1 )
1311
+
1312
+
1313
+ class TMEMLayout (enum .Enum ):
1314
+ """Layout for TMEM references."""
1315
+ SCALES_LAYOUT = enum .auto ()
1316
+
1317
+ def to_mgpu (self ) -> mgpu .FragmentedLayout :
1318
+ match self :
1319
+ case TMEMLayout .SCALES_LAYOUT :
1320
+ return tcgen05 .scales_layout ()
0 commit comments