77
77
78
78
@dataclasses .dataclass
79
79
class MeshContext :
80
- logical_to_mesh : ir . Value
80
+ mesh_shape : tuple [ int , ...]
81
81
axis_names : tuple [str , ...]
82
82
mesh_strides : tuple [int , ...]
83
83
@@ -298,20 +298,7 @@ def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
298
298
mesh_strides = pallas_utils .strides_from_shape (tuple (
299
299
mesh .shape [a ] for a in axis_names
300
300
))
301
- logical_to_mesh = np .empty ((mesh .size , len (axis_names )), dtype = np .int32 )
302
- for i , idx in enumerate (np .ndindex (* mesh .device_ids .shape )):
303
- logical_to_mesh [i ] = np .array (idx )
304
- self .mesh_info = MeshInfo (logical_to_mesh , axis_names , mesh_strides )
305
- l_to_m_aval = pl_core .AbstractMemoryRef (
306
- jax_core .raise_to_shaped (jax_core .get_aval (logical_to_mesh )),
307
- TPUMemorySpace .SMEM ,
308
- )
309
- # We are now passing in the logical -> mesh index mapping
310
- # TODO(sharadmv,apaszke): avoid stalling pipeline by marking the index
311
- # mapping as scalar prefetch and instead just mark it as an SMEM operand.
312
- self .scalar_prefetch_types = (
313
- _get_arg_type (l_to_m_aval , None )[0 ],
314
- * self .scalar_prefetch_types )
301
+ self .mesh_info = MeshInfo (mesh .device_ids .shape , axis_names , mesh_strides )
315
302
316
303
def maybe_compress_grid (self ):
317
304
# If we have many leading parallel dimensions, we should "compress" them
@@ -324,9 +311,7 @@ def has_communication(self) -> bool:
324
311
return bool (jax_core .used_axis_names_jaxpr (self .jaxpr ))
325
312
326
313
def get_extra_args (self ) -> tuple [Any , ...]:
327
- if self .mesh_info is None :
328
- return ()
329
- return (self .mesh_info .logical_to_mesh ,)
314
+ return ()
330
315
331
316
def get_dimension_semantics (self ) -> ir .ArrayAttr :
332
317
@@ -344,7 +329,7 @@ def _get_semantics(s: str | None) -> str:
344
329
345
330
@dataclasses .dataclass
346
331
class MeshInfo :
347
- logical_to_mesh : np . ndarray
332
+ mesh_shape : tuple [ int , ...]
348
333
axis_names : list [str ]
349
334
mesh_strides : tuple [int , ...]
350
335
@@ -469,9 +454,9 @@ def body_func(*args):
469
454
470
455
mesh_info = mosaic_grid_mapping .mesh_info
471
456
if mesh_info is not None :
472
- ( l_to_m ,), scalar_prefetch = split_list ( scalar_prefetch , [ 1 ])
473
- mesh_context = MeshContext ( l_to_m , mesh_info .axis_names ,
474
- mesh_info . mesh_strides )
457
+ mesh_context = MeshContext (
458
+ mesh_info . mesh_shape , mesh_info .axis_names , mesh_info . mesh_strides
459
+ )
475
460
else :
476
461
mesh_context = None
477
462
lowering_context = LoweringContext (
@@ -527,9 +512,9 @@ def body_func(*args):
527
512
if i not in mosaic_grid_mapping .mapped_dims )
528
513
mesh_info = mosaic_grid_mapping .mesh_info
529
514
if mesh_info is not None :
530
- ( l_to_m ,), scalar_prefetch = split_list ( scalar_prefetch , [ 1 ])
531
- mesh_context = MeshContext ( l_to_m , mesh_info .axis_names ,
532
- mesh_info . mesh_strides )
515
+ mesh_context = MeshContext (
516
+ mesh_info . mesh_shape , mesh_info .axis_names , mesh_info . mesh_strides
517
+ )
533
518
else :
534
519
mesh_context = None
535
520
lowering_context = LoweringContext (
@@ -2145,11 +2130,15 @@ def _device_id_lowering_rule(ctx: LoweringRuleContext):
2145
2130
lowering_rules [tpu_primitives .device_id_p ] = _device_id_lowering_rule
2146
2131
2147
2132
def _axis_index_rule (ctx : LoweringRuleContext , * , axis_name : str ):
2148
- device_id = _make_index ( tpu .DeviceIdOp ().result )
2149
- l_to_m = ctx .lowering_context .mesh_context .logical_to_mesh
2133
+ device_id = tpu .DeviceIdOp ().result
2134
+ mesh_shape = ctx .lowering_context .mesh_context .mesh_shape
2150
2135
axis_names = ctx .lowering_context .mesh_context .axis_names
2151
- col = _make_index (axis_names .index (axis_name ))
2152
- return memref .LoadOp (l_to_m , [device_id , col ]).result
2136
+ axis_index = axis_names .index (axis_name )
2137
+ axis_size = ir_constant (mesh_shape [axis_index ])
2138
+ minor_divisor = ir_constant (
2139
+ np .prod (mesh_shape [axis_index + 1 :], dtype = np .int32 )
2140
+ )
2141
+ return arith .remsi (arith .divsi (device_id , minor_divisor ), axis_size )
2153
2142
lowering_rules [lax .axis_index_p ] = _axis_index_rule
2154
2143
2155
2144
def _get_barrier_semaphore_rule (ctx : LoweringRuleContext ):
0 commit comments