-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
While Numba enables JIT and parallelized Python code on the CPU, the GPU aspect is lacking in features and regressing (see the depreciation of AMD support). One aspect of this problem is the batching of data #24.
JAX is a solid alternative to Numba and ships with similar features and extends them even further. Some notable differences are:
- Code can be executed on CPU, GPU, and TPU (Tensor Processing Unit, usually used on Google Cloud solution) by switching an environmental variable
- GPU incorporates Nvidia for now, but an experimental AMD integration is currently available and functions to a certain degree
- Vector mapping enables easy parallelization of the code and possibly batching using
jax.lax.map
- Using an environmental variable, one can toggle between single and double precision, with single precision usually being enough but speeding things up!
- Compared to Numba, JAX has a whole "ecosystem" of modules which build on top of it, making scientific computing with Numba much easier
A package rework is planned and done on a separate branch until all tests pass!
Metadata
Metadata
Assignees
Labels
No labels