What is OpShardingSharding? #13705
-
In regards to jax.sharding, what is OpShardingSharding? cc: @yashk2810 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hey, good question! OpShardingSharding is the base XLACompatibleSharding that every subclass of The main thing is it consists of an To read the OpSharding, you look at the type, tile_assignment_dimensions and tile_assignment_devices. tile_assignment_dimension will tell you how many ways a dimension is sharded or replicated. If you see something like If you see: This is a simple function we use to figure this information out: https://github.com/google/jax/blob/main/jax/interpreters/pxla.py#L211-L223. |
Beta Was this translation helpful? Give feedback.
Hey, good question!
OpShardingSharding is the base XLACompatibleSharding that every subclass of
XLACompatibleSharding
eventually lower to inside jax.The main thing is it consists of an
OpSharding
. OpSharding is the representation that XLA understands. It is basically documented here: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/xla_data.proto;l=715-765?q=xla_data.proto&ss=tensorflow%2Ftensorflow though it may not be very clear on how things work.To read the OpSharding, you look at the type, tile_assignment_dimensions and tile_assignment_devices. tile_assignment_dimension will tell you how many ways a dimension is sharded or replicated.
If you see s…