Skip to content

Understanding the generated MLIR in multi-devices context #30271

Answered by guy-singer
yuanfz98 asked this question in Q&A
Discussion options

You must be logged in to vote

You're observing the difference between logical and physical tensor representations in JAX's compilation pipeline.

The second MLIR module is correct.

The second module shows the logical view where:

  • The function signature uses the full tensor shape tensor<512x512xf32>
  • The sharding annotation {devices=[2,2]<=[4]} tells the compiler how to physically distribute this logical tensor
  • Each device will only store and compute on its 256x256 shard

This is because the compiler can perform cross-shard optimizations when it sees the full logical computation graph. The compiler automatically inserts collective operations where needed based on sharding annotations. Functions remain readable without ex…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@yuanfz98
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by yuanfz98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants