@@ -93,9 +93,12 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
93
93
if not isinstance (mesh , Mesh ):
94
94
raise TypeError ("shard_map requires a `jax.sharding.Mesh` instance for its "
95
95
f"second argument, but got { mesh } of type { type (mesh )} ." )
96
- _check_specs (SpecErrorType .input , in_specs )
96
+ if not auto .issubset (mesh .axis_names ):
97
+ raise ValueError (f"shard_map requires auto={ auto } to be a subset of "
98
+ f"mesh.axis_names={ mesh .axis_names } " )
99
+ _check_specs (SpecErrorType .input , in_specs , auto )
97
100
if not callable (out_specs ):
98
- _check_specs (SpecErrorType .out , out_specs )
101
+ _check_specs (SpecErrorType .out , out_specs , auto )
99
102
100
103
@util .wraps (f )
101
104
@traceback_util .api_boundary
@@ -114,7 +117,7 @@ def wrapped(*args):
114
117
def out_names_thunk ():
115
118
if callable (out_specs ):
116
119
out_specs_ = out_specs ()
117
- _check_specs (SpecErrorType .out , out_specs_ )
120
+ _check_specs (SpecErrorType .out , out_specs_ , auto )
118
121
else :
119
122
out_specs_ = out_specs
120
123
dummy = tree_unflatten (out_tree (), [object ()] * out_tree ().num_leaves )
@@ -162,17 +165,40 @@ def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
162
165
163
166
SpecErrorType = enum .Enum ('SpecErrorType' , ['input' , 'out' ])
164
167
165
- def _check_specs (error_type : SpecErrorType , specs : Any ) -> None :
168
+ def _check_specs (error_type : SpecErrorType , specs : Any , auto ) -> None :
166
169
if error_type == SpecErrorType .input and specs is None :
167
170
raise TypeError (
168
171
"shard_map in_specs argument must be a pytree of "
169
172
"`jax.sharding.PartitionSpec` instances, but it was None.\n "
170
173
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
171
174
"where `P = jax.sharding.PartitionSpec`?" )
172
- if all (isinstance (p , PartitionSpec ) for p in tree_leaves (specs )): return
175
+ def check_spec (p ):
176
+ if not isinstance (p , PartitionSpec ):
177
+ return False
178
+ for names in p :
179
+ if not isinstance (names , tuple ):
180
+ names = (names ,)
181
+ for name in names :
182
+ if name in auto :
183
+ return False
184
+ return True
185
+ if all (check_spec (p ) for p in tree_leaves (specs )): return
173
186
prefix = 'in' if error_type == SpecErrorType .input else 'out'
174
187
msgs = [f" { prefix } _specs{ keystr (key )} is { x } of type { type (x ).__name__ } , "
175
188
for key , x in generate_key_paths (specs ) if not isinstance (x , P )]
189
+ if not msgs :
190
+ for key , p in generate_key_paths (specs ):
191
+ for names in p :
192
+ if not isinstance (names , tuple ):
193
+ names = (names ,)
194
+ for name in names :
195
+ if name in auto :
196
+ msgs .append (f" { prefix } _specs{ keystr (key )} refers to { repr (name )} " )
197
+ raise ValueError (
198
+ f"shard_map { prefix } _specs argument cannot refer to an axis "
199
+ f"marked auto ({ auto } ), but:\n \n "
200
+ + '\n \n ' .join (msgs ) + '\n \n '
201
+ f"Check the { prefix } _specs values passed to shard_map." )
176
202
raise TypeError (
177
203
f"shard_map { prefix } _specs argument must be a pytree of "
178
204
f"`jax.sharding.PartitionSpec` instances, but:\n \n "
@@ -549,7 +575,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
549
575
in_nodes_ = map (partial (_xla_shard , ctx , mesh , auto ), in_names , ctx .avals_in ,
550
576
in_avals_ , in_nodes )
551
577
new_axis_context = sharding_impls .SPMDAxisContext (
552
- mesh , frozenset (mesh .axis_names )
578
+ mesh , frozenset (mesh .axis_names ) - auto
553
579
)
554
580
sub_ctx = ctx .module_context .replace (axis_context = new_axis_context )
555
581
with core .extend_axis_env_nd (tuple (mesh .shape .items ())):
@@ -575,20 +601,20 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
575
601
unspecified = set (range (aval_in .ndim )) if auto else set ()
576
602
sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , shard_proto , # type: ignore
577
603
unspecified_dims = unspecified )
578
- return [mlir .wrap_with_full_to_shard_op (ctx , sx , aval_out , manual_proto , set () )]
604
+ return [mlir .wrap_with_full_to_shard_op (ctx , sx , aval_out , manual_proto , unspecified )]
579
605
580
606
def _xla_unshard (ctx : mlir .LoweringRuleContext , mesh , auto , names ,
581
607
aval_in , aval_out , xs ):
582
608
x , = xs
583
- manual_proto = pxla .manual_proto (aval_in , frozenset (mesh .axis_names ) - auto , mesh )
584
- sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , manual_proto , unspecified_dims = set ())
585
609
axes = {name : i for i , ns in names .items () for name in ns }
586
610
ns = NamedSharding (mesh , sharding_impls .array_mapping_to_axis_resources (axes )) # type: ignore
587
611
if dtypes .issubdtype (aval_out .dtype , dtypes .extended ):
588
612
ns = aval_out .dtype ._rules .physical_sharding (aval_out , ns )
589
613
aval_out = core .physical_aval (aval_out )
590
- shard_proto = ns ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
591
614
unspecified = set (range (aval_out .ndim )) if auto else set ()
615
+ manual_proto = pxla .manual_proto (aval_in , frozenset (mesh .axis_names ) - auto , mesh )
616
+ sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , manual_proto , unspecified_dims = unspecified )
617
+ shard_proto = ns ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
592
618
return mlir .wrap_with_shard_to_full_op (ctx , sx , aval_out , shard_proto ,
593
619
unspecified ) # type: ignore
594
620
0 commit comments