Understanding the generated MLIR in multi-devices context #30271
-
Hello Community, I have a custom PJRT plugin to compile a MLIR:
The first generated MLIR Module is:
Which I understand that XLA is doing partitioning in the original tensor. But the second MLIR gives:
I am confused here as every device holds only tensor of 256x256xf32. It should be something like:
Any help will be appreciated, thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Well I see that pjrt_c_api_client.cc doesn't lower module like cpu_client.cc for |
Beta Was this translation helpful? Give feedback.
-
You're observing the difference between logical and physical tensor representations in JAX's compilation pipeline. The second MLIR module is correct. The second module shows the logical view where:
This is because the compiler can perform cross-shard optimizations when it sees the full logical computation graph. The compiler automatically inserts collective operations where needed based on sharding annotations. Functions remain readable without explicit per-device logic. At runtime with your sharding:
Each device computes The first module ( Your expected MLIR with explicit 256x256 signatures would require the compiler to generate separate functions for each device, which would complicate optimization and communication insertion. The current approach lets the compiler handle distribution automatically based on sharding annotations. |
Beta Was this translation helpful? Give feedback.
You're observing the difference between logical and physical tensor representations in JAX's compilation pipeline.
The second MLIR module is correct.
The second module shows the logical view where:
tensor<512x512xf32>
{devices=[2,2]<=[4]}
tells the compiler how to physically distribute this logical tensorThis is because the compiler can perform cross-shard optimizations when it sees the full logical computation graph. The compiler automatically inserts collective operations where needed based on sharding annotations. Functions remain readable without ex…