Question about the completeness of JAX operators. #18147
-
I want to use JAX to do a project about AI. So, I'm curious about the completeness of JAX operators. (I know the positioning of JAX is different with PyTorch or TensorFlow. If an operator which is native in PyTorch or Tensorflow and it can be constructed with JAX native operators, I won't think the lack of that operator reflects the shortcoming in the completeness of JAX operators) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Jax can lower to everything in XLA. Tensorflow also lowers a subset of its operations to XLA: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla . Note that XLA is the only way to target TPUs using tensorflow, so for TPUs JAX and Tensorflow have full feature parity. The TF operations that are not supported are all CPU/GPU-only operations like decoding jpegs, datasets, etc. XLA has one major limitation vs other approaches in that its shapes are all static (this helps XLA better plan tiling and buffer assignment). However, most ML models do not use dynamic shapes, and there are good workarounds like padding and masking for a limited 'bounded dynamism' regime. |
Beta Was this translation helpful? Give feedback.
Jax can lower to everything in XLA. Tensorflow also lowers a subset of its operations to XLA: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla . Note that XLA is the only way to target TPUs using tensorflow, so for TPUs JAX and Tensorflow have full feature parity. The TF operations that are not supported are all CPU/GPU-only operations like decoding jpegs, datasets, etc. XLA has one major limitation vs other approaches in that its shapes are all static (this helps XLA better plan tiling and buffer assignment). However, most ML models do not use dynamic shapes, and there are good workarounds like padding and masking for a limited 'bounded dynamism' regime.