Using JAX to generate code for other languages #11203
Unanswered
EelcoHoogendoorn
asked this question in
General
Replies: 2 comments 1 reply
-
As far as I know, the only close-to-complete thing of this flavor is jax/experimental/jax2tf, which uses JAX's transformation framework to emit tensorflow code. It would certainly be possible to use a similar approach to output code in other languages and/or backends, but as you said, it would be a lot of work to make the project feature-complete. |
Beta Was this translation helpful? Give feedback.
1 reply
-
Hi. I think JAX ==> MLIR =(LLVM)=> HLSL/DX is possible, but LLVM might not support HLSL/DX currently: https://discourse.llvm.org/t/rfc-adding-hlsl-and-directx-support-to-clang-llvm/60783 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I recently came across this: https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
That looks super interesting; ive been interested in running my JAX programs inside a rendering context, such as HLSL, but sending objects back and forth seems like a nightmare. However... how hard would it actually be to use this custom interpreter functionality to emit a valid code string in some language like HLSL? The simple case seems simple... but probably its quite the rabbit hole to implement a full suite of functionality.
I was wondering if anyone has done something similar; of using JAX to trace-transform-and-cross-compile their numpy expressions to other languages. My googling comes up empty handed, but is anyone aware of some other project doing something similar?
Beta Was this translation helpful? Give feedback.
All reactions