Skip to content

Question about the completeness of JAX operators. #18147

Answered by pschuh
Dong-Jiahuan asked this question in Q&A
Discussion options

You must be logged in to vote

Jax can lower to everything in XLA. Tensorflow also lowers a subset of its operations to XLA: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla . Note that XLA is the only way to target TPUs using tensorflow, so for TPUs JAX and Tensorflow have full feature parity. The TF operations that are not supported are all CPU/GPU-only operations like decoding jpegs, datasets, etc. XLA has one major limitation vs other approaches in that its shapes are all static (this helps XLA better plan tiling and buffer assignment). However, most ML models do not use dynamic shapes, and there are good workarounds like padding and masking for a limited 'bounded dynamism' regime.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Dong-Jiahuan
Comment options

Answer selected by Dong-Jiahuan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants