@@ -86,7 +86,9 @@ class MeshContext:
86
86
@dataclasses .dataclass
87
87
class LoweringContext :
88
88
ir_context : ir .Context
89
- grid_indices : Sequence [ir .Value ] | None
89
+ grid_rank : int # Includes both user and vmap axes.
90
+ mapped_dims : tuple [int , ...] # Indices of vmapped grid dimensions.
91
+ user_grid_indices : Sequence [ir .Value ] | None
90
92
block_shapes : list [tuple [int | pl_core .Mapped , ...]]
91
93
name_stack : source_info_util .NameStack
92
94
mesh_context : MeshContext | None
@@ -475,6 +477,8 @@ def body_func(*args):
475
477
mesh_context = None
476
478
lowering_context = LoweringContext (
477
479
ctx ,
480
+ len (mosaic_grid_mapping .grid ),
481
+ mosaic_grid_mapping .mapped_dims ,
478
482
None ,
479
483
arg_block_shapes ,
480
484
source_info_util .NameStack (),
@@ -531,6 +535,8 @@ def body_func(*args):
531
535
mesh_context = None
532
536
lowering_context = LoweringContext (
533
537
ctx ,
538
+ len (mosaic_grid_mapping .grid ),
539
+ mosaic_grid_mapping .mapped_dims ,
534
540
jaxpr_indices ,
535
541
arg_block_shapes ,
536
542
source_info_util .NameStack (),
@@ -1846,22 +1852,32 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
1846
1852
1847
1853
1848
1854
def _program_id_lowering_rule (ctx : LoweringRuleContext , * , axis : int ):
1849
- if ctx .lowering_context .grid_indices is None :
1855
+ if ctx .lowering_context .user_grid_indices is None :
1850
1856
raise ValueError (
1851
1857
f"program id: { axis } was passed, but user did not provide a grid."
1852
1858
)
1853
- length = len (ctx .lowering_context .grid_indices )
1859
+ length = len (ctx .lowering_context .user_grid_indices )
1854
1860
if not (0 <= axis < length ):
1855
1861
raise ValueError (
1856
1862
f"user passed in program id with axis: { axis } , but grid only has"
1857
1863
f" length: { length } "
1858
1864
)
1859
- return ctx .lowering_context .grid_indices [axis ]
1865
+ return ctx .lowering_context .user_grid_indices [axis ]
1860
1866
lowering_rules [primitives .program_id_p ] = _program_id_lowering_rule
1861
1867
1862
1868
def _num_programs_lowering_rule (ctx : LoweringRuleContext , * , axis : int ):
1863
- del ctx
1864
- return tpu .iteration_bound (axis )
1869
+ mapped_axes = set (ctx .lowering_context .mapped_dims )
1870
+ seen_user_axes = 0
1871
+ for i in range (ctx .lowering_context .grid_rank ):
1872
+ seen_user_axes += int (i not in mapped_axes )
1873
+ if seen_user_axes == axis + 1 :
1874
+ break
1875
+ else :
1876
+ raise ValueError (
1877
+ f"user passed in program id with axis: { axis } , but grid only has"
1878
+ f" length: { len (ctx .lowering_context .grid_rank )} "
1879
+ )
1880
+ return tpu .iteration_bound (i )
1865
1881
lowering_rules [primitives .num_programs_p ] = _num_programs_lowering_rule
1866
1882
1867
1883
0 commit comments