Skip to content

Commit ab2e906

Browse files
yashk2810jax authors
authored andcommitted
Fix the indentation of the physical_hlo_sharding function
PiperOrigin-RevId: 616280971
1 parent cd1e55a commit ab2e906

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

jax/_src/prng.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,14 +361,13 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
361361

362362

363363
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
364-
key_shape = aval.dtype._impl.key_shape
365-
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
366-
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
367-
hlo_sharding)
368-
suffix = [] if num_replicas == 1 else [num_replicas]
369-
tad = partitions + [1] * len(key_shape) + suffix
370-
new_op_sharding.tile_assignment_dimensions = tad
371-
return xc.HloSharding.from_proto(new_op_sharding)
364+
key_shape = aval.dtype._impl.key_shape
365+
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
366+
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
367+
suffix = [] if num_replicas == 1 else [num_replicas]
368+
tad = partitions + [1] * len(key_shape) + suffix
369+
new_op_sharding.tile_assignment_dimensions = tad
370+
return xc.HloSharding.from_proto(new_op_sharding)
372371

373372

374373
class KeyTyRules:

0 commit comments

Comments
 (0)