@@ -146,11 +146,35 @@ def __call__(
146
146
* ,
147
147
transforms : Sequence [MemoryRefTransform ] = (),
148
148
packed : bool | None = None ,
149
- collective : bool | None = None
149
+ collective : bool | None = None ,
150
+ layout : TMEMLayout | None = None ,
150
151
) -> pallas_core .MemoryRef :
151
- # A convenience function for constructing MemoryRef types.
152
+ if self == MemorySpace .TMEM :
153
+ if transforms :
154
+ raise ValueError ("transforms are not supported for TMEM" )
155
+ if collective is None :
156
+ collective = False
157
+ if layout is None :
158
+ if packed is None :
159
+ if dtypes .bit_width (dtype ) != 32 :
160
+ raise ValueError (
161
+ "dtypes narrower than 32-bit require either the packed argument"
162
+ " or an explicit TMEM layout"
163
+ )
164
+ packed = False
165
+ mgpu_layout = infer_tmem_layout (
166
+ shape , dtype , packed = packed , collective = collective
167
+ )
168
+ else :
169
+ if packed is not None :
170
+ raise ValueError ("packed cannot be specified if layout is specified." )
171
+ mgpu_layout = layout .to_mgpu ()
172
+ else :
173
+ if packed is not None or collective is not None or layout is not None :
174
+ raise ValueError ("packed, collective and layout arguments are only supported for TMEM." )
175
+ mgpu_layout = None
152
176
return GPUMemoryRef (shape , dtype , memory_space = self , transforms = transforms ,
153
- packed = packed , collective = collective )
177
+ layout = mgpu_layout , collective = collective )
154
178
155
179
156
180
class SemaphoreType (enum .Enum ):
@@ -223,38 +247,26 @@ def cmap_body():
223
247
class GPUMemoryRef (pallas_core .MemoryRef ):
224
248
transforms : Sequence [MemoryRefTransform ] = ()
225
249
226
- # Whether to allow TMEM packing for sub 32-bit dtypes.
227
- packed : bool | None = dataclasses .field (default = None , kw_only = True )
250
+ layout : tcgen05 .TMEMLayout | None = dataclasses .field (default = None , kw_only = True )
228
251
collective : bool | None = dataclasses .field (default = None , kw_only = True )
229
252
230
253
def __post_init__ (self ):
231
- if self .memory_space == MemorySpace .TMEM :
232
- if dtypes .bit_width (self .dtype ) < 32 and self .packed is None :
233
- raise ValueError (
234
- "Packed option must be specified for sub-32 bit dtypes." )
235
- else :
236
- if self .packed is not None :
237
- raise ValueError ("Packed option is only supported for TMEM." )
238
- if self .collective is not None :
239
- raise ValueError ("Collective option is only supported for TMEM." )
254
+ is_tmem = self .memory_space == MemorySpace .TMEM
255
+ assert (self .layout is not None ) == is_tmem
256
+ assert (self .collective is not None ) == is_tmem
257
+ assert not (self .transforms and is_tmem )
240
258
241
259
def get_ref_aval (self ) -> _Ref :
242
- aval = jax_core .ShapedArray (self .shape , self .dtype )
260
+ aval : Any = jax_core .ShapedArray (self .shape , self .dtype )
243
261
for t in self .transforms :
244
262
aval = t (aval )
245
263
if self .memory_space == MemorySpace .TMEM :
246
- collective = self .collective if self .collective is not None else False
247
- packed = self .packed if self .packed is not None else False
248
- ref = pallas_core .TransformedRef (
249
- AbstractTMEMRef (aval ,
250
- memory_space = self .memory_space ,
251
- packed = packed ,
252
- collective = collective ), ()
264
+ aval = AbstractTMEMRef (
265
+ aval , self .memory_space , self .layout , self .collective
253
266
)
254
267
else :
255
- ref = pallas_core .TransformedRef (
256
- state .AbstractRef (aval , memory_space = self .memory_space ), ()
257
- )
268
+ aval = state .AbstractRef (aval , memory_space = self .memory_space )
269
+ ref = pallas_core .TransformedRef (aval , ())
258
270
for t in reversed (self .transforms ):
259
271
ref = t .undo (ref )
260
272
if not ref .transforms :
@@ -295,32 +307,22 @@ def _ref_group_tmem_col_size(refs: _GPUMemoryRefTree) -> int:
295
307
"""
296
308
ncols = 0
297
309
for ref in jax .tree .leaves (refs ):
298
- ncols += infer_tmem_cols_layout (ref .shape , ref .dtype ,
299
- collective = ref .collective ,
300
- packed = ref .packed )[0 ]
310
+ ncols += ref .layout .cols_in_shape (ref .shape , dtypes .bit_width (ref .dtype ))
301
311
return ncols
302
312
303
313
304
- def infer_tmem_cols_layout (
314
+ def infer_tmem_layout (
305
315
shape : tuple [int , ...],
306
316
dtype : jnp .dtype ,
307
317
* ,
308
318
packed : bool ,
309
- collective : bool ,
310
- layout : tcgen05 .TMEMLayout | None = None ) -> tuple [int , tcgen05 .TMEMLayout ]:
319
+ collective : bool ) -> tcgen05 .TMEMLayout :
311
320
"""Infers the number of columns used and layout for allocating TMEM Refs."""
312
321
if packed :
313
322
packing = 32 // dtypes .bit_width (dtype )
314
323
else :
315
324
packing = 1
316
- if layout is None :
317
- layout = tcgen05 ._infer_tmem_layout (shape , # type: ignore[arg-type]
318
- collective = collective ,
319
- packing = packing )
320
- with ir .Context ():
321
- ir_dtype = mgpu_utils .dtype_to_ir_type (dtype )
322
- cols_used = layout .cols_in_shape (shape , ir_dtype ) # type: ignore[arg-type]
323
- return cols_used , layout
325
+ return tcgen05 ._infer_tmem_layout (shape , collective = collective , packing = packing ) # typing: ignore
324
326
325
327
326
328
def flatten_ref_union (ref_union : AbstractRefUnion ) -> tuple [_Ref , ...]:
@@ -365,11 +367,9 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
365
367
for ref in jax .tree .leaves (ref_group ):
366
368
if not isinstance (ref , pallas_core .TransformedRef ):
367
369
ref = pallas_core .TransformedRef (ref , transforms = ())
368
- ncols , _ = infer_tmem_cols_layout (
369
- ref .shape , ref .dtype , # type: ignore[arg-type]
370
- packed = ref .packed , collective = ref .collective )
370
+ ncols = ref .layout .cols_in_shape (ref .shape , dtypes .bit_width (ref .dtype ))
371
371
transform = ExtractAliasedRef .from_transformed_ref (
372
- ref , col_offset , packed = ref .packed , collective = ref . collective )
372
+ ref , col_offset , layout = ref .layout )
373
373
flat_refs .append (
374
374
pallas_core .TransformedRef (
375
375
ref_union , transforms = (transform , * ref .transforms )
@@ -409,19 +409,23 @@ def update(self, inner_aval=None, memory_space=None):
409
409
ref = super ().update (inner_aval , memory_space )
410
410
return AbstractRefUnion (ref .inner_aval , self .refs , self .memory_space )
411
411
412
+ @functools .cached_property
413
+ def layout (self ) -> tcgen05 .TMEMLayout :
414
+ if self .memory_space != TMEM :
415
+ raise ValueError ("layout attribute is only defined for TMEM refs" )
416
+ return tcgen05 .tmem_default_layout (packing = 1 )
417
+
412
418
@functools .cached_property
413
419
def collective (self ) -> bool :
414
420
if self .memory_space != TMEM :
415
- raise ValueError ("Collective is only supported for TMEM. " )
421
+ raise ValueError ("collective attribute is only defined for TMEM refs " )
416
422
ref_leaves = jax .tree .leaves (self .refs )
417
423
first_ref = ref_leaves [0 ]
418
- # Check if all Refs have the same collective attribute.
419
- if not all (ref .collective == first_ref .collective for ref in ref_leaves ):
420
- raise ValueError (f"All Refs must be either collective/not collective."
421
- f" Got: { [ref .collective for ref in ref_leaves ]} " )
424
+ assert all (ref .collective == first_ref .collective for ref in ref_leaves )
422
425
return first_ref .collective
423
426
424
427
428
+
425
429
@dataclasses .dataclass (init = False , frozen = True )
426
430
class RefUnion (GPUMemoryRef ):
427
431
"""A sequence of trees of refs that are allowed to reuse the same memory.
@@ -450,11 +454,18 @@ def __init__(self, *refs: _GPUMemoryRefTree):
450
454
elif all (ref .memory_space == TMEM for ref in ref_leaves ):
451
455
object .__setattr__ (self , "refs" , refs )
452
456
max_cols = max (map (_ref_group_tmem_col_size , self .refs ))
457
+ is_collective = ref_leaves [0 ].collective
458
+ if any (r .collective != is_collective for r in ref_leaves ):
459
+ raise ValueError (
460
+ "Some aliased TMEM references are collective and some are not."
461
+ )
453
462
super ().__init__ (
454
463
shape = (128 , max_cols ,),
455
464
dtype = jnp .int32 ,
456
465
memory_space = TMEM ,
457
466
transforms = (),
467
+ layout = tcgen05 .tmem_default_layout (packing = 1 ),
468
+ collective = all (ref .collective for ref in ref_leaves ),
458
469
)
459
470
else :
460
471
raise NotImplementedError (
@@ -752,20 +763,16 @@ class ExtractAliasedRef(state_types.Transform):
752
763
shape : tuple [int , ...]
753
764
offset : int
754
765
# TMEM-specific params
755
- packed : bool | None
756
- collective : bool | None
766
+ layout : tcgen05 .TMEMLayout | None
757
767
758
768
@classmethod
759
769
def from_transformed_ref (
760
- cls , ref : pallas_core .TransformedRef , byte_offset : int ,
761
- packed : bool | None = None ,
762
- collective : bool | None = None ,
770
+ cls ,
771
+ ref : pallas_core .TransformedRef ,
772
+ byte_offset : int ,
773
+ layout : tcgen05 .TMEMLayout | None = None ,
763
774
):
764
- return cls (
765
- dtypes .dtype (ref .dtype ), ref .ref .shape , byte_offset ,
766
- packed = packed ,
767
- collective = collective ,
768
- )
775
+ return cls (dtypes .dtype (ref .dtype ), ref .ref .shape , byte_offset , layout )
769
776
770
777
def transform_shape (self , shape ):
771
778
if shape is None :
@@ -777,8 +784,7 @@ def transform_dtype(self, dtype):
777
784
return self .dtype
778
785
779
786
def tree_flatten (self ):
780
- return (), (self .dtype , self .shape , self .offset ,
781
- self .packed , self .collective )
787
+ return (), (self .dtype , self .shape , self .offset , self .layout )
782
788
783
789
@classmethod
784
790
def tree_unflatten (cls , metadata , arrays ):
@@ -1040,20 +1046,20 @@ def _getitem(self, tracer, idx):
1040
1046
1041
1047
1042
1048
class AbstractTMEMRef (state .AbstractRef ):
1043
- __slots__ = ["inner_aval" , "memory_space" , "packed " , "collective" ]
1049
+ __slots__ = ["inner_aval" , "memory_space" , "layout " , "collective" ]
1044
1050
1045
- def __init__ (self , inner_aval , memory_space , packed , collective ):
1051
+ def __init__ (self , inner_aval , memory_space , layout , collective ):
1046
1052
super ().__init__ (inner_aval , memory_space )
1047
- self .packed = packed
1053
+ self .layout = layout
1048
1054
self .collective = collective
1049
1055
1050
1056
def __repr__ (self ) -> str :
1051
- return f'TMEM({ self .inner_aval .str_short ()} ,packed ={ self .packed } )'
1057
+ return f'TMEM({ self .inner_aval .str_short ()} , layout ={ self .layout } , collective= { self . collective } )'
1052
1058
1053
1059
def update (self , inner_aval = None , memory_space = None ):
1054
1060
ref = super ().update (inner_aval , memory_space )
1055
1061
return AbstractTMEMRef (
1056
- ref .inner_aval , ref .memory_space , self .packed , self .collective
1062
+ ref .inner_aval , ref .memory_space , self .layout , self .collective
1057
1063
)
1058
1064
1059
1065
@@ -1246,6 +1252,7 @@ def to_mgpu(self) -> mgpu.FragmentedLayout:
1246
1252
raise ValueError ("Only TiledLayout supports reductions." )
1247
1253
return layout .reduce (self .axes )
1248
1254
1255
+
1249
1256
class Layout (SomeLayout , enum .Enum ):
1250
1257
#: [m, n] matrix, where m % 64 == 0 == n % 8.
1251
1258
WGMMA = enum .auto ()
@@ -1297,3 +1304,13 @@ def check_no_args():
1297
1304
Layout .TCGEN05_ROW = Layout .TCGEN05 .reduce (1 )
1298
1305
Layout .TCGEN05_COL = Layout .TCGEN05 .reduce (0 )
1299
1306
Layout .TCGEN05_TMEM_NATIVE_ROW = Layout .TCGEN05_TMEM_NATIVE .reduce (1 )
1307
+
1308
+
1309
+ class TMEMLayout (enum .Enum ):
1310
+ """Layout for TMEM references."""
1311
+ SCALES_LAYOUT = enum .auto ()
1312
+
1313
+ def to_mgpu (self ) -> mgpu .FragmentedLayout :
1314
+ match self :
1315
+ case TMEMLayout .SCALES_LAYOUT :
1316
+ return tcgen05 .scales_layout ()
0 commit comments