File tree Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Expand file tree Collapse file tree 1 file changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -361,14 +361,13 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
361
361
362
362
363
363
def physical_hlo_sharding (aval , hlo_sharding : xc .HloSharding ) -> xc .HloSharding :
364
- key_shape = aval .dtype ._impl .key_shape
365
- new_op_sharding = hlo_sharding .to_proto ().clone () # type: ignore
366
- partitions , num_replicas = op_shardings .get_num_ways_dim_sharded (
367
- hlo_sharding )
368
- suffix = [] if num_replicas == 1 else [num_replicas ]
369
- tad = partitions + [1 ] * len (key_shape ) + suffix
370
- new_op_sharding .tile_assignment_dimensions = tad
371
- return xc .HloSharding .from_proto (new_op_sharding )
364
+ key_shape = aval .dtype ._impl .key_shape
365
+ new_op_sharding = hlo_sharding .to_proto ().clone () # type: ignore
366
+ partitions , num_replicas = op_shardings .get_num_ways_dim_sharded (hlo_sharding )
367
+ suffix = [] if num_replicas == 1 else [num_replicas ]
368
+ tad = partitions + [1 ] * len (key_shape ) + suffix
369
+ new_op_sharding .tile_assignment_dimensions = tad
370
+ return xc .HloSharding .from_proto (new_op_sharding )
372
371
373
372
374
373
class KeyTyRules :
You can’t perform that action at this time.
0 commit comments